Operator Parameters Guide¶
In jaxdf, every @operator has a params keyword argument that controls its behavior. For example, the Fourier gradient uses frequency vectors, while a finite-difference gradient uses stencil coefficients. These params are:
- Automatically computed if you don't provide them
- Inspectable via
.default_params() - Overridable by passing
params=your_values - Differentiable —
jax.gradflows through them - Any PyTree — including
eqx.Moduleobjects (neural networks)
This guide walks through the params system using a real image, building up from basic inspection to training a neural network that improves the accuracy of a finite-difference gradient operator.
Setup¶
Let's load an image and create fields from it, just like in the quickstart tutorial.
import jax
import jax.numpy as jnp
import equinox as eqx
import optax
import numpy as np
from matplotlib import pyplot as plt
import urllib.request, os
from jaxdf import operator, Domain, FourierSeries, FiniteDifferences, OnGrid
from jaxdf.operators import gradient, laplacian, diag_jacobian, sum_over_dims
from jaxdf.util import get_implementations, has_implementation
# Download a test image
if not os.path.exists("test_img.jpg"):
url = "https://upload.wikimedia.org/wikipedia/commons/a/a0/Black_and_White_1_-_Augusto_De_Luca_photographer.jpg"
req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'})
with urllib.request.urlopen(req) as response:
with open("test_img.jpg", 'wb') as f:
f.write(response.read())
import cv2
img = cv2.imread("test_img.jpg")
img = cv2.resize(img, (128, 128), interpolation=cv2.INTER_AREA)
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY).astype(np.float32) / 255.
plt.figure(figsize=(4, 4))
plt.imshow(img, cmap="gray")
plt.title("Test image")
plt.axis("off")
plt.show()
# Create fields from the image
domain = Domain(N=img.shape, dx=(1.0, 1.0))
grid_values = jnp.asarray(img)
u_fourier = FourierSeries.from_grid(grid_values, domain)
u_fd = FiniteDifferences.from_grid(grid_values, domain)
print(f"Fourier field: {u_fourier}")
print(f"FD field: {u_fd}")
Fourier field: FourierSeries(params=f32[128,128,1], domain=Domain(N=(128, 128), dx=(1.0, 1.0))) FD field: FiniteDifferences( params=f32[128,128,1], domain=Domain(N=(128, 128), dx=(1.0, 1.0)) )
1. Inspecting Implementations and Parameters¶
Each @operator function dispatches based on the type of its first argument. The gradient operator, for example, has separate implementations for FourierSeries (using FFTs), FiniteDifferences (using stencil convolution), and Continuous (using jax.jacfwd).
You can see which types are supported:
print("Implementations of `gradient`:")
for impl in get_implementations(gradient):
print(f" {impl}")
print()
print("Does gradient support FourierSeries?", has_implementation(gradient, FourierSeries))
print("Does gradient support FiniteDifferences?", has_implementation(gradient, FiniteDifferences))
Implementations of `gradient`:
('Continuous',)
('FiniteDifferences',)
('FourierSeries',)
Does gradient support FourierSeries? True
Does gradient support FiniteDifferences? True
Each implementation may have different parameters. For example, the Fourier gradient needs frequency vectors, while the FD gradient needs stencil kernels. You can inspect these with .default_params():
# Fourier gradient params: a dict containing frequency vectors
fourier_params = gradient.default_params(u_fourier)
print("Fourier params type:", type(fourier_params).__name__)
print("Keys:", list(fourier_params.keys()))
print("k_vec shapes:", [k.shape for k in fourier_params["k_vec"]])
print()
# FD gradient params: a list of stencil kernels (one per spatial axis)
fd_params = gradient.default_params(u_fd)
print("FD params type:", type(fd_params).__name__)
print("Number of stencil kernels:", len(fd_params))
print("Kernel for axis 0 (shape", fd_params[0].shape, "):")
print(fd_params[0].flatten())
Fourier params type: dict Keys: ['k_vec'] k_vec shapes: [(65,), (65,)] FD params type: list Number of stencil kernels: 2 Kernel for axis 0 (shape (9, 1) ): [ 0.00357143 -0.03809524 0.2 -0.8 -0. 0.8 -0.2 0.03809524 -0.00357143]
Let's see what the gradient looks like on our image with each discretization:
grad_fourier = gradient(u_fourier)
grad_fd = gradient(u_fd)
fig, axes = plt.subplots(1, 3, figsize=(14, 4))
axes[0].imshow(img, cmap="gray")
axes[0].set_title("Original image")
# Show x-component of gradient (axis 0)
axes[1].imshow(grad_fourier.params[..., 0], cmap="RdBu", vmin=-0.3, vmax=0.3)
axes[1].set_title("Fourier gradient (x)")
axes[2].imshow(grad_fd.params[..., 0], cmap="RdBu", vmin=-0.3, vmax=0.3)
axes[2].set_title("FD gradient (x, accuracy=8)")
for ax in axes:
ax.axis("off")
plt.tight_layout()
plt.show()
2. Overriding Parameters¶
When you call an operator without params=, it auto-computes the default params. But you can pass your own values to override them.
For example, let's manually modify the FD stencil to see the effect:
# Get the default stencil for FD gradient
default_stencil = gradient.default_params(u_fd)
print("Default stencil (axis 0):", default_stencil[0].flatten())
# Create a simpler stencil: [-1, 0, 1] / 2 (central difference, accuracy=2)
simple_kernel_1d = np.array([-0.5, 0.0, 0.5])
simple_stencil = [
simple_kernel_1d.reshape(-1, 1) / domain.dx[0], # axis 0
simple_kernel_1d.reshape(1, -1) / domain.dx[1], # axis 1
]
# Apply both
grad_default = gradient(u_fd)
grad_simple = gradient(u_fd, params=simple_stencil)
fig, axes = plt.subplots(1, 3, figsize=(14, 4))
axes[0].imshow(img, cmap="gray")
axes[0].set_title("Original")
axes[1].imshow(grad_default.params[..., 0], cmap="RdBu", vmin=-0.3, vmax=0.3)
axes[1].set_title("Default stencil (accuracy=8)")
axes[2].imshow(grad_simple.params[..., 0], cmap="RdBu", vmin=-0.3, vmax=0.3)
axes[2].set_title("Simple stencil [-0.5, 0, 0.5]")
for ax in axes:
ax.axis("off")
plt.tight_layout()
plt.show()
Default stencil (axis 0): [ 0.00357143 -0.03809524 0.2 -0.8 -0. 0.8 -0.2 0.03809524 -0.00357143]
3. Sharing Parameters Between Operators¶
Different operators sometimes use the same underlying data. For FourierSeries fields, both gradient and diag_jacobian need the same frequency vectors (k_vec). Instead of computing them separately, you can compute once and share:
# Both operators use the same k_vec
shared_k = {"k_vec": u_fourier._freq_axis}
# Use the shared params in a composed computation: laplacian = sum(diag_jacobian(gradient(u)))
g = gradient(u_fourier, params=shared_k)
dj = diag_jacobian(g, params=shared_k)
lap = sum_over_dims(dj)
# Compare with the built-in laplacian
lap_builtin = laplacian(u_fourier)
fig, axes = plt.subplots(1, 3, figsize=(14, 4))
axes[0].imshow(lap.params[..., 0], cmap="RdBu", vmin=-5, vmax=5)
axes[0].set_title("Laplacian (shared params)")
axes[1].imshow(lap_builtin.params[..., 0], cmap="RdBu", vmin=-5, vmax=5)
axes[1].set_title("Built-in laplacian")
axes[2].imshow(img, cmap="gray")
axes[2].set_title("Original")
for ax in axes:
ax.axis("off")
plt.tight_layout()
plt.show()
4. Composed Operators with Nested Parameters¶
When you build an operator that internally calls other operators, you can manage their parameters as a nested dictionary. The init_params function collects params from all sub-operators:
# Define a composed operator: edge detector = |gradient|
def edge_detect_init(u: FourierSeries, *args, **kwargs):
"""Collect params from the sub-operators we'll use."""
return {
"gradient": gradient.default_params(u),
}
@operator(init_params=edge_detect_init)
def edge_detect(u: FourierSeries, *, params=None):
"""Computes edge magnitude from the gradient."""
g = gradient(u, params=params["gradient"])
# Magnitude: sqrt(gx^2 + gy^2)
magnitude = jnp.sqrt(jnp.sum(g.params ** 2, axis=-1, keepdims=True))
return u.replace_params(magnitude)
# Use it
edges = edge_detect(u_fourier)
# Inspect the nested params tree
p = edge_detect.default_params(u_fourier)
print("Params keys:", list(p.keys()))
print("Nested gradient params keys:", list(p["gradient"].keys()))
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
axes[0].imshow(img, cmap="gray")
axes[0].set_title("Original")
axes[1].imshow(edges.params[..., 0], cmap="hot", vmin=0, vmax=0.5)
axes[1].set_title("Edge detection via composed operator")
for ax in axes:
ax.axis("off")
plt.tight_layout()
plt.show()
Params keys: ['gradient'] Nested gradient params keys: ['k_vec']
5. Neural Networks as Operator Parameters¶
Operator params can be any PyTree, including eqx.Module objects. This means you can put an entire neural network as the params of an operator, and jax.grad will flow through it.
Convention: pass the whole eqx.Module, not just its weights. This follows the equinox idiom and gives you eqx.apply_updates, eqx.partition, and clean PyTree structure for free.
Here's a simple example: a learnable pointwise operator on our image.
class PointwiseNet(eqx.Module):
"""A simple learnable pointwise transformation."""
scale: jax.Array
bias: jax.Array
def __call__(self, x):
return self.scale * x + self.bias
# The init_params returns the entire network
def pointwise_init(u: OnGrid, *args, **kwargs):
return PointwiseNet(
scale=jnp.ones(u.params.shape),
bias=jnp.zeros(u.params.shape),
)
@operator(init_params=pointwise_init)
def pointwise_op(u: OnGrid, *, params=None):
"""Apply a learnable pointwise transformation. params IS the network."""
return u.replace_params(params(u.params))
# Default params: identity transform (scale=1, bias=0)
result_default = pointwise_op(u_fourier)
# Custom: contrast enhancement (scale=2, bias=-0.3)
enhancer = PointwiseNet(
scale=jnp.full(u_fourier.params.shape, 2.0),
bias=jnp.full(u_fourier.params.shape, -0.3),
)
result_enhanced = pointwise_op(u_fourier, params=enhancer)
fig, axes = plt.subplots(1, 3, figsize=(14, 4))
axes[0].imshow(img, cmap="gray", vmin=0, vmax=1)
axes[0].set_title("Original")
axes[1].imshow(result_default.params[..., 0], cmap="gray", vmin=0, vmax=1)
axes[1].set_title("Default (identity)")
axes[2].imshow(jnp.clip(result_enhanced.params[..., 0], 0, 1), cmap="gray", vmin=0, vmax=1)
axes[2].set_title("Custom (2x - 0.3)")
for ax in axes:
ax.axis("off")
plt.tight_layout()
plt.show()
Since params is a PyTree, jax.grad flows through the network automatically:
# Differentiate a loss w.r.t. the network params
def loss(net):
result = pointwise_op(u_fourier, params=net)
return jnp.mean(result.params ** 2)
grads = jax.grad(loss)(enhancer)
print("Gradient type:", type(grads).__name__)
print("d(loss)/d(scale) — finite?", bool(jnp.all(jnp.isfinite(grads.scale))))
print("d(loss)/d(bias) — finite?", bool(jnp.all(jnp.isfinite(grads.bias))))
# jit + grad also works
grads_jit = jax.jit(jax.grad(loss))(enhancer)
print("jit+grad matches?", bool(jnp.allclose(grads.scale, grads_jit.scale)))
Gradient type: PointwiseNet d(loss)/d(scale) — finite? True d(loss)/d(bias) — finite? True jit+grad matches? True
6. Dynamic Parameters: Neural Network-Generated Stencils¶
A more advanced pattern: the operator params contain a neural network that generates the stencil dynamically, based on the input field. The neural network weights are the learnable parameters; the stencil is computed fresh on each call.
This is different from the previous section: there, the network was the operator. Here, the network generates the parameters of a classical operator (a convolution stencil).
┌─────────────────┐
field stats ──► │ NN (in params) │ ──► stencil coefficients ──► convolve(field, stencil)
└─────────────────┘
from jaxdf.conv import reflection_conv
class StencilGenerator(eqx.Module):
"""Neural network that generates a 1D FD stencil from field statistics."""
linear1: eqx.nn.Linear
linear2: eqx.nn.Linear
linear3: eqx.nn.Linear
def __init__(self, stencil_size, key):
k1, k2, k3 = jax.random.split(key, 3)
self.linear1 = eqx.nn.Linear(2, 16, key=k1)
self.linear2 = eqx.nn.Linear(16, 16, key=k2)
self.linear3 = eqx.nn.Linear(16, stencil_size, key=k3)
def __call__(self, stats):
x = jax.nn.tanh(self.linear1(stats))
x = jax.nn.tanh(self.linear2(x))
return self.linear3(x)
@operator
def nn_gradient(u: FiniteDifferences, *, params=None):
"""Gradient using an NN-generated stencil. params is a StencilGenerator."""
# Compute statistics of the input field
field_stats = jnp.array([jnp.mean(u.params), jnp.var(u.params)])
# Generate stencil from the NN
stencil_1d = params(field_stats)
# Apply as a derivative kernel along each axis
array = u.on_grid
outs = []
for axis in range(u.domain.ndim):
kernel = stencil_1d
for _ in range(u.domain.ndim - 1):
kernel = jnp.expand_dims(kernel, 0)
kernel = jnp.moveaxis(kernel, -1, axis)
outs.append(reflection_conv(kernel, array[..., 0], reverse=True))
return u.replace_params(jnp.stack(outs, axis=-1))
# Create a generator and use it
gen = StencilGenerator(stencil_size=5, key=jax.random.PRNGKey(0))
result = nn_gradient(u_fd, params=gen)
print("Generated stencil:", gen(jnp.array([jnp.mean(u_fd.params), jnp.var(u_fd.params)])))
print("Result shape:", result.params.shape)
# Verify jax.grad works through the stencil generator
def gen_loss(generator):
result = nn_gradient(u_fd, params=generator)
target = gradient(u_fd)
return jnp.mean((result.params - target.params) ** 2)
grads = jax.grad(gen_loss)(gen)
print("Gradient through stencil generator — finite?",
all(bool(jnp.all(jnp.isfinite(leaf))) for leaf in jax.tree.leaves(grads)))
Generated stencil: [ 0.09484302 -0.14690985 0.15311614 0.0696418 -0.35178652] Result shape: (128, 128, 2) Gradient through stencil generator — finite? True
7. Training a Learned Gradient Correction¶
Let's put it all together with a practical example. We'll take a low-accuracy finite-difference gradient (accuracy=2, which only uses a 3-point stencil) and train a CNN to correct it to match a high-accuracy reference (accuracy=8, which uses a 9-point stencil).
The key idea:
- The operator params contain both the FD stencil AND a CNN correction network
- During each call, we compute the cheap FD gradient and add a learned correction
- We train the CNN by backpropagating through the entire operator
# Create coarse (accuracy=2) and fine (accuracy=8) fields from the same image
u_coarse = FiniteDifferences.from_grid(grid_values, domain)
u_coarse = FiniteDifferences(u_coarse.params, domain, accuracy=2)
u_fine = FiniteDifferences.from_grid(grid_values, domain)
# u_fine already has accuracy=8 by default
# The target: high-accuracy gradient
target = gradient(u_fine)
# The baseline: low-accuracy gradient (this is what we want to improve)
baseline = gradient(u_coarse)
# Show the difference
fig, axes = plt.subplots(1, 4, figsize=(18, 4))
axes[0].imshow(img, cmap="gray")
axes[0].set_title("Original image")
axes[1].imshow(target.params[..., 0], cmap="RdBu", vmin=-0.3, vmax=0.3)
axes[1].set_title("Target gradient\n(accuracy=8)")
axes[2].imshow(baseline.params[..., 0], cmap="RdBu", vmin=-0.3, vmax=0.3)
axes[2].set_title("Baseline gradient\n(accuracy=2)")
axes[3].imshow((target.params - baseline.params)[..., 0], cmap="RdBu", vmin=-0.05, vmax=0.05)
axes[3].set_title("Error to correct")
for ax in axes:
ax.axis("off")
plt.tight_layout()
plt.show()
print(f"Baseline MSE: {float(jnp.mean((target.params - baseline.params) ** 2)):.6f}")
Baseline MSE: 0.001319
# The correction CNN
class GradientCorrector(eqx.Module):
"""CNN that corrects the coarse FD gradient to match a fine one."""
conv1: eqx.nn.Conv2d
conv2: eqx.nn.Conv2d
conv3: eqx.nn.Conv2d
def __init__(self, key):
k1, k2, k3 = jax.random.split(key, 3)
self.conv1 = eqx.nn.Conv2d(2, 8, kernel_size=5, padding=2, key=k1)
self.conv2 = eqx.nn.Conv2d(8, 8, kernel_size=5, padding=2, key=k2)
self.conv3 = eqx.nn.Conv2d(8, 2, kernel_size=5, padding=2, key=k3)
def __call__(self, x):
# x: (H, W, C) -> (C, H, W) for Conv2d
x = jnp.moveaxis(x, -1, 0)
x = jax.nn.gelu(self.conv1(x))
x = jax.nn.gelu(self.conv2(x))
x = self.conv3(x)
return jnp.moveaxis(x, 0, -1)
# Define the corrected gradient operator
# params = {"stencil": <FD stencil>, "corrector": <CNN>}
def corrected_init(u: FiniteDifferences, *args, **kwargs):
return {
"stencil": gradient.default_params(u),
"corrector": GradientCorrector(jax.random.PRNGKey(0)),
}
@operator(init_params=corrected_init)
def corrected_gradient(u: FiniteDifferences, *, params=None):
"""FD gradient + learned CNN correction."""
base = gradient(u, params=params["stencil"])
correction = params["corrector"](base.params)
return base.replace_params(base.params + correction)
# Quick test
test_result = corrected_gradient(u_coarse)
print("Corrected gradient shape:", test_result.params.shape)
Corrected gradient shape: (128, 128, 2)
Now let's train the correction network. The loss is the MSE between the corrected gradient and the high-accuracy target:
def train_loss(corrector):
params = corrected_gradient.default_params(u_coarse)
params = {**params, "corrector": corrector}
result = corrected_gradient(u_coarse, params=params)
return jnp.mean((result.params - target.params) ** 2)
corrector = GradientCorrector(jax.random.PRNGKey(42))
optimizer = optax.adam(3e-4)
opt_state = optimizer.init(corrector)
@jax.jit
def step(net, opt_state):
loss_val, grads = jax.value_and_grad(train_loss)(net)
updates, opt_state = optimizer.update(grads, opt_state)
net = optax.apply_updates(net, updates)
return net, opt_state, loss_val
losses = []
for i in range(5001):
corrector, opt_state, loss_val = step(corrector, opt_state)
losses.append(float(loss_val))
if i % 500 == 0:
print(f"Step {i:4d} | Loss: {loss_val:.8f}")
Step 0 | Loss: 0.00275118 Step 500 | Loss: 0.00006772 Step 1000 | Loss: 0.00002710 Step 1500 | Loss: 0.00001886 Step 2000 | Loss: 0.00001488 Step 2500 | Loss: 0.00001239 Step 3000 | Loss: 0.00001063 Step 3500 | Loss: 0.00000930 Step 4000 | Loss: 0.00000823 Step 4500 | Loss: 0.00000731 Step 5000 | Loss: 0.00000653
# Training curve
plt.figure(figsize=(8, 3))
plt.semilogy(losses)
plt.xlabel("Training step")
plt.ylabel("MSE loss")
plt.title("Training the gradient correction CNN")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
Let's see the results — the corrected gradient should look much closer to the high-accuracy target:
# Apply the trained corrector
params = corrected_gradient.default_params(u_coarse)
params = {**params, "corrector": corrector}
corrected = corrected_gradient(u_coarse, params=params)
fig, axes = plt.subplots(2, 4, figsize=(18, 8))
# Row 1: gradient fields
axes[0, 0].imshow(img, cmap="gray")
axes[0, 0].set_title("Original image")
axes[0, 1].imshow(target.params[..., 0], cmap="RdBu", vmin=-0.3, vmax=0.3)
axes[0, 1].set_title("Target (accuracy=8)")
axes[0, 2].imshow(baseline.params[..., 0], cmap="RdBu", vmin=-0.3, vmax=0.3)
axes[0, 2].set_title("Baseline (accuracy=2)")
axes[0, 3].imshow(corrected.params[..., 0], cmap="RdBu", vmin=-0.3, vmax=0.3)
axes[0, 3].set_title("Corrected (accuracy=2 + CNN)")
# Row 2: errors
axes[1, 0].axis("off")
axes[1, 1].axis("off")
err_baseline = (target.params - baseline.params)[..., 0]
err_corrected = (target.params - corrected.params)[..., 0]
vmax_err = max(float(jnp.max(jnp.abs(err_baseline))), 0.01)
axes[1, 2].imshow(err_baseline, cmap="RdBu", vmin=-vmax_err, vmax=vmax_err)
axes[1, 2].set_title(f"Baseline error\nMSE={float(jnp.mean(err_baseline**2)):.6f}")
axes[1, 3].imshow(err_corrected, cmap="RdBu", vmin=-vmax_err, vmax=vmax_err)
axes[1, 3].set_title(f"Corrected error\nMSE={float(jnp.mean(err_corrected**2)):.6f}")
for ax in axes.flat:
ax.axis("off")
plt.suptitle("Learned Gradient Correction: Before vs After", fontsize=14, y=1.01)
plt.tight_layout()
plt.show()
What did the CNN learn?¶
We can visualize the CNN's effective kernel by applying it to a delta function (impulse response). This shows what spatial pattern the network learned to correct:
# Impulse response: apply corrector to a delta function
delta_input = jnp.zeros((128, 128, 2))
delta_input = delta_input.at[64, 64, :].set(1.0)
impulse = corrector(delta_input)
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
for i, title in enumerate(["Learned kernel (d/dx correction)", "Learned kernel (d/dy correction)"]):
# Zoom into center 15x15
kernel = impulse[57:71, 57:71, i]
vmax = float(jnp.max(jnp.abs(kernel)))
im = axes[i].imshow(kernel, cmap="RdBu", vmin=-vmax, vmax=vmax)
axes[i].set_title(title)
plt.colorbar(im, ax=axes[i], shrink=0.8)
# Add grid
axes[i].set_xticks(jnp.arange(-0.5, 14, 1), minor=True)
axes[i].set_yticks(jnp.arange(-0.5, 14, 1), minor=True)
axes[i].grid(which='minor', color='gray', linewidth=0.5, alpha=0.3)
plt.suptitle("CNN Impulse Response (center 14x14 pixels)", fontsize=12)
plt.tight_layout()
plt.show()
Summary¶
| Pattern | Example | params contains |
|---|---|---|
| Inspect & override | gradient(u, params=my_stencil) |
Arrays (stencils, k-vectors) |
| Shared params | Same k_vec for gradient and diag_jacobian |
Dict referenced by multiple operators |
| Nested params | Edge detector composing gradient |
{"gradient": {...}} |
| Network IS the operator | pointwise_op(u, params=my_net) |
eqx.Module — the whole network |
| Network GENERATES stencils | nn_gradient(u, params=stencil_gen) |
eqx.Module — called inside operator to produce stencil |
| Learned correction | corrected_gradient(u, params={"stencil": ..., "corrector": cnn}) |
Dict with both classical params and eqx.Module |
All patterns work with jax.grad, jax.jit, and jax.vmap.
Key convention: pass the whole eqx.Module as params, not just its weights. This follows the equinox idiom and gives you eqx.apply_updates and eqx.partition for free.