All Projects → chnsh → Dcrnn_pytorch

chnsh / Dcrnn_pytorch

Licence: mit
Diffusion Convolutional Recurrent Neural Network Implementation in PyTorch

Programming Languages

python
139335 projects - #7 most used programming language

Labels

Projects that are alternatives of or similar to Dcrnn pytorch

Object Localization
Object localization in images using simple CNNs and Keras
Stars: ✭ 130 (-10.96%)
Mutual labels:  cnn
Vpilot
Scripts and tools to easily communicate with DeepGTAV. In the future a self-driving agent will be implemented.
Stars: ✭ 136 (-6.85%)
Mutual labels:  cnn
Self Driving Car 3d Simulator With Cnn
Implementing a self driving car using a 3D Driving Simulator. CNN will be used for training
Stars: ✭ 143 (-2.05%)
Mutual labels:  cnn
Gtzan.keras
[REPO] Music Genre classification on GTZAN dataset using CNNs
Stars: ✭ 132 (-9.59%)
Mutual labels:  cnn
Adnet
Attention-guided CNN for image denoising(Neural Networks,2020)
Stars: ✭ 135 (-7.53%)
Mutual labels:  cnn
Keras Vgg16 Places365
Keras code and weights files for the VGG16-places365 and VGG16-hybrid1365 CNNs for scene classification
Stars: ✭ 138 (-5.48%)
Mutual labels:  cnn
I3d finetune
TensorFlow code for finetuning I3D model on UCF101.
Stars: ✭ 128 (-12.33%)
Mutual labels:  cnn
Text Classification Demos
Neural models for Text Classification in Tensorflow, such as cnn, dpcnn, fasttext, bert ...
Stars: ✭ 144 (-1.37%)
Mutual labels:  cnn
Ncrfpp
NCRF++, a Neural Sequence Labeling Toolkit. Easy use to any sequence labeling tasks (e.g. NER, POS, Segmentation). It includes character LSTM/CNN, word LSTM/CNN and softmax/CRF components.
Stars: ✭ 1,767 (+1110.27%)
Mutual labels:  cnn
Brdnet
Image denoising using deep CNN with batch renormalization(Neural Networks,2020)
Stars: ✭ 141 (-3.42%)
Mutual labels:  cnn
Mk Tfjs
Play MK.js with TensorFlow.js
Stars: ✭ 133 (-8.9%)
Mutual labels:  cnn
Easyocr
Ready-to-use OCR with 80+ supported languages and all popular writing scripts including Latin, Chinese, Arabic, Devanagari, Cyrillic and etc.
Stars: ✭ 13,379 (+9063.7%)
Mutual labels:  cnn
Pytorch Fcn Easiest Demo
PyTorch Implementation of Fully Convolutional Networks (a very simple and easy demo).
Stars: ✭ 138 (-5.48%)
Mutual labels:  cnn
Noiseface
Noise-Tolerant Paradigm for Training Face Recognition CNNs
Stars: ✭ 132 (-9.59%)
Mutual labels:  cnn
Livianet
This repository contains the code of LiviaNET, a 3D fully convolutional neural network that was employed in our work: "3D fully convolutional networks for subcortical segmentation in MRI: A large-scale study"
Stars: ✭ 143 (-2.05%)
Mutual labels:  cnn
Id Cnn Cws
Source codes and corpora of paper "Iterated Dilated Convolutions for Chinese Word Segmentation"
Stars: ✭ 129 (-11.64%)
Mutual labels:  cnn
Flownet2 Docker
Dockerfile and runscripts for FlowNet 2.0 (estimation of optical flow)
Stars: ✭ 137 (-6.16%)
Mutual labels:  cnn
Tensorflow Class Activation Mapping
Learning Deep Features for Discriminative Localization (2016)
Stars: ✭ 147 (+0.68%)
Mutual labels:  cnn
Visualizing cnns
Using Keras and cats to visualize layers from CNNs
Stars: ✭ 143 (-2.05%)
Mutual labels:  cnn
Image classifier
CNN image classifier implemented in Keras Notebook 🖼️.
Stars: ✭ 139 (-4.79%)
Mutual labels:  cnn

Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting

Diffusion Convolutional Recurrent Neural Network

This is a PyTorch implementation of Diffusion Convolutional Recurrent Neural Network in the following paper:
Yaguang Li, Rose Yu, Cyrus Shahabi, Yan Liu, Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting, ICLR 2018.

Requirements

  • torch
  • scipy>=0.19.0
  • numpy>=1.12.1
  • pandas>=0.19.2
  • pyyaml
  • statsmodels
  • tensorflow>=1.3.0
  • torch
  • tables
  • future

Dependency can be installed using the following command:

pip install -r requirements.txt

Comparison with Tensorflow implementation

In MAE (For LA dataset, PEMS-BAY coming in a while)

Horizon Tensorflow Pytorch
1 Hour 3.69 3.12
30 Min 3.15 2.82
15 Min 2.77 2.56

Data Preparation

The traffic data files for Los Angeles (METR-LA) and the Bay Area (PEMS-BAY), i.e., metr-la.h5 and pems-bay.h5, are available at Google Drive or Baidu Yun, and should be put into the data/ folder. The *.h5 files store the data in panads.DataFrame using the HDF5 file format. Here is an example:

sensor_0 sensor_1 sensor_2 sensor_n
2018/01/01 00:00:00 60.0 65.0 70.0 ...
2018/01/01 00:05:00 61.0 64.0 65.0 ...
2018/01/01 00:10:00 63.0 65.0 60.0 ...
... ... ... ... ...

Here is an article about Using HDF5 with Python.

Run the following commands to generate train/test/val dataset at data/{METR-LA,PEMS-BAY}/{train,val,test}.npz.

# Create data directories
mkdir -p data/{METR-LA,PEMS-BAY}

# METR-LA
python -m scripts.generate_training_data --output_dir=data/METR-LA --traffic_df_filename=data/metr-la.h5

# PEMS-BAY
python -m scripts.generate_training_data --output_dir=data/PEMS-BAY --traffic_df_filename=data/pems-bay.h5

Graph Construction

As the currently implementation is based on pre-calculated road network distances between sensors, it currently only supports sensor ids in Los Angeles (see data/sensor_graph/sensor_info_201206.csv).

python -m scripts.gen_adj_mx  --sensor_ids_filename=data/sensor_graph/graph_sensor_ids.txt --normalized_k=0.1\
    --output_pkl_filename=data/sensor_graph/adj_mx.pkl

Besides, the locations of sensors in Los Angeles, i.e., METR-LA, are available at data/sensor_graph/graph_sensor_locations.csv.

Run the Pre-trained Model on METR-LA

# METR-LA
python run_demo_pytorch.py --config_filename=data/model/pretrained/METR-LA/config.yaml

# PEMS-BAY
python run_demo_pytorch.py --config_filename=data/model/pretrained/PEMS-BAY/config.yaml

The generated prediction of DCRNN is in data/results/dcrnn_predictions.

Model Training

# METR-LA
python dcrnn_train_pytorch.py --config_filename=data/model/dcrnn_la.yaml

# PEMS-BAY
python dcrnn_train_pytorch.py --config_filename=data/model/dcrnn_bay.yaml

There is a chance that the training loss will explode, the temporary workaround is to restart from the last saved model before the explosion, or to decrease the learning rate earlier in the learning rate schedule.

Eval baseline methods

# METR-LA
python -m scripts.eval_baseline_methods --traffic_reading_filename=data/metr-la.h5

PyTorch Results

PyTorch Results

PyTorch Results

PyTorch Results

PyTorch Results

Citation

If you find this repository, e.g., the code and the datasets, useful in your research, please cite the following paper:

@inproceedings{li2018dcrnn_traffic,
  title={Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting},
  author={Li, Yaguang and Yu, Rose and Shahabi, Cyrus and Liu, Yan},
  booktitle={International Conference on Learning Representations (ICLR '18)},
  year={2018}
}
Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].