jax-boltzmann — A CMB Boltzmann Solver in Pure JAX

jax-boltzmann
— A CMB Boltzmann Solver in Pure JAX

Initial public release (alpha). Repository: github.com/anovickis/jax-boltzmann.
Validation against CAMB and Class is ongoing — feedback and pull
requests welcome.


What this is

jax-boltzmann is a CMB power-spectrum solver written
entirely in JAX. It computes the angular power spectra C_l^TT, C_l^EE,
and C_l^TE for flat ΛCDM plus a free ΔN_eff — the same job that CAMB
does — but in a few different ways:

  • Pure JAX, single file. ~1,500 lines of Python.
    (CAMB is ~50,000 lines of Fortran.)
  • GPU-accelerated. JIT-compiled with
    jax.jit; runs the whole pipeline on a GPU when one is
    available.
  • Auto-differentiable. Every cosmological parameter
    has an exact analytic Jacobian via jacfwd — no finite
    differences, no parameter sweep.
  • MIT-licensed. No restrictions.

Why bother — CAMB already
exists

Two reasons.

1. Auto-differentiable Fisher forecasts. When you
compute a Fisher matrix for a cosmological survey, you need the Jacobian
dC_l / dθ for every parameter θ. With CAMB, you compute it numerically —
perturb each parameter by a small ε, re-run the solver, take a finite
difference. That’s slow, noisy at low ε, and biased at high ε.

With JAX, the Jacobian is exact and comes for free:

from jax import jacfwd
from jax_boltzmann import compute_cls_array

theta = jnp.array([67.36, 0.02237, 0.1200, 2.1e-9, 0.9649, 0.0544, 3.046])

# Exact dC_l/dθ for all 7 parameters simultaneously
J = jacfwd(compute_cls_array)(theta)

That’s the entire derivative computation. One line. Exact to machine
precision.

2. GPU parameter scans. When CAMB runs on a CPU,
every parameter point is a fresh process: parse, allocate, integrate,
write. With a JIT-compiled JAX kernel, the second call onwards skips
compilation and runs the whole pipeline on-GPU. Internal benchmarking
shows ~100× speedup for parameter scans of a few hundred points or
more.

For MCMC chains and grid scans common in cosmology, that’s the
difference between an overnight run and a coffee break.


What’s in v0.1.0 (alpha)

  • Background cosmology: Friedmann equation for H(a), conformal time
    η(a)
  • Recombination: Peebles three-level atom (Saha + ODE for x_e)
  • Perturbation hierarchy: truncated multipole expansion for photons (T
    + E-mode polarisation), baryons, CDM, massless neutrinos
  • Initial conditions: adiabatic super-horizon modes from
    inflation
  • Line-of-sight integration: source function convolved with spherical
    Bessel functions (Seljak-Zaldarriaga)
  • Output: C_l^TT, C_l^EE, C_l^TE up to ℓ_max ≈ 2500
  • Free ΔN_eff: extra effective relativistic degrees of freedom
  • Fisher-matrix example for σ(N_eff) under a Planck-like noise
    model

What’s not in yet

  • Massive neutrinos
  • Tensor modes (B-mode polarisation from primordial gravitational
    waves)
  • Non-flat geometries (open / closed universes)
  • Dark-energy equations of state beyond w = −1
  • Lensing reconstruction
  • Boltzmann hierarchy beyond the truncation order needed for ℓ_max ≈
    2500

These are the main features that will land as the full
Fortran-to-JAX port
progresses. The plan is to bring CAMB’s
complete capability set into JAX, keeping the auto-diff and GPU
advantages along the way.


Status — alpha, validation
in progress

The code is public and runs end-to-end on the Planck 2018 best-fit
cosmology. What’s still being closed out:

  • Numerical regression against CAMB and Class —
    confirming sub-1% agreement in C_l across the whole multipole range
  • Auto-diff sanity checks — comparing JAX gradients
    to finite-difference gradients on isolated parameters
  • Performance benchmarking on real GPUs vs the ~100×
    target
  • Documentation pass for cosmologists who haven’t
    used JAX before

The release is tagged v0.1.0 (alpha) to set
expectations. Issues and pull requests at the repo are welcome —
especially regression-failure cases.


Why this fits the consulting
work

Most of the consulting on this site is hardware: silicon, boards,
safety. This is on the other end of the stack — pure scientific
computing. They’re connected: the same auto-differentiation discipline
that makes JAX useful for cosmological forecasting also matters for ML
inference silicon (JAX is one of the standard frameworks for the
workloads I’m currently designing chips around). Building a non-trivial
JAX codebase from scratch keeps the framework expertise current.

If you’re working on cosmological data analysis, parameter
forecasting, or scientific computing at the GPU/auto-diff intersection,
the announcement is: a fully open-source, modern alternative to
CAMB is coming.


Get the code

  • Repository: github.com/anovickis/jax-boltzmann
  • License: MIT
  • Install:
    git clone https://github.com/anovickis/jax-boltzmann.git && pip install -e .
  • Issues / PRs welcome. Validation reports and
    edge-case bug reports especially.

Validation results and benchmark numbers will land here on the blog
as they’re confirmed.


Working on cosmological forecasts, JAX scientific computing, or
the silicon side of ML inference?
Get in
touch.

#JAX #Cosmology #CMB #Boltzmann #CAMB #ScientificComputing #AutoDiff
#GPU #OpenSource