All Projects → khirotaka → SAnD

khirotaka / SAnD

Licence: MIT License
[Implementation example] Attend and Diagnose: Clinical Time Series Analysis Using Attention Models

Programming Languages

python
139335 projects - #7 most used programming language
Dockerfile
14818 projects

Projects that are alternatives of or similar to SAnD

pairs trading cryptocurrencies strategy catalyst
Pairs trading strategy example based on Catalyst
Stars: ✭ 34 (-12.82%)
Mutual labels:  time-series
Deep-Signature-Transforms
Code for "Deep Signature Transforms" (NeurIPS 2019)
Stars: ✭ 65 (+66.67%)
Mutual labels:  time-series
time series clustering via community detection
Code used in the paper "Time Series Clustering via Community Detection in Networks"
Stars: ✭ 27 (-30.77%)
Mutual labels:  time-series
tscompdata
Time series competition data
Stars: ✭ 17 (-56.41%)
Mutual labels:  time-series
Fred
A fast, scalable and light-weight C++ Fréchet distance library, exposed to python and focused on (k,l)-clustering of polygonal curves.
Stars: ✭ 13 (-66.67%)
Mutual labels:  time-series
gpu accelerated forecasting modeltime gluonts
GPU-Accelerated Deep Learning for Time Series using Modeltime GluonTS (Learning Lab 53). Event sponsors: Saturn Cloud, NVIDIA, & Business Science.
Stars: ✭ 20 (-48.72%)
Mutual labels:  time-series
support resistance line
A well-tuned algorithm to generate & draw support/resistance line on time series. 根据时间序列自动生成支撑线压力线
Stars: ✭ 53 (+35.9%)
Mutual labels:  time-series
questdb.io
The official QuestDB website, database documentation and blog.
Stars: ✭ 75 (+92.31%)
Mutual labels:  time-series
MogrifierLSTM
A quick walk-through of the innards of LSTMs and a naive implementation of the Mogrifier LSTM paper in PyTorch
Stars: ✭ 58 (+48.72%)
Mutual labels:  paper-implementations
dts
A Keras library for multi-step time-series forecasting.
Stars: ✭ 130 (+233.33%)
Mutual labels:  time-series
SMC.jl
Sequential Monte Carlo algorithm for approximation of posterior distributions.
Stars: ✭ 53 (+35.9%)
Mutual labels:  time-series
NALU-Keras
A keras implementation of [Neural Arithmetic Logic Units](https://arxiv.org/pdf/1808.00508.pdf) by Andrew et. al.
Stars: ✭ 14 (-64.1%)
Mutual labels:  paper-implementations
mvts-ano-eval
A repository for code accompanying the manuscript 'An Evaluation of Anomaly Detection and Diagnosis in Multivariate Time Series' (published at TNNLS)
Stars: ✭ 26 (-33.33%)
Mutual labels:  time-series
awesome-time-series
Resources for working with time series and sequence data
Stars: ✭ 178 (+356.41%)
Mutual labels:  time-series
Awesome CV Research
No description or website provided.
Stars: ✭ 18 (-53.85%)
Mutual labels:  paper-implementations
Robust-Deep-Learning-Pipeline
Deep Convolutional Bidirectional LSTM for Complex Activity Recognition with Missing Data. Human Activity Recognition Challenge. Springer SIST (2020)
Stars: ✭ 20 (-48.72%)
Mutual labels:  time-series
ts-forecasting-ensemble
CentOS based Docker container for Time Series Analysis and Modeling.
Stars: ✭ 19 (-51.28%)
Mutual labels:  time-series
barrage
Barrage is an opinionated supervised deep learning tool built on top of TensorFlow 2.x designed to standardize and orchestrate the training and scoring of complicated models.
Stars: ✭ 16 (-58.97%)
Mutual labels:  time-series
Time-Series-Transformer
A data preprocessing package for time series data. Design for machine learning and deep learning.
Stars: ✭ 123 (+215.38%)
Mutual labels:  time-series
gmwm
Generalized Method of Wavelet Moments (GMWM) is an estimation technique for the parameters of time series models. It uses the wavelet variance in a moment matching approach that makes it particularly suitable for the estimation of certain state-space models.
Stars: ✭ 21 (-46.15%)
Mutual labels:  time-series

SAnD

AAAI 2018 Attend and Diagnose: Clinical Time Series Analysis Using Attention Models

Codacy Badge contributions welcome License: MIT

Warning This code is UNOFFICIAL.

Paper: Attend and Diagnose: Clinical Time Series Analysis Using Attention Models

If you want to run this code, you need download some dataset and write experimenting code.

from comet_ml import Experiment
from SAnD.core.model import SAnD
from SAnD.utils.trainer import NeuralNetworkClassifier

model = SAnD( ... )
clf = NeuralNetworkClassifier( ... )
clf.fit( ... )

Installation

git clone https://github.com/khirotaka/SAnD.git

Requirements

  • Python 3.6
  • Comet.ml
  • PyTorch v1.1.0 or later

Simple Usage

Here's a brief overview of how you can use this project to help you solve the classification task.

Download this project

First, create an empty directory.
In this example, I'll call it "playground".
Run the git init & git submodule add command to register SAnD project as a submodule.

$ mkdir playground/
$ cd playground/
$ git init
$ git submodule add https://github.com/khirotaka/SAnD.git

Now you're ready to use SAnD in your project.

Preparing the Dataset

Prepare the data set of your choice.
Remember that the input dimension to the SAnD model is basically three dimensions of [N, seq_len, features].

This example shows how to use torch.randn() as a pseudo dataset.

from comet_ml import Experiment

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

from SAnD.core.model import SAnD
from SAnD.utils.trainer import NeuralNetworkClassifier


x_train = torch.randn(1024, 256, 23)    # [N, seq_len, features]
x_val = torch.randn(128, 256, 23)       # [N, seq_len, features]
x_test =  torch.randn(512, 256, 23)     # [N, seq_len, features]

y_train = torch.randint(0, 9, (1024, ))
y_val = torch.randint(0, 9, (128, ))
y_test = torch.randint(0, 9, (512, ))


train_ds = TensorDataset(x_train, y_train)
val_ds = TensorDataset(x_val, y_val)
test_ds = TensorDataset(x_test, y_test)

train_loader = DataLoader(train_ds, batch_size=128)
val_loader = DataLoader(val_ds, batch_size=128)
test_loader = DataLoader(test_ds, batch_size=128)

Note:
In my experience, I have a feeling that SAnD is better at problems with a large number of features.

Training SAnD model using Trainer

Finally, train the SAnD model using the included NeuralNetworkClassifier.
Of course, you can also have them use a well-known training tool such as PyTorch Lightning.
The included NeuralNetworkClassifier depends on the comet.ml's logging service.

in_feature = 23
seq_len = 256
n_heads = 32
factor = 32
num_class = 10
num_layers = 6

clf = NeuralNetworkClassifier(
    SAnD(in_feature, seq_len, n_heads, factor, num_class, num_layers),
    nn.CrossEntropyLoss(),
    optim.Adam, optimizer_config={"lr": 1e-5, "betas": (0.9, 0.98), "eps": 4e-09, "weight_decay": 5e-4},
    experiment=Experiment()
)

# training network
clf.fit(
    {"train": train_loader,
     "val": val_loader},
    epochs=200
)

# evaluating
clf.evaluate(test_loader)

# save
clf.save_to_file("save_params/")

For the actual task, choose the appropriate hyperparameters for your model and optimizer.

Regression Task

There are two ways to use SAnD in a regression task.

  1. Specify the number of output dimensions in num_class.
  2. Inherit class SAnD and overwrite ClassificationModule with RegressionModule.

I would like to introduce a second point.

from SAnD.core.model import SAnD
from SAnD.core.modules import RegressionModule


class RegSAnD(SAnD):
    def __init__(self, *args, **kwargs):
        super(RegSAnD, self).__init__(*args, **kwargs)
        d_model = kwargs.get("d_model")
        factor = kwargs.get("factor")
        output_size = kwargs.get("n_class")    # output_size

        self.clf = RegressionModule(d_model, factor, output_size)


model = RegSAnD(
    input_features=..., seq_len=..., n_heads=..., factor=...,
    n_class=..., n_layers=...
)

The contents of both ClassificationModule and RegressionModule are almost the same, so the 1st is recommended.

Please let me know when my code has been used to bring products or research results to the world.
It's very encouraging :)

Author

Hirotaka Kawashima (川島 寛隆)

License

Copyright (c) 2019 Hirotaka Kawashima
Released under the MIT license

Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].