jaxdf.core
Module Overview
This module is the fundamental part of the jaxdf framework.
At its core is the Field class, a key element of jaxdf. This class is designed as a module derived from equinox.Module, which means it's a JAX-compatible dataclass. All types of discretizations within jaxdf are derived from the Field class.
Another crucial feature of jaxdf is the operator decorator. This decorator enables the implementation of multiple-dispatch functionality through the plum library. This is particularly useful for creating new operators within the framework.
operator = Operator()
module-attribute
Decorator for defining operators using multiple dispatch. The type annotation of the
evaluate function are used to determine the dispatch rules. The dispatch syntax is the
same as the Julia one, that is: operators are dispatched on the types of the positional arguments.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
evaluate
|
Callable
|
A function with the signature |
required |
init_params
|
Callable
|
A function that overrides the default parameters initializer for the operator. Useful when running the operator just to get the parameters is expensive. |
required |
precedence
|
int
|
The precedence of the operator if an ambiguous match is found. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
Callable |
The operator function with signature |
Keyword arguments are not considered for dispatching.
Keyword arguments are defined after the * in the function signature.
Example
@operator
def my_operator(x: FourierSeries, *, dx: float, params=None):
...
The argument params is mandatory and it must be a keyword argument. It is used to pass the
parameters of the operator, for example the stencil coefficients of a finite difference operator.
The default value of the parameters is specified by the init_params function, as follows:
Example
def params_initializer(x, *, dx):
return {"stencil": jnp.ones(x.shape) * dx}
@operator(init_params=params_initializer)
def my_operator(x, *, dx, params=None):
b = params["stencil"] / dx
y_params = jnp.convolve(x.params, b, mode="same")
return x.replace_params(y_params)
The default value of params is not considered during computation.
If the operator has no parameters, the init_params function can be omitted. In this case, the
params value is set to None.
For constant parameters, the constants function can be used:
Example
@operator(init_params=constants({"a": 1, "b": 2.0}))
def my_operator(x, *, params):
return x + params["a"] + params["b"]
Field
Bases: Module
dims
property
The dimension of the field values
is_complex
property
Checks if a field is complex.
Returns:
| Name | Type | Description |
|---|---|---|
bool |
bool
|
Whether the field is complex. |
on_grid
property
Returns the field on the grid points of the domain.
θ
property
Handy alias for the params attribute
__call__(x)
An Field can be called as a function, returning the field at a desired point.
Example
...
a = Continuous.from_function(init_params, domain, get_field)
field_at_x = a(1.0)
replace_params(new_params)
Returns a new field of the same type, with the same domain and auxiliary data, but with new parameters.
Example
x = FourierSeries(jnp.ones(10), domain=domain)
y_params = x.params + 1
y = x.replace_params(y_params)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
new_params
|
Any
|
The new parameters. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
Field |
A new field with the same domain and auxiliary data, but with new parameters. |
Operator
abstract(evaluate)
Decorator for defining abstract operators. This is mainly used to define generic docstrings.
constants(value)
This is a higher order function for defining constant parameters of operators, independent of the operator arguments.
Example
@operator(init_params=constants({"a": 1, "b": 2.0}))
def my_operator(x, *, params):
return x + params["a"] + params["b"]
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
value
|
Any
|
The value of the constant parameters. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
Callable |
Callable
|
The parameters initializer function that returns the constant value. |