Quickstart¶
This notebook is designed to showcase the primary features of jaxdf
. Throughout the notebook, we will utilize several libraries from the jax
ecosystem, including optax
and diffrax
. Although these libraries are not necessary for using jaxdf
, our goal is to demonstrate jaxdf
's compatibility with generic jax
-based libraries.
Let's begin by downloading an image that we will use in this notebook.
!wget -O test_img.jpg https://upload.wikimedia.org/wikipedia/commons/a/a0/Black_and_White_1_-_Augusto_De_Luca_photographer.jpg >/dev/null 2>&1
!pip install opencv-python # Run if opencv is not installed in your system
!pip install matplotlib # Run if matplotlib is not installed in your system
Requirement already satisfied: opencv-python in /home/antonio/anaconda3/envs/jaxdf/lib/python3.11/site-packages (4.8.1.78) Requirement already satisfied: numpy>=1.21.2 in /home/antonio/anaconda3/envs/jaxdf/lib/python3.11/site-packages (from opencv-python) (1.25.0) Requirement already satisfied: matplotlib in /home/antonio/anaconda3/envs/jaxdf/lib/python3.11/site-packages (3.8.2) Requirement already satisfied: contourpy>=1.0.1 in /home/antonio/anaconda3/envs/jaxdf/lib/python3.11/site-packages (from matplotlib) (1.2.0) Requirement already satisfied: cycler>=0.10 in /home/antonio/anaconda3/envs/jaxdf/lib/python3.11/site-packages (from matplotlib) (0.12.1) Requirement already satisfied: fonttools>=4.22.0 in /home/antonio/anaconda3/envs/jaxdf/lib/python3.11/site-packages (from matplotlib) (4.45.0) Requirement already satisfied: kiwisolver>=1.3.1 in /home/antonio/anaconda3/envs/jaxdf/lib/python3.11/site-packages (from matplotlib) (1.4.5) Requirement already satisfied: numpy<2,>=1.21 in /home/antonio/anaconda3/envs/jaxdf/lib/python3.11/site-packages (from matplotlib) (1.25.0) Requirement already satisfied: packaging>=20.0 in /home/antonio/anaconda3/envs/jaxdf/lib/python3.11/site-packages (from matplotlib) (23.1) Requirement already satisfied: pillow>=8 in /home/antonio/anaconda3/envs/jaxdf/lib/python3.11/site-packages (from matplotlib) (10.1.0) Requirement already satisfied: pyparsing>=2.3.1 in /home/antonio/anaconda3/envs/jaxdf/lib/python3.11/site-packages (from matplotlib) (3.1.1) Requirement already satisfied: python-dateutil>=2.7 in /home/antonio/anaconda3/envs/jaxdf/lib/python3.11/site-packages (from matplotlib) (2.8.2) Requirement already satisfied: six>=1.5 in /home/antonio/anaconda3/envs/jaxdf/lib/python3.11/site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)
from jax import config
config.update("jax_enable_x64", True)
import cv2
from matplotlib import pyplot as plt
import numpy as np
img = cv2.imread("test_img.jpg")
img = cv2.resize(img, (256, 256), interpolation=cv2.INTER_AREA)
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY).astype(np.float32) / 255.
plt.imshow(img, cmap="gray")
plt.colorbar()
plt.show()
Heat Equation¶
For illustration, we'll simulate the heat equation.
$$ \frac{\partial}{\partial t}u = \nabla^2 u $$
NOTE: We don't enforce specific boundary conditions here as they are not yet implemented. At present, boundary conditions are implicitly defined by padding for convolutive operators in Finite Differences and Fourier as periodic BC. While this isn't ideal for properly integrating the heat equation, it works whenever we use some form of absorbing layer at the boundary, which is often the case in acoustics.
Obviously, a more suitable handling of generic boundary conditions would be a valuable addition to the package. Contributions are welcome 😊
We first define the domain where fields live.
Subsequently, we specify the discretization family that will represent the inputs to the operator. In this case, we use Finite Differences, as given by the FiniteDifferences
discretization.
from jaxdf.geometry import Domain
# Setting the domain.
domain = Domain(N=img.shape, dx=(1., 1.))
from jax import numpy as jnp
import numpy as np
from jaxdf.discretization import FiniteDifferences, OnGrid
# We define the grid values from the image, making sure they are a jax.numpy array of floats
grid_values = jnp.asarray(img, dtype=np.float32)
# Then the field is defined by the combination of grid values and domain
u = FiniteDifferences.from_grid(grid_values, domain)
from matplotlib import pyplot as plt
# We can get the field on the domain grid using the .on_grid method
plt.imshow(u.on_grid, cmap="gray")
<matplotlib.image.AxesImage at 0x7f60181c9510>
Now, let's define the right-hand side of the heat equation. As the only operator required is the Laplacian, we can employ operators.laplacian
to calculate it.
from jaxdf.operators.differential import laplacian
from jax import jit
# Make RHS operator
@jit # We can jit the entire function, `Field`s from jaxdf are compatible with jax
def heat_rhs(u):
return laplacian(u)
# Apply the rhs of the operator
z = heat_rhs(u)
2023-11-24 16:05:46.095035: E external/xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng1{k2=2,k3=0} for conv (f64[1,1,256,256]{3,2,1,0}, u8[0]{0}) custom-call(f64[1,1,264,264]{3,2,1,0}, f64[1,1,9,9]{3,2,1,0}), window={size=9x9}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0} is taking a while... 2023-11-24 16:05:48.972354: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 3.877431494s Trying algorithm eng1{k2=2,k3=0} for conv (f64[1,1,256,256]{3,2,1,0}, u8[0]{0}) custom-call(f64[1,1,264,264]{3,2,1,0}, f64[1,1,9,9]{3,2,1,0}), window={size=9x9}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0} is taking a while...
# Look at the output of the laplacian
plt.figure(figsize=(8,6))
plt.imshow(z.on_grid, cmap='RdBu'); plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x7f5f10256590>
Integration using diffrax
¶
In this section, we will integrate our heat equation using the diffrax
library. diffrax
offers a suite of differential equation solvers, making it ideal for our needs.
Setting up the PDE: First, we define our PDE field, which in this case is represented by the Laplacian of ( u ).
Setting the Solver and Time Steps:
- We choose
Tsit5
as our solver - We set our initial time
t0
to 0 and the final timet_final
to 20. - The time step
dt
is set to 0.1. - Additionally, with the
SaveAt
feature, we decide to save the results at specific, evenly spaced time points.
- We choose
Solving the PDE: Using
dfx.diffeqsolve
, we pass in our term, the chosen solver, time configurations, and the initial fieldu
to obtain the solution.
Let's dive into the code:
!pip install diffrax # Install this for numerical integration in jax
Requirement already satisfied: diffrax in /home/antonio/anaconda3/envs/jaxdf/lib/python3.11/site-packages (0.4.1) Requirement already satisfied: jax>=0.4.13 in /home/antonio/anaconda3/envs/jaxdf/lib/python3.11/site-packages (from diffrax) (0.4.20) Requirement already satisfied: equinox>=0.10.11 in /home/antonio/anaconda3/envs/jaxdf/lib/python3.11/site-packages (from diffrax) (0.11.2) Requirement already satisfied: jaxtyping>=0.2.20 in /home/antonio/anaconda3/envs/jaxdf/lib/python3.11/site-packages (from equinox>=0.10.11->diffrax) (0.2.23) Requirement already satisfied: typing-extensions>=4.5.0 in /home/antonio/anaconda3/envs/jaxdf/lib/python3.11/site-packages (from equinox>=0.10.11->diffrax) (4.6.3) Requirement already satisfied: ml-dtypes>=0.2.0 in /home/antonio/anaconda3/envs/jaxdf/lib/python3.11/site-packages (from jax>=0.4.13->diffrax) (0.2.0) Requirement already satisfied: numpy>=1.22 in /home/antonio/anaconda3/envs/jaxdf/lib/python3.11/site-packages (from jax>=0.4.13->diffrax) (1.25.0) Requirement already satisfied: opt-einsum in /home/antonio/anaconda3/envs/jaxdf/lib/python3.11/site-packages (from jax>=0.4.13->diffrax) (3.3.0) Requirement already satisfied: scipy>=1.9 in /home/antonio/anaconda3/envs/jaxdf/lib/python3.11/site-packages (from jax>=0.4.13->diffrax) (1.9.3) Requirement already satisfied: typeguard<3,>=2.13.3 in /home/antonio/anaconda3/envs/jaxdf/lib/python3.11/site-packages (from jaxtyping>=0.2.20->equinox>=0.10.11->diffrax) (2.13.3)
# Integrate with diffrax
import diffrax as dfx
def pde_field(t, u, args):
return laplacian(u)
term = dfx.ODETerm(pde_field)
t0 = 0
t_final = 20
dt = 0.1
saveat = dfx.SaveAt(ts=jnp.linspace(t0, t_final, 10))
solver = dfx.Tsit5()
sol = dfx.diffeqsolve(term, solver, t0, t_final, dt, u, saveat=saveat)
snapshots = sol.ys
# Plot the solutions on a grid
fig, ax = plt.subplots(2, 5, figsize=(10,4))
# Flatten the axes
ax = ax.flatten()
for i in range(10):
ax[i].imshow(snapshots[i].on_grid, vmin=0, vmax=1.0, cmap="gray")
ax[i].set_title(f"t={sol.ts[i]:.2f}")
ax[i].set_xticks([])
ax[i].set_yticks([])
Anisotropic Diffusion example¶
Anisotropic diffusion is a more sophisticated form of heat diffusion where the diffusion rate can vary in different directions and magnitudes, often dependent on the underlying structure of the data. This is particularly useful in applications like image processing, where it can help preserve edges while diffusing noise.
Let's see how we can leverage jaxdf
to define and solve an anisotropic diffusion problem:
The anisotropic diffusion equation can be represented as:
$$ \frac{\partial u}{\partial t} = \nabla \cdot (c(u) \nabla u) $$
Where:
- $u$ is the field of interest.
- $c(u)$ is the diffusion conductivity that varies depending on the gradient magnitude of ( u ).
To translate this to code:
Divergence Operator: We start by defining the divergence of a vector field,
divergence(u, stagger)
. This computes the rate at which density exits at each point, essential for understanding how diffusion spreads. Note that we are using staggered gradients, which are a feature ofOnGrid
fields.Diffusion Conductivity:
conductivity_kernel(u)
computes the conductivity based on the magnitude of the gradient of $u$. The kernel ensures that areas with high gradients (like edges) have lower conductivity, preserving features.Gradient Magnitude:
norm(u)
computes the magnitude of the gradient, which is essential to determine the diffusion rate.Complete Anisotropic Diffusion:
anisotropic_diffusion(t, u, args)
combines the above functions to compute the divergence of the conductivity-weighted gradient of $u$, yielding the right-hand side of our PDE.
from jaxdf.operators.differential import laplacian, gradient, diag_jacobian
from jaxdf.operators.functions import compose, sum_over_dims
# What if we want to use anisotropic diffusion?
def divergence(u, stagger): # Defining the divergence operator
return sum_over_dims(diag_jacobian(u, stagger=stagger))
def conductivity_kernel(u): # Defining the diffusion kernel
kernel = lambda x: 1/(1 + (x/0.03)**2)
return compose(u)(kernel)
def norm(u):
z = sum_over_dims(u**2)
return compose(z)(jnp.sqrt)
@jit
def anisotropic_diffusion(t, u, args):
grad_u = gradient(u, stagger=[0.5])
mod_gradient = norm(grad_u)
c = conductivity_kernel(mod_gradient)
return divergence(c * grad_u, stagger=[-0.5])
# Plot the effect of the kernel
z = anisotropic_diffusion(0, u, None)
plt.imshow(z.on_grid, cmap="RdBu")
plt.colorbar()
plt.savefig("anisotropic_kernel.png")
plt.close()
term = dfx.ODETerm(anisotropic_diffusion)
sol = dfx.diffeqsolve(term, solver, t0, t_final, dt, u, saveat=saveat)
snapshots = sol.ys
# Plot the solutions on a grid
fig, ax = plt.subplots(2, 5, figsize=(10,4))
# Flatten the axes
ax = ax.flatten()
for i in range(10):
ax[i].imshow(snapshots[i].on_grid, vmin=0, vmax=1.0, cmap="gray")
ax[i].set_title(f"t={sol.ts[i]:.2f}")
ax[i].set_xticks([])
ax[i].set_yticks([])
By constructing the anisotropic diffusion operator using jaxdf
, we've highlighted several features of the library:
- Composition: Ability to define complex differential operators, such as the divergence, gradient, and custom conductivity kernel, using other operators.
- Performance: Using
@jit
to just-in-time compile our functions, ensuring optimal performance during execution.
Customizing the Discretization¶
Exploring different discretizations is a powerful feature in jaxdf
. For instance, if we wish to go from FiniteDifferences
to a FourierSeries
discretization, the transition is seamless with the following steps:
- Define a New Field: Initialize your field using the desired FourierSeries discretization.
- Invoke the Operator: Simply call the previously defined operator on the newly discretized field.
This flexibility ensures that researchers and developers can effortlessly experiment with various discretizations, often without restructuring their primary operators.
from jaxdf.discretization import FourierSeries
u_f = FourierSeries.from_grid(u.on_grid, domain)
# We can reuse the previous code!
sol = dfx.diffeqsolve(term, solver, t0, t_final, dt, u_f, saveat=saveat)
snapshots_fourier = sol.ys
# Plot the solutions on a grid
fig, ax = plt.subplots(1, 3, figsize=(10, 3))
ax = ax.flatten()
ax[0].imshow(snapshots[-1].on_grid, vmin=0, vmax=0.8, cmap="gray")
ax[0].set_title(f"FiniteDifferences")
ax[1].imshow(snapshots_fourier[-1].on_grid, vmin=0, vmax=0.8, cmap="gray")
ax[1].set_title(f"FourierSeries")
difference = np.abs(snapshots_fourier[-1].on_grid - snapshots[-1].on_grid)
ax[2].imshow(difference, vmin=0, vmax=0.1, cmap="inferno")
ax[2].set_title(f"Abs. difference")
for i in range(3):
ax[i].set_xticks([])
ax[i].set_yticks([])
Automatic Differentiation with jaxdf
¶
One of the significant strengths of the jaxdf
framework is its seamless integration with jax
's automatic differentiation capabilities. By leveraging jax
's native transformations, users can efficiently compute gradients with respect to the arguments of operators.
from jax import value_and_grad
@value_and_grad
def loss_fn(u: FourierSeries):
y = compose(2*jnp.pi*u)(jnp.sin)
z = laplacian(y)
return jnp.mean(z.on_grid**2)
lossval, z = loss_fn(u_f)
# The gradient is once again a Field
plt.imshow(z.on_grid, cmap="RdBu" , vmin=-0.01, vmax=0.01)
plt.colorbar()
print(f"Loss: {lossval}")
Loss: 1.5509734922469067
Handling Operator Parameters¶
Operators in computational methods often come with associated parameters that play a critical role in their computation. In jaxdf
, these parameters can be discretization-dependent, adapting based on the specific method in use.
For instance, the laplacian
operator:
- In the context of
FiniteDifferences
, it's implemented using a stencil. - For
FourierSeries
, it relies on transformations of thek-axis
(spatial frequency axis).
The .default_params
method provides insight into these parameters. When invoked with the same arguments as the operator, it returns the default parameters utilized.
# Get the operator parameters for a FiniteDifferences field
# (remember that type(u) = FiniteDifferences)
fd_stencil = laplacian.default_params(u)
plt.imshow(fd_stencil, vmin=-1, vmax=1, cmap="RdBu")
plt.colorbar()
plt.axis('off')
plt.title("Laplacian stencil for FiniteDifferences fields")
Text(0.5, 1.0, 'Laplacian stencil for FiniteDifferences fields')
# Get the operator parameters for a FourierSeries field
# (remember that type(u) = FourierSeries)
f_params = laplacian.default_params(u_f)
f_params.keys()
dict_keys(['k_vec'])
By design, these parameters are statically compiled into the XLA function. However, jaxdf
offers flexibility for scenarios such as when:
- Parameters have been manually adjusted.
- Parameters are sizeable, making static compilation inefficient.
- Different operators share the same parameters, as seen in Fourier methods. In such cases, users might desire to use a consistent variable across all.
To accommodate these scenarios, every operator
in jaxdf
has a params
keyword. This reserved keyword allows users to override default parameters, giving them precise control over the computation.
# Get the default parameters
stencil = laplacian.default_params(u)
# Change the middle element
stencil[5,5] = -10.0
# Call the laplacian with the new parameters
z = laplacian(u, params=stencil)
# Check the result
plt.imshow(z.on_grid)
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x7f604fade590>
jaxdf
doesn't just stop at providing access to parameters. Users can also apply functional transformations directly to these parameters too!
To illustrate this, let's embark on an example. We'll optimize the stencil of the laplacian
operator for a FiniteDifferences
field to align its outcome more closely with the FourierSeries
version.
stencil = laplacian.default_params(u)
@value_and_grad
def loss_fn(stencil, u: FiniteDifferences, v: FourierSeries):
z_fd = laplacian(u, params=stencil) # Note the explicit parameters being passed
z_fs = laplacian(v)
mse = jnp.sum(((z_fd - z_fs)**2).on_grid)
return mse
lossval, stencil_grad = loss_fn(stencil, u, u_f)
!pip install optax # Install optax for optimization in jax
Requirement already satisfied: optax in /home/antonio/anaconda3/envs/jaxdf/lib/python3.11/site-packages (0.1.7) Requirement already satisfied: absl-py>=0.7.1 in /home/antonio/anaconda3/envs/jaxdf/lib/python3.11/site-packages (from optax) (2.0.0) Requirement already satisfied: chex>=0.1.5 in /home/antonio/anaconda3/envs/jaxdf/lib/python3.11/site-packages (from optax) (0.1.84) Requirement already satisfied: jax>=0.1.55 in /home/antonio/anaconda3/envs/jaxdf/lib/python3.11/site-packages (from optax) (0.4.20) Requirement already satisfied: jaxlib>=0.1.37 in /home/antonio/anaconda3/envs/jaxdf/lib/python3.11/site-packages (from optax) (0.4.20+cuda12.cudnn89) Requirement already satisfied: numpy>=1.18.0 in /home/antonio/anaconda3/envs/jaxdf/lib/python3.11/site-packages (from optax) (1.25.0) Requirement already satisfied: typing-extensions>=4.2.0 in /home/antonio/anaconda3/envs/jaxdf/lib/python3.11/site-packages (from chex>=0.1.5->optax) (4.6.3) Requirement already satisfied: toolz>=0.9.0 in /home/antonio/anaconda3/envs/jaxdf/lib/python3.11/site-packages (from chex>=0.1.5->optax) (0.12.0) Requirement already satisfied: ml-dtypes>=0.2.0 in /home/antonio/anaconda3/envs/jaxdf/lib/python3.11/site-packages (from jax>=0.1.55->optax) (0.2.0) Requirement already satisfied: opt-einsum in /home/antonio/anaconda3/envs/jaxdf/lib/python3.11/site-packages (from jax>=0.1.55->optax) (3.3.0) Requirement already satisfied: scipy>=1.9 in /home/antonio/anaconda3/envs/jaxdf/lib/python3.11/site-packages (from jax>=0.1.55->optax) (1.9.3)
import optax
optimizer = optax.adam(0.02)
opt_state = optimizer.init(stencil)
new_stencil = stencil
@jit
def step(stencil, opt_state, u, v):
lossval, stencil_grad = loss_fn(stencil, u, u_f)
updates, opt_state = optimizer.update(stencil_grad, opt_state)
stencil = optax.apply_updates(stencil, updates)
return lossval, stencil, opt_state
for i in range(4001):
lossval, new_stencil, opt_state = step(new_stencil, opt_state, u, u_f)
if i % 500 == 0:
print(f"Step: {i} - Loss {lossval}")
Step: 0 - Loss 142.6593552085341 Step: 500 - Loss 74.48526520856781 Step: 1000 - Loss 39.88113557681022 Step: 1500 - Loss 20.86625953473459 Step: 2000 - Loss 13.724514375728456 Step: 2500 - Loss 10.244072447332503 Step: 3000 - Loss 8.509387549827993 Step: 3500 - Loss 7.6116108961437625 Step: 4000 - Loss 8.342067758718336
# Check the performances
z_FD = laplacian(u)
z_FS = laplacian(u_f)
z_opt = laplacian(u, params = new_stencil)
fd_diff = jnp.abs((z_FD - z_FS).on_grid)
opt_diff = jnp.abs((z_opt - z_FS).on_grid)
fig, ax = plt.subplots(2, 3, figsize=(10, 6))
ax = ax.flatten()
ax[0].imshow(z_FS.on_grid, vmin=-1, vmax=1, cmap="RdBu")
ax[0].set_title(f"FourierSeries")
ax[1].imshow(z_FD.on_grid, vmin=-1, vmax=1, cmap="RdBu")
ax[1].set_title(f"FiniteDifferences")
ax[2].imshow(fd_diff, vmin=0, vmax=0.3, cmap="inferno")
ax[2].set_title(f"FiniteDifferences Error")
ax[3].imshow(new_stencil, vmin=-1, vmax=1, cmap="RdBu")
ax[3].set_title(f"Optimized Stencil")
ax[4].imshow(z_opt.on_grid, vmin=-1, vmax=1, cmap="RdBu")
ax[4].set_title(f"Optimized")
ax[5].imshow(opt_diff, vmin=0, vmax=0.3, cmap="inferno")
ax[5].set_title(f"Optimized Error")
for i in range(6):
ax[i].set_xticks([])
ax[i].set_yticks([])
More complex operators¶
We can of course construct more complex operators than the heat equation, using composition. Beyond the core functionality of operators, jaxdf
grants direct access to the numerical parameters of the fields. This feature is invaluable for scenarios requiring fine-tuned control or experimentation with parameters.
from jaxdf.operators.functions import compose
from jaxdf import Field
from jax.nn import relu
from jax import random, lax
seed = random.PRNGKey(42)
cnn_kernel = random.normal(seed, (1,1,3,3))
def silly_cnn(x: jnp.ndarray, kernel: jnp.ndarray):
x = jnp.moveaxis(x, -1, 0)
x = jnp.expand_dims(x,0)
out_conv = lax.conv(x, cnn_kernel, (1,1), padding='same')
out_conv = relu(out_conv)[0]
out = jnp.moveaxis(out_conv, 0, -1)
return out
@jit
def f(u: Field):
L = laplacian(u)
field_on_grid = L.on_grid # Represent the field as a standard jnp array
new_grid_values = silly_cnn(field_on_grid, cnn_kernel) # Manipulate its values with a neural network
p = FiniteDifferences.from_grid(new_grid_values, u.domain) # Generate a new FiniteDifferences field
# Apply jaxdf operators again
p = compose(p)(jnp.sin)
return 0.01*p
z = f(u)
print(z)
FiniteDifferences( params=f64[256,256,1], domain=Domain(N=(256, 256), dx=(1.0, 1.0)), accuracy=8 )