All Projects → brentyi → jaxfg

brentyi / jaxfg

Licence: MIT license
Factor graphs and nonlinear optimization for JAX

Programming Languages

python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to jaxfg

MINLPLib.jl
A JuMP-based library of Non-Linear and Mixed-Integer Non-Linear Programs
Stars: ✭ 30 (-75.81%)
Mutual labels:  nonlinear-optimization
revisiting rainbow
Revisiting Rainbow
Stars: ✭ 71 (-42.74%)
Mutual labels:  jax
GPJax
A didactic Gaussian process package for researchers in Jax.
Stars: ✭ 159 (+28.23%)
Mutual labels:  jax
snopt-matlab
Matlab interface for sparse nonlinear optimizer SNOPT
Stars: ✭ 49 (-60.48%)
Mutual labels:  nonlinear-optimization
jax-cfd
Computational Fluid Dynamics in JAX
Stars: ✭ 399 (+221.77%)
Mutual labels:  jax
mlp-gpt-jax
A GPT, made only of MLPs, in Jax
Stars: ✭ 53 (-57.26%)
Mutual labels:  jax
graphsignal
Graphsignal Python agent
Stars: ✭ 158 (+27.42%)
Mutual labels:  jax
lbfgsb-gpu
An open source library for the GPU-implementation of L-BFGS-B algorithm
Stars: ✭ 70 (-43.55%)
Mutual labels:  nonlinear-optimization
jax-models
Unofficial JAX implementations of deep learning research papers
Stars: ✭ 108 (-12.9%)
Mutual labels:  jax
uvadlc notebooks
Repository of Jupyter notebook tutorials for teaching the Deep Learning Course at the University of Amsterdam (MSc AI), Fall 2022/Spring 2022
Stars: ✭ 901 (+626.61%)
Mutual labels:  jax
ADAM
ADAM implements a collection of algorithms for calculating rigid-body dynamics in Jax, CasADi, PyTorch, and Numpy.
Stars: ✭ 51 (-58.87%)
Mutual labels:  jax
jax-rl
JAX implementations of core Deep RL algorithms
Stars: ✭ 61 (-50.81%)
Mutual labels:  jax
rA9
JAX-based Spiking Neural Network framework
Stars: ✭ 60 (-51.61%)
Mutual labels:  jax
NAGPythonExamples
Examples and demos showing how to call functions from the NAG Library for Python
Stars: ✭ 46 (-62.9%)
Mutual labels:  nonlinear-optimization
ShinRL
ShinRL: A Library for Evaluating RL Algorithms from Theoretical and Practical Perspectives (Deep RL Workshop 2021)
Stars: ✭ 30 (-75.81%)
Mutual labels:  jax
omd
JAX code for the paper "Control-Oriented Model-Based Reinforcement Learning with Implicit Differentiation"
Stars: ✭ 43 (-65.32%)
Mutual labels:  jax
galini
An extensible MINLP solver
Stars: ✭ 29 (-76.61%)
Mutual labels:  nonlinear-optimization
dm pix
PIX is an image processing library in JAX, for JAX.
Stars: ✭ 271 (+118.55%)
Mutual labels:  jax
efficientnet-jax
EfficientNet, MobileNetV3, MobileNetV2, MixNet, etc in JAX w/ Flax Linen and Objax
Stars: ✭ 114 (-8.06%)
Mutual labels:  jax
pdfo
Powell's Derivative-Free Optimization solvers
Stars: ✭ 56 (-54.84%)
Mutual labels:  nonlinear-optimization

jaxfg

build lint mypy codecov

jaxfg is a factor graph-based nonlinear least squares library for JAX. Typical applications include sensor fusion, SLAM, bundle adjustment, optimal control.

The premise: we provide a high-level interface for defining probability densities as factor graphs. MAP inference reduces to nonlinear optimization, which we accelerate by analyzing the structure of the graph. Repeated factor and variable types have operations vectorized, and the sparsity of graph connections is translated into sparse matrix operations.

Features:

  • Autodiff-powered sparse Jacobians.
  • Automatic vectorization for repeated factor and variable types.
  • Manifold definition interface, with implementations provided for SO(2), SE(2), SO(3), and SE(3) Lie groups.
  • Support for standard JAX function transformations: jit, vmap, pmap, grad, etc.
  • Nonlinear optimizers: Gauss-Newton, Levenberg-Marquardt, Dogleg.
  • Sparse linear solvers: conjugate gradient (Jacobi-preconditioned), sparse Cholesky (via CHOLMOD).

This library is released as part of our IROS 2021 paper (more info in our core experiment repository here) and borrows heavily from a wide set of existing libraries, including GTSAM, Ceres Solver, minisam, SwiftFusion, and g2o. For technical background and concepts, GTSAM has a great set of tutorials.

Installation

scikit-sparse require SuiteSparse:

sudo apt update
sudo apt install -y libsuitesparse-dev

Then, from your environment of choice:

git clone https://github.com/brentyi/jaxfg.git
cd jaxfg
pip install -e .

Example scripts

Toy pose graph optimization:

python scripts/pose_graph_simple.py

Pose graph optimization from .g2o files:

python scripts/pose_graph_g2o.py  # For options, pass in a --help flag

Development

If you're interested in extending this library to define your own factor graphs, we'd recommend first familiarizing yourself with:

  1. Pytrees in JAX: https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html
  2. Python dataclasses: https://docs.python.org/3/library/dataclasses.html
    • We currently take a "make everything a dataclass" philosophy for software engineering in this library. This is convenient for several reasons, but notably makes it easy for objects to be registered as pytree nodes. See jax_dataclasses for details on this.
  3. Type annotations: https://docs.python.org/3/library/typing.html
    • We rely on generics (typing.Generic and typing.TypeVar) particularly heavily. If you're familiar with C++ this should come very naturally (~templates).
  4. Explicit decorators for overrides/inheritance: https://github.com/mkorpela/overrides
    • The @overrides and @final decorators signal which methods are being and/or shouldn't be overridden. The same goes for @abc.abstractmethod.

From there, we have a few references for defining your own factor graphs, factors, and manifolds:

Current limitations

  1. In XLA, JIT compilation needs to happen for each unique set of input shapes. Modifying graph structures can thus introduce significant re-compilation overheads; this can restrict applications that are dynamic or online.
  2. Our marginalization implementation is not very good.

To-do

This library's still in development mode! Here's our TODO list:

  • Preliminary graph, variable, factor interfaces
  • Real vector variable types
  • Refactor into package
  • Nonlinear optimization for MAP inference
    • Conjugate gradient linear solver
    • CHOLMOD linear solver
      • Basic implementation. JIT-able, but no vmap, pmap, or autodiff support.
      • Custom VJP rule? vmap support?
    • Gauss-Newton implementation
    • Termination criteria
    • Damped least squares
    • Dogleg
    • Inexact Newton steps
    • Revisit termination criteria
    • Reduce redundant code
    • Robust losses
  • Marginalization
    • Prototype using sksparse/CHOLMOD (works but fairly slow)
    • JAX implementation?
  • Validate g2o example
  • Performance
    • More intentional JIT compilation
    • Re-implement parallel factor computation
    • Vectorized linearization
    • Basic (Jacobi) CGLS preconditioning
  • Manifold optimization (mostly offloaded to jaxlie)
    • Basic interface
    • Manifold optimization on SO2
    • Manifold optimization on SE2
    • Manifold optimization on SO3
    • Manifold optimization on SE3
  • Usability + code health (low priority)
    • Basic cleanup/refactor
      • Better parallel factor interface
      • Separate out utils, lie group helpers
      • Put things in folders
    • Resolve typing errors
    • Cleanup/refactor (more)
    • Package cleanup: dependencies, etc
    • Add CI:
      • mypy
      • lint
      • build
      • coverage
    • More comprehensive tests
    • Clean up docstrings
    • New name
Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].