j-Wave
Fast and differentiable acoustic simulations in JAX
j-Wave is a customizable Python simulator, written on top of the JAX library and the discretization framework JaxDF, designed for fast, parallelizable, and differentiable acoustic simulations.
j-Wave solves both time-varying and time-harmonic forms of the wave equation, with support for multiple discretizations, including finite differences and Fourier spectral methods. Custom discretizations, including those based on neural networks, can also be utilized via the JaxDF framework.
The use of the JAX library provides direct support for program transformations, such as automatic differentiation, Single-Program Multiple-Data (SPMD) parallelism, and just-in-time compilation.
Following the philosophy of JAX, j-Wave is developed with the following principles in mind:
- Fully differentiable
- Fast through hardware-specific
jit
compilation - Easy to run on GPUs and TPUs
- Easily customizable to support novel research ideas, including novel discretizations via
jaxdf