All Projects → thomaspinder → GPJax

thomaspinder / GPJax

Licence: Apache-2.0 license
A didactic Gaussian process package for researchers in Jax.

Programming Languages

python
139335 projects - #7 most used programming language
TeX
3793 projects
Makefile
30231 projects

Projects that are alternatives of or similar to GPJax

Ipynotebook machinelearning
This contains a number of IP[y]: Notebooks that hopefully give a light to areas of bayesian machine learning.
Stars: ✭ 27 (-83.02%)
Mutual labels:  bayesian-inference, gaussian-processes
Bcpd
Bayesian Coherent Point Drift (BCPD/BCPD++); Source Code Available
Stars: ✭ 116 (-27.04%)
Mutual labels:  bayesian-inference, gaussian-processes
Neural Tangents
Fast and Easy Infinite Neural Networks in Python
Stars: ✭ 1,357 (+753.46%)
Mutual labels:  bayesian-inference, gaussian-processes
lgpr
R-package for interpretable nonparametric modeling of longitudinal data using additive Gaussian processes. Contains functionality for inferring covariate effects and assessing covariate relevances. Various models can be specified using a convenient formula syntax.
Stars: ✭ 22 (-86.16%)
Mutual labels:  bayesian-inference, gaussian-processes
Survival Analysis Using Deep Learning
This repository contains morden baysian statistics and deep learning based research articles , software for survival analysis
Stars: ✭ 139 (-12.58%)
Mutual labels:  bayesian-inference, gaussian-processes
FBNN
Code for "Functional variational Bayesian neural networks" (https://arxiv.org/abs/1903.05779)
Stars: ✭ 67 (-57.86%)
Mutual labels:  bayesian-inference, gaussian-processes
Numpy Ml
Machine learning, in numpy
Stars: ✭ 11,100 (+6881.13%)
Mutual labels:  bayesian-inference, gaussian-processes
Gpstuff
GPstuff - Gaussian process models for Bayesian analysis
Stars: ✭ 106 (-33.33%)
Mutual labels:  bayesian-inference, gaussian-processes
Aboleth
A bare-bones TensorFlow framework for Bayesian deep learning and Gaussian process approximation
Stars: ✭ 127 (-20.13%)
Mutual labels:  bayesian-inference, gaussian-processes
Vbmc
Variational Bayesian Monte Carlo (VBMC) algorithm for posterior and model inference in MATLAB
Stars: ✭ 123 (-22.64%)
Mutual labels:  bayesian-inference, gaussian-processes
Stheno.jl
Probabilistic Programming with Gaussian processes in Julia
Stars: ✭ 318 (+100%)
Mutual labels:  bayesian-inference, gaussian-processes
TemporalGPs.jl
Fast inference for Gaussian processes in problems involving time. Partly built on results from https://proceedings.mlr.press/v161/tebbutt21a.html
Stars: ✭ 89 (-44.03%)
Mutual labels:  bayesian-inference, gaussian-processes
TrendinessOfTrends
The Trendiness of Trends
Stars: ✭ 14 (-91.19%)
Mutual labels:  bayesian-inference, gaussian-processes
models
Forecasting 🇫🇷 elections with Bayesian statistics 🥳
Stars: ✭ 24 (-84.91%)
Mutual labels:  bayesian-inference, gaussian-processes
Exoplanet
Fast & scalable MCMC for all your exoplanet needs!
Stars: ✭ 122 (-23.27%)
Mutual labels:  bayesian-inference, gaussian-processes
Stheno.jl
Probabilistic Programming with Gaussian processes in Julia
Stars: ✭ 233 (+46.54%)
Mutual labels:  bayesian-inference, gaussian-processes
approxposterior
A Python package for approximate Bayesian inference and optimization using Gaussian processes
Stars: ✭ 36 (-77.36%)
Mutual labels:  bayesian-inference, gaussian-processes
revisiting rainbow
Revisiting Rainbow
Stars: ✭ 71 (-55.35%)
Mutual labels:  jax
mlp-gpt-jax
A GPT, made only of MLPs, in Jax
Stars: ✭ 53 (-66.67%)
Mutual labels:  jax
ml course
"Learning Machine Learning" Course, Bogotá, Colombia 2019 #LML2019
Stars: ✭ 22 (-86.16%)
Mutual labels:  gaussian-processes

GPJax's logo

codecov CodeFactor Documentation Status PyPI version DOI Downloads Slack Invite

Quickstart | Install guide | Documentation | Slack Community

GPJax aims to provide a low-level interface to Gaussian process (GP) models in Jax, structured to give researchers maximum flexibility in extending the code to suit their own needs. We define a GP prior in GPJax by specifying a mean and kernel function and multiply this by a likelihood function to construct the posterior. The idea is that the code should be as close as possible to the maths we write on paper when working with GP models.

Package support

GPJax was created by Thomas Pinder. Today, the maintenance of GPJax is undertaken by Thomas and Daniel Dodd.

We would be delighted to review pull requests (PRs) from new contributors. Before contributing, please read our guide for contributing. If you do not have the capacity to open a PR, or you would like guidance on how best to structure a PR, then please open an issue. For broader discussions on best practices for fitting GPs, technical questions surrounding the mathematics of GPs, or anything else that you feel doesn't quite constitue an issue, please start a discussion thread in our discussion tracker.

We have recently set up a Slack channel where we hope to facilitate discussions around the development of GPJax and broader support for Gaussian process modelling. If you'd like to join the channel, then please follow this invitation link which will take you to the GPJax Slack community.

Supported methods and interfaces

Examples

Guides for customisation

Simple example

This simple regression example aims to illustrate the resemblance of GPJax's API with how we write the mathematics of Gaussian processes.

After importing the necessary dependencies, we'll simulate some data.

import gpjax as gpx
from jax import grad, jit
import jax.numpy as jnp
import jax.random as jr
import optax as ox

key = jr.PRNGKey(123)

x = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(50,)).sort().reshape(-1, 1)
y = jnp.sin(x) + jr.normal(key, shape=x.shape)*0.05
training = gpx.Dataset(X=x, y=y)

The function of interest here is sinusoidal, but our observations of it have been perturbed by independent zero-mean Gaussian noise. We aim to utilise a Gaussian process to try and recover this latent function.

We begin by defining a zero-mean Gaussian process prior with a radial basis function kernel and assume the likelihood to be Gaussian.

prior = gpx.Prior(kernel = gpx.RBF())
likelihood = gpx.Gaussian(num_datapoints = x.shape[0])

The posterior is then constructed by the product of our prior with our likelihood.

posterior = prior * likelihood

Equipped with the posterior, we proceed to train the model's hyperparameters through gradient-optimisation of the marginal log-likelihood.

We begin by defining a set of initial parameter values through the initialise callable.

parameter_state = gpx.initialise(posterior, key=key)
params, trainables, constrainer, unconstrainer = parameter_state.unpack()
params = gpx.transform(params, unconstrainer)

Next, we define the marginal log-likelihood, adding Jax's just-in-time (JIT) compilation to accelerate training. Notice that this is the first instance of incorporating data into our model. Model building works this way in principle too, where we first define our prior model, then observe some data and use this data to build a posterior.

mll = jit(posterior.marginal_log_likelihood(training, constrainer, negative=True))

Finally, we utilise Jax's built-in Adam optimiser and run an optimisation loop.

opt = ox.adam(learning_rate=0.01)
opt_state = opt.init(params)

for _ in range(100):
  grads = grad(mll)(params)
  updates, opt_state = opt.update(grads, opt_state)
  params = ox.apply_updates(params, updates)

Now that our parameters are optimised, we transform these back to their original constrained space. Using their learned values, we can obtain the posterior distribution of the latent function at novel test points.

final_params = gpx.transform(params, constrainer)

xtest = jnp.linspace(-3., 3., 100).reshape(-1, 1)

latent_distribution = posterior(training, final_params)(xtest)
predictive_distribution = likelihood(latent_distribution, params)

predictive_mean = predictive_distribution.mean()
predictive_stddev = predictive_distribution.stddev()

Installation

Stable version

To install the latest stable version of GPJax run

pip install gpjax

Development version

To install the latest (possibly unstable) version, the following steps should be followed. It is by no means compulsory, but we do advise that you do all of the below inside a virtual environment.

git clone https://github.com/thomaspinder/GPJax.git
cd GPJax
python setup.py develop

We then recommend you check your installation using the supplied unit tests.

python -m pytest tests/

Citing GPJax

If you use GPJax in your research, please cite our JOSS paper. Sample Bibtex is given below:

@article{Pinder2022,
  doi = {10.21105/joss.04455},
  url = {https://doi.org/10.21105/joss.04455},
  year = {2022},
  publisher = {The Open Journal},
  volume = {7},
  number = {75},
  pages = {4455},
  author = {Thomas Pinder and Daniel Dodd},
  title = {GPJax: A Gaussian Process Framework in JAX},
  journal = {Journal of Open Source Software}
}
Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].