All Projects → zh217 → torch-asg

zh217 / torch-asg

Licence: GPL-3.0 License
Auto Segmentation Criterion (ASG) implemented in pytorch

Programming Languages

C++
36643 projects - #6 most used programming language
python
139335 projects - #7 most used programming language
Cuda
1817 projects
CMake
9771 projects

Projects that are alternatives of or similar to torch-asg

Neural sp
End-to-end ASR/LM implementation with PyTorch
Stars: ✭ 408 (+871.43%)
Mutual labels:  speech, seq2seq, asr, ctc
Kerasdeepspeech
A Keras CTC implementation of Baidu's DeepSpeech for model experimentation
Stars: ✭ 245 (+483.33%)
Mutual labels:  speech, asr, ctc
Delta
DELTA is a deep learning based natural language and speech processing platform.
Stars: ✭ 1,479 (+3421.43%)
Mutual labels:  speech, seq2seq, asr
Lingvo
Lingvo
Stars: ✭ 2,361 (+5521.43%)
Mutual labels:  speech, seq2seq, asr
Pytorch Asr
ASR with PyTorch
Stars: ✭ 124 (+195.24%)
Mutual labels:  speech, asr, ctc
Naver-AI-Hackathon-Speech
2019 Clova AI Hackathon : Speech - Rank 12 / Team Kai.Lib
Stars: ✭ 26 (-38.1%)
Mutual labels:  speech, seq2seq
Multimodal-Gesture-Recognition-with-LSTMs-and-CTC
An end-to-end system that performs temporal recognition of gesture sequences using speech and skeletal input. The model combines three networks with a CTC output layer that recognises gestures from continuous stream.
Stars: ✭ 25 (-40.48%)
Mutual labels:  speech, ctc
ASR-Audio-Data-Links
A list of publically available audio data that anyone can download for ASR or other speech activities
Stars: ✭ 179 (+326.19%)
Mutual labels:  speech, asr
kospeech
Open-Source Toolkit for End-to-End Korean Automatic Speech Recognition leveraging PyTorch and Hydra.
Stars: ✭ 456 (+985.71%)
Mutual labels:  seq2seq, asr
End2end Asr Pytorch
End-to-End Automatic Speech Recognition on PyTorch
Stars: ✭ 175 (+316.67%)
Mutual labels:  speech, asr
opensource-voice-tools
A repo listing known open source voice tools, ordered by where they sit in the voice stack
Stars: ✭ 21 (-50%)
Mutual labels:  speech, asr
ctc-asr
End-to-end trained speech recognition system, based on RNNs and the connectionist temporal classification (CTC) cost function.
Stars: ✭ 112 (+166.67%)
Mutual labels:  asr, ctc
Edgedict
Working online speech recognition based on RNN Transducer. ( Trained model release available in release )
Stars: ✭ 205 (+388.1%)
Mutual labels:  speech, asr
wav2vec2-live
A live speech recognition using Facebooks wav2vec 2.0 model.
Stars: ✭ 205 (+388.1%)
Mutual labels:  speech, asr
speech-transformer
Transformer implementation speciaized in speech recognition tasks using Pytorch.
Stars: ✭ 40 (-4.76%)
Mutual labels:  speech, asr
avsr-tf1
Audio-Visual Speech Recognition using Sequence to Sequence Models
Stars: ✭ 76 (+80.95%)
Mutual labels:  seq2seq, asr
sentence2vec
Deep sentence embedding using Sequence to Sequence learning
Stars: ✭ 23 (-45.24%)
Mutual labels:  torch, seq2seq
opensnips
Open source projects related to Snips https://snips.ai/.
Stars: ✭ 50 (+19.05%)
Mutual labels:  speech, asr
speech recognition ctc
Use ctc to do chinese speech recognition by keras / 通过keras和ctc实现中文语音识别
Stars: ✭ 40 (-4.76%)
Mutual labels:  speech, ctc
Asr audio data links
A list of publically available audio data that anyone can download for ASR or other speech activities
Stars: ✭ 128 (+204.76%)
Mutual labels:  speech, asr

Auto Segmentation Criterion (ASG) for pytorch

This repo contains a pytorch implementation of the auto segmentation criterion (ASG), introduced in the paper Wav2Letter: an End-to-End ConvNet-based Speech Recognition System by Facebook.

As mentioned in this blog post by Daniel Galvez, ASG, being an alternative to the connectionist temporal classification (CTC) criterion widely used in deep learning, has the advantage of being a globally normalized model without the conditional independence assumption of CTC and the potential of playing better with WFST frameworks.

Unfortunately, Facebook's implementation in its official wav2letter++ project is based on the ArrayFire C++ framework, which makes experimentation rather difficult. Hence we have ported the ASG implementation in wav2letter++ to pytorch as C++ extensions.

Our implementation should produce the same result as Facebook's, but the implementation is completely different. For example, in their implementation after doing an alpha recursion during the forward pass, they just brute force the back-propagation during the backward pass, whereas we do a proper alpha-beta recursion during the forward pass, and during the backward pass there is no recursion at all. Our implementation has the benefit of much higher parallelism potential. Another difference is that we try to use pytorch's native functions as much as possible, whereas Facebook's implementation is basically a gigantic hand-written C code working on raw arrays.

In the doc folder, you can find the maths derivation of our implementation.

Project status

  • CPU (openmp) implementation
  • GPU (cuda) implementation
  • testing
  • performance tuning and comparison
  • Viterbi decoders
  • generalization to better integrate with general WFSTs decoders

Using the project

Ensure pytorch > 1.01 is installed, clone the project and in terminal do

cd torch_asg
pip install .

Tested with python 3.7.1. You need to have suitable C++ toolchain installed. For GPU, you need to have an nVidia card with compute capability >= 6.

Then in your python code:

import torch
from torch_asg import ASGLoss


def test_run():
    num_labels = 7
    input_batch_len = 6
    num_batches = 2
    target_batch_len = 5
    asg_loss = ASGLoss(num_labels=num_labels,
                       reduction='mean',  # mean (default), sum, none
                       gpu_no_stream_impl=False, # see below for explanation
                       forward_only=False # see below for explanation                      
                       )
    for i in range(1):
        # Note that inputs follows the CTC convention so that the batch dimension is 1 instead of 0,
        # in order to have a more efficient GPU implementation
        inputs = torch.randn(input_batch_len, num_batches, num_labels, requires_grad=True)
        targets = torch.randint(0, num_labels, (num_batches, target_batch_len))
        input_lengths = torch.randint(1, input_batch_len + 1, (num_batches,))
        target_lengths = torch.randint(1, target_batch_len + 1, (num_batches,))
        loss = asg_loss.forward(inputs, targets, input_lengths, target_lengths)
        print('loss', loss)
        # You can get the transition matrix if you need it.
        # transition[i, j] is transition score from label j to label i.
        print('transition matrix', asg_loss.transition)
        loss.backward()
        print('transition matrix grad', asg_loss.transition.grad)
        print('inputs grad', inputs.grad)

test_run()

There are two options for the loss constructor that warrants further explanation:

  • gpu_no_stream_impl: by default, if you are using GPU, we are using an implementation that is highly concurrent by doing some rather complicated CUDA streams manipulation. You can turn this concurrent implementation off by setting this parameter to true, and then CUDA kernel launches are serial. Useful for debugging.
  • forward_only: by default, our implementation does quite a lot of work during the forward pass concurrently that is only useful for calculating the gradients. If you don't need the gradient, setting this parameter to true will give a further speed boost. Note that the forward-only mode is automatically active when your model is in evaluation mode.

Compared to Facebook's implementation, we have also omitted scaling based on input/output lengths. If you need it, you can do it yourself by using the None reduction and scale the individual scores before summing/averaging.

Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].