All Projects → ucl-bug → jaxdf

ucl-bug / jaxdf

Licence: LGPL-3.0 license
A JAX-based research framework for writing differentiable numerical simulators with arbitrary discretizations

Programming Languages

python
139335 projects - #7 most used programming language
Makefile
30231 projects

Projects that are alternatives of or similar to jaxdf

GalerkinSparseGrids.jl
Sparse Grid Discretization with the Discontinuous Galerkin Method for solving PDEs
Stars: ✭ 39 (-22%)
Mutual labels:  pde, discretization
madam
👩 Pytorch and Jax code for the Madam optimiser.
Stars: ✭ 46 (-8%)
Mutual labels:  jax
PyFRAP
PyFRAP: A Python based FRAP analysis tool box
Stars: ✭ 15 (-70%)
Mutual labels:  pde
featool-multiphysics
FEATool - "Physics Simulation Made Easy" (Fully Integrated FEA, FEniCS, OpenFOAM, SU2 Solver GUI & Multi-Physics Simulation Platform)
Stars: ✭ 190 (+280%)
Mutual labels:  pde
dm pix
PIX is an image processing library in JAX, for JAX.
Stars: ✭ 271 (+442%)
Mutual labels:  jax
MAESTRO
A low Mach number stellar hydrodynamics code
Stars: ✭ 29 (-42%)
Mutual labels:  pde
efficientnet-jax
EfficientNet, MobileNetV3, MobileNetV2, MixNet, etc in JAX w/ Flax Linen and Objax
Stars: ✭ 114 (+128%)
Mutual labels:  jax
SciMLBenchmarks.jl
Benchmarks for scientific machine learning (SciML) software and differential equation solvers
Stars: ✭ 195 (+290%)
Mutual labels:  pde
rom-operator-inference-Python3
Operator Inference for data-driven, non-intrusive model reduction of dynamical systems.
Stars: ✭ 31 (-38%)
Mutual labels:  pde
curve-shortening-demo
Visualize curve shortening flow in your browser.
Stars: ✭ 19 (-62%)
Mutual labels:  pde
hydro examples
Simple one-dimensional examples of various hydrodynamics techniques
Stars: ✭ 83 (+66%)
Mutual labels:  pde
jaxfg
Factor graphs and nonlinear optimization for JAX
Stars: ✭ 124 (+148%)
Mutual labels:  jax
robustness-vit
Contains code for the paper "Vision Transformers are Robust Learners" (AAAI 2022).
Stars: ✭ 78 (+56%)
Mutual labels:  jax
Numerical-Algorithms
Numerical Experiments
Stars: ✭ 15 (-70%)
Mutual labels:  pde
koclip
KoCLIP: Korean port of OpenAI CLIP, in Flax
Stars: ✭ 80 (+60%)
Mutual labels:  jax
devsim
TCAD Semiconductor Device Simulator
Stars: ✭ 104 (+108%)
Mutual labels:  pde
adversarial-code-generation
Source code for the ICLR 2021 work "Generating Adversarial Computer Programs using Optimized Obfuscations"
Stars: ✭ 16 (-68%)
Mutual labels:  differentiable-programming
chef-transformer
Chef Transformer 🍲 .
Stars: ✭ 29 (-42%)
Mutual labels:  jax
Teg
A differentiable programming language with an integration primitive that soundly handles interactions among the derivative, integral, and discontinuities.
Stars: ✭ 25 (-50%)
Mutual labels:  differentiable-programming
cr-sparse
Functional models and algorithms for sparse signal processing
Stars: ✭ 38 (-24%)
Mutual labels:  jax

jaxdf - JAX-based Discretization Framework

License: LGPL v3 codecov Continous Integration Documentation

Overview | Example | Installation | Documentation

⚠️ This library is still in development. Breaking changes may occur.


Overview

jaxdf is a JAX-based package defining a coding framework for writing differentiable numerical simulators with arbitrary discretizations.

The intended use is to build numerical models of physical systems, such as wave propagation, or the numerical solution of partial differential equations, that are easy to customize to the user's research needs. Such models are pure functions that can be included into arbitray differentiable programs written in JAX: for example, they can be used as layers of neural networks, or to build a physics loss function.


Example

The following script builds the non-linear operator (∇2 + sin), using a Fourier spectral discretization on a square 2D domain, and uses it to define a loss function whose gradient is evaluated using JAX Automatic Differentiation.

from jaxdf import operators as jops
from jaxdf import FourierSeries, operator
from jaxdf.geometry import Domain
from jax import numpy as jnp
from jax import jit, grad


# Defining operator
@operator
def custom_op(u):
  grad_u = jops.gradient(u)
  diag_jacobian = jops.diag_jacobian(grad_u)
  laplacian = jops.sum_over_dims(diag_jacobian)
  sin_u = jops.compose(u)(jnp.sin)
  return laplacian + sin_u

# Defining discretizations
domain = Domain((128, 128), (1., 1.))
parameters = jnp.ones((128,128,1))
u = FourierSeries(parameters, domain)

# Define a differentiable loss function
@jit
def loss(u):
  v = custom_op(u)
  return jnp.mean(jnp.abs(v.on_grid)**2)

gradient = grad(loss)(u) # gradient is a FourierSeries

Installation

Before installing jaxdf, make sure that you have installed JAX. Follow the instruction to install JAX with NVidia GPU support if you want to use jaxdf on the GPUs.

Install jaxdf by cloning the repository or downloading and extracting the compressed archive. Then navigate in the root folder in a terminal, and run

pip install -r .requirements/requirements.txt
pip install .

Citation

arXiv

This package will be presented at the Differentiable Programming workshop at NeurIPS 2021.

@article{stanziola2021jaxdf,
    author={Stanziola, Antonio and Arridge, Simon and Cox, Ben T. and Treeby, Bradley E.},
    title={A research framework for writing differentiable PDE discretizations in JAX},
    year={2021},
    journal={Differentiable Programming workshop at Neural Information Processing Systems 2021}
}

Acknowledgements

Related projects

  1. odl Operator Discretization Library (ODL) is a python library for fast prototyping focusing on (but not restricted to) inverse problems.
  2. deepXDE: a TensorFlow and PyTorch library for scientific machine learning.
  3. SciML: SciML is a NumFOCUS sponsored open source software organization created to unify the packages for scientific machine learning.
Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].