All Projects → chrischoy → Pytorch Custom Cuda Tutorial

chrischoy / Pytorch Custom Cuda Tutorial

Tutorial for building a custom CUDA function for Pytorch

Programming Languages

python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to Pytorch Custom Cuda Tutorial

Processing Docs
Processing reference, examples, tutorials, and website
Stars: ✭ 346 (-6.23%)
Mutual labels:  tutorial
Learn Rust
Stars: ✭ 356 (-3.52%)
Mutual labels:  tutorial
Ffmpy
Pythonic interface for FFmpeg/FFprobe command line
Stars: ✭ 360 (-2.44%)
Mutual labels:  wrapper
Ideasnprojects
That Project's Source Code
Stars: ✭ 344 (-6.78%)
Mutual labels:  tutorial
Thebookofshaders
Step-by-step guide through the abstract and complex universe of Fragment Shaders.
Stars: ✭ 4,070 (+1002.98%)
Mutual labels:  tutorial
Python Lectures
IPython Notebooks to learn Python
Stars: ✭ 355 (-3.79%)
Mutual labels:  tutorial
Mysql Tutorial
🌱 This is a tutorial of MySQL. In this tutorial, you can leran how to use MySQL and optimize SQL.
Stars: ✭ 346 (-6.23%)
Mutual labels:  tutorial
Pygame tutorials
Code to go along with lessons at http://kidscancode.org/lessons
Stars: ✭ 363 (-1.63%)
Mutual labels:  tutorial
Javascript Idiosyncrasies
A bunch of Javascript idiosyncrasies to beginners.
Stars: ✭ 353 (-4.34%)
Mutual labels:  tutorial
Csharp Data Visualization
Code examples and notes for displaying data with C#
Stars: ✭ 355 (-3.79%)
Mutual labels:  tutorial
Mastering ros
This repository contains exercise files of the book "Mastering ROS for Robotics Programming"
Stars: ✭ 351 (-4.88%)
Mutual labels:  tutorial
Ai Blog
Accompanying repository for Let's make a DQN / A3C series.
Stars: ✭ 351 (-4.88%)
Mutual labels:  tutorial
Llthw Zh
📖 [译] 笨办法学 Linux
Stars: ✭ 357 (-3.25%)
Mutual labels:  tutorial
Pygobject Tutorial
Tutorial for using GTK+ 3 in Python
Stars: ✭ 347 (-5.96%)
Mutual labels:  tutorial
Cuda Api Wrappers
Thin C++-flavored wrappers for the CUDA Runtime API
Stars: ✭ 362 (-1.9%)
Mutual labels:  wrapper
Pytorchzerotoall
Simple PyTorch Tutorials Zero to ALL!
Stars: ✭ 3,586 (+871.82%)
Mutual labels:  tutorial
Pycadl
Python package with source code from the course "Creative Applications of Deep Learning w/ TensorFlow"
Stars: ✭ 356 (-3.52%)
Mutual labels:  tutorial
Vim Classes
Class notes from my vim training
Stars: ✭ 364 (-1.36%)
Mutual labels:  tutorial
Blender Scripting
Introduction to blender scripting
Stars: ✭ 362 (-1.9%)
Mutual labels:  tutorial
Serverless Stack Com
An open source guide for building and deploying full-stack apps using Serverless and React on AWS.
Stars: ✭ 3,617 (+880.22%)
Mutual labels:  tutorial

Pytorch Custom CUDA kernel Tutorial

This repository contains a tutorial code for making a custom CUDA function for pytorch. The code is based on the pytorch C extension example.

Disclaimer

This tutorial was written when pytorch did not support broadcasting sum. Now that it supports, probably you wouldn't need to make your own broadcasting sum function, but you can still follow the tutorial to build your own custom layer with a custom CUDA kernel.

In this repository, we will build a simple CUDA based broadcasting sum function. The current version of pytorch does not support broadcasting sum, thus we have to manually expand a tensor like using expand_as which makes a new tensor and takes additional memory and computation.

For example,

a = torch.randn(3, 5)
b = torch.randn(3, 1)
# The following line will give an error
# a += b

# Expand b to have the same dimension as a
b_like_a = b.expand_as(a)
a += b_like_a

In this post, we will build a function that can compute a += b without explicitly expanding b.

mathutil.broadcast_sum(a, b, *map(int, a.size()))

Make a CUDA kernel

First, let's make a cuda kernel that adds b to a without making a copy of a tensor b.

__global__ void broadcast_sum_kernel(float *a, float *b, int x, int y, int size)
{
    int i = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
    if(i >= size) return;
    int j = i % x; i = i / x;
    int k = i % y;
    a[IDX2D(j, k, y)] += b[k];
}

Make a C wrapper

Once you made a CUDA kernel, you have to wrap it with a C code. However, we are not using the pytorch backend yet. Note that the inputs are already device pointers.

void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream)
{
    int size = x * y;
    cudaError_t err;

    broadcast_sum_kernel<<<cuda_gridsize(size), BLOCK, 0, stream>>>(a, b, x, y, size);

    err = cudaGetLastError();
    if (cudaSuccess != err)
    {
        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
        exit(-1);
    }
}

Connect Pytorch backends with the C Wrapper

Next, we have to connect the pytorch backend with our C wrapper. You can expose the device pointer using the function THCudaTensor_data. The pointers a and b are device pointers (on GPU).

extern THCState *state;

int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y)
{
    float *a = THCudaTensor_data(state, a_tensor);
    float *b = THCudaTensor_data(state, b_tensor);
    cudaStream_t stream = THCState_getCurrentStream(state);

    broadcast_sum_cuda(a, b, x, y, stream);

    return 1;
}

Make a python wrapper

Now that we built the cuda function and a pytorch function, we need to expose the function to python so that we can use the function in python.

We will first build a shared library using nvcc.

nvcc ... -o build/mathutil_cuda_kernel.so src/mathutil_cuda_kernel.cu

Then, we will use the pytorch torch.utils.ffi.create_extension function which automatically put appropriate headers and builds a python loadable shared library.

from torch.utils.ffi import create_extension

...

ffi = create_extension(
    'mathutils',
    headers=[...],
    sources=[...],
    ...
)

ffi.build()

Test!

Finally, we can test our function by building it. In the readme, I removed a lot of details, but you can see a working example.

git clone https://github.com/chrischoy/pytorch-cffi-tutorial
cd pytorch-cffi-tutorial
make

Note

The function only takes THCudaTensor, which is torch.FloatTensor().cuda() in python.

Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].