Discretization APIĀ¶
jaxdf
revolves around the concept of discretization.
We will call discretization family the mapping $\mathcal{D}$ that associates a function $f$ to a set of discrete parameters $\theta$
$$ \theta \xrightarrow{\mathcal{D}}f $$
with $f \in \text{Range}(\mathcal{D})$ or, in other words, $f_\theta(x) = \mathcal{D}(\theta, x)$ is a function parametrized by $\theta$.
$\theta$ is the discrete representation of $f$ over $\mathcal{D}$. The latter is analogous to the interpolation function defined in other libraries (See for example the Operator Discretization Library)
ExampleĀ¶
A simple example of discretization family is the set of $N$-th order polynomials on the interval $[0,1)$:
$$ \mathcal{P}_N(\theta,x) = \sum_{i=0}^N \theta_i x^i, \qquad \theta \in \mathbb{R}^{N+1} $$
In jaxdf
, we construct such a field using the Continuous
discretization. To do so, we have to provide the function $\mathcal{P}_N$, the parameters of the function and the domain where the function is defined
from jaxdf.discretization import Continuous
from jaxdf.geometry import Domain
from jax.random import normal, PRNGKey
from jax import numpy as jnp
seed = PRNGKey(42)
N = 5
# This defines the spatial domain of the function
domain = Domain((256,),(1/64.,))
# This is the mapping from the parameters to the function
def p_n(theta, x):
i = jnp.arange(N)
powers = x**i
return jnp.expand_dims(jnp.sum(theta*(x**i)), -1)
# Here we generate a random set of parameters
params = normal(seed, (N,))
# Finally, we place them all together in a single object
# that can be used to evaluate the function and apply operators
# to it.
u = Continuous(params, domain, p_n)
print(u)
Continuous( params=f32[5], domain=Domain(N=(256,), dx=(0.015625,)), get_fun=<function p_n> )
Let's look at the field over the domain using the on_grid
method
from matplotlib import pyplot as plt
def show_field(grid_representation, domain):
plt.plot(domain.spatial_axis[0], grid_representation)
plt.xlabel("x")
plt.ylabel("$f$")
plt.show()
field_on_grid = u.on_grid
show_field(field_on_grid, domain)
To get the field at a specific location $x$, we can simply call the field with at the required coordinates.
x = 1.2
field_at_x = u(x)
print(f"Field at x={x} : {field_at_x}")
Field at x=1.2 : [-2.5217302]
Fields are pytrees, and are based on equinox Module
, so they natively work jax.jit
, jax.grad
etc.
Let's try to apply derivative operator to this newly defined field
from jaxdf.operators import derivative
du_dx = derivative(u)
show_field(du_dx.on_grid, domain)
The parameter of the field du_dx
are the same as u
, since for Continuous
fields the gradient
operator is evaluated using autograd, which is an operation that only affects the computational grah of a function but not its inputs
u.Īø, du_dx.Īø # .Īø is a shorthand for .params
(Array([ 0.6122652, 1.1225883, -0.8544134, -0.8127325, -0.890405 ], dtype=float32), Array([ 0.6122652, 1.1225883, -0.8544134, -0.8127325, -0.890405 ], dtype=float32))
Customizing discretizationsĀ¶
For a polynomial field, we actually know analytically how to compute derivatives:
$$ \frac{\partial}{\partial x}\mathcal{P}_N(\theta,x) = \frac{\partial}{\partial x} \sum_{i=0}^N \theta_i x^i = \sum_{i=0}^{N-1} i\theta_{i+1} x^i $$
To use this knowledge, we first define a new discretization family from the Continuous
one, and then we define the gradient
method using the analytical formula.
This new class needs to initialize the parent class using the super().__init__()
method; the input parameters are params,domain,get_field
, however we can use knowledge of the formula for the get_fun
.
The params
one must be a PyTree
compatible with jax.numpy
(arrays, dictionaries of arrays, equinox modules, etc). The domain
attribute must be the jaxdf.geometry.Domain
object defining the domain of the field.
The last attribute, get_field
, must be a function that evaluates the field at a coordinate using the parameters contained in params
, and has the signature get_fun(params: Field.params, x: Union[jnp.ndarray,float])
.
class Polynomial(Continuous):
@classmethod
def from_params(cls, params, domain):
def get_fun(params, x):
i = jnp.arange(len(params))
return jnp.expand_dims(jnp.sum(params*(x**i)), -1)
return cls(
params = params,
domain = domain,
get_fun = get_fun
)
@property
def degree(self):
return len(self.params)-1
def __repr__(self):
return "Polynomial(degree={})".format(self.degree)
# Construct a polynomial field from the same parameters as before
u_custom = Polynomial.from_params(u.params, u.domain)
print(u_custom, u_custom.params)
show_field(u_custom.on_grid, domain)
Polynomial(degree=4) [ 0.6122652 1.1225883 -0.8544134 -0.8127325 -0.890405 ]
To now define the derivative
operator acting on polynomials, we have several options. One is to simply define a python function that generates the Polynomial
object resulting from the gradient computation.
def derivative(u: Polynomial):
# Find the parameters of the polynomial after differentiation
coeffs = jnp.arange(1, u.params.shape[0])
new_params = u.params[1:]*coeffs
# Return a new polynomial with the new parameters
return Polynomial.from_params(new_params, u.domain)
du_custom = derivative(u_custom)
show_field(du_custom.on_grid, domain)
print(du_custom) # Note that the degree is one less than before
Polynomial(degree=3)
Note that the code is fully differentiable and can be compiled
import jax
@jax.jit
def f(u):
x = derivative(u) + 0.3
return x(0.1)
print(f(u))
[1.223762]
However, note that now we have a derivative
operator defined for all types, and we get incorrect results for fields that are not Polynomials
def sinfun(params, x):
theta = jnp.sum(params) # dummy operation to make the parameter vector a scalar
y = jnp.sin(theta*x*10)
return y
params = u.params
sin = Continuous(params, domain, sinfun)
show_field(sin.on_grid, domain)
z = derivative(sin)
show_field(z.on_grid, domain)
print('z: ', z)
z: Polynomial(degree=3)
Operators and Multiple DispatchĀ¶
One way to avoid this is to implement the operators as methods of the fields, which are then redefined by the children classes. This was the approach used in jaxdf
when it was first made public, and it still can be used
class PolynomialWithMethods(Polynomial):
@classmethod
def from_params(cls, params, domain):
def get_fun(params, x):
i = jnp.arange(len(params))
return jnp.expand_dims(jnp.sum(params*(x**i)), -1)
return cls(
params = params,
domain = domain,
get_fun = get_fun
)
@property
def degree(self):
return len(self.params)-1
def __repr__(self):
return "Polynomial(degree={})".format(self.degree)
# Custom derivative code, with parameters
def derivative(self, exponent=1.0):
new_params = self.params[1:]*jnp.arange(1, self.params.shape[0])*exponent
return PolynomialWithMethods.from_params(new_params, self.domain)
@jax.jit
def g(u, exponent):
return u.derivative(exponent) + 3.0
u_2 = PolynomialWithMethods.from_params(u.params, u.domain)
show_field(g(u_2, 1.3).on_grid, domain)
However, it can become cumbersome to deal with many derived methods for different kind of discretizations, especially if one starts to evaluate operators that accept multiple operands with different combinations of discretizations (e.g. dot products, +
, heterogeneous differential operators, etc).
This problem elegantly resolved in some programming languages usign multiple-dispatch. One of the languages that notably supports multiple dispatch is the Julia language, and I suggest to look at the packages of the SciML echosystem if you are familiar with Julia and or interested in learning this language (those packages look great!).
For us sticking with python, jax
and jaxdf
, here we borrow those ideas using the python multiple dispatch library plum. The jaxdf.operator
decorator can be used to define new (parametric) operators using as
@operator
def new_operator(x: Polynomial, y: Continuous, *, params=1.0):
... # Any jaxdf or jax-compatible code
return Field(...)
The input of the fuction can be arbitrary types: if they are fields or any type which is traceable by jax, they will be traced. The function has a reserved, mandatory input keyword params
, which is reserved for the parameters of the operator, like the coefficients of the stencil of a differential operator.
The output of the function can be any type, including jaxdf.Field
or jax.numpy.ndarray
.
The use of the @operator
decorator makes sense when the arguments are defined using type annotation. In that way, we are using the dispatch method of plum
to define an implementation of that function which is specific for the annotated types.
from jaxdf import operator
@operator
def derivative(x: Polynomial, *, axis=None, params=None):
print('Applying derivative to a polynomial')
if axis is not None:
print("Warning: axis argument is ignored for polynomials")
new_params = x.params[1:]*jnp.arange(1, u.params.shape[0])
return Polynomial.from_params(new_params, u.domain)
@operator
def derivative(x: Continuous, *, axis=0, params=None):
print('Applying derivative to a generic Continuous field')
get_x = x.get_fun
def grad_fun(p, coords):
f_jac = jax.jacfwd(get_x, argnums=(1,))
return jnp.expand_dims(f_jac(p, coords)[0][0][axis], -1)
return Continuous(x.params, x.domain, grad_fun)
du_custom = derivative(u_custom)
show_field(du_custom.on_grid, domain)
Applying derivative to a polynomial
z = derivative(sin)
show_field(z.on_grid, domain)
print('z: ', z)
Applying derivative to a generic Continuous field
z: Continuous( params=f32[5], domain=Domain(N=(256,), dx=(0.015625,)), get_fun=<function grad_fun> )