jaxdf.core
Module Overview
This module is the fundamental part of the jaxdf
framework.
At its core is the Field
class, a key element of jaxdf
. This class is designed as a module derived from equinox.Module
, which means it's a JAX-compatible dataclass. All types of discretizations within jaxdf
are derived from the Field
class.
Another crucial feature of jaxdf
is the operator
decorator. This decorator enables the implementation of multiple-dispatch functionality through the plum
library. This is particularly useful for creating new operators within the framework.
operator = Operator()
module-attribute
Decorator for defining operators using multiple dispatch. The type annotation of the
evaluate
function are used to determine the dispatch rules. The dispatch syntax is the
same as the Julia one, that is: operators are dispatched on the types of the positional arguments.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
evaluate |
Callable
|
A function with the signature |
required |
init_params |
Callable
|
A function that overrides the default parameters initializer for the operator. Useful when running the operator just to get the parameters is expensive. |
required |
precedence |
int
|
The precedence of the operator if an ambiguous match is found. |
required |
Returns:
Name | Type | Description |
---|---|---|
Callable |
The operator function with signature |
Keyword arguments are not considered for dispatching.
Keyword arguments are defined after the *
in the function signature.
Example
@operator
def my_operator(x: FourierSeries, *, dx: float, params=None):
...
The argument params
is mandatory and it must be a keyword argument. It is used to pass the
parameters of the operator, for example the stencil coefficients of a finite difference operator.
The default value of the parameters is specified by the init_params
function, as follows:
Example
def params_initializer(x, *, dx):
return {"stencil": jnp.ones(x.shape) * dx}
@operator(init_params=params_initializer)
def my_operator(x, *, dx, params=None):
b = params["stencil"] / dx
y_params = jnp.convolve(x.params, b, mode="same")
return x.replace_params(y_params)
The default value of params
is not considered during computation.
If the operator has no parameters, the init_params
function can be omitted. In this case, the
params
value is set to None
.
For constant parameters, the constants
function can be used:
Example
@operator(init_params=constants({"a": 1, "b": 2.0}))
def my_operator(x, *, params):
return x + params["a"] + params["b"]
Field
Bases: Module
Source code in jaxdf/core.py
239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 |
|
dims
property
The dimension of the field values
is_complex: bool
property
Checks if a field is complex.
Returns:
Name | Type | Description |
---|---|---|
bool |
bool
|
Whether the field is complex. |
on_grid
property
Returns the field on the grid points of the domain.
θ
property
Handy alias for the params
attribute
__call__(x)
An Field can be called as a function, returning the field at a desired point.
Example
...
a = Continuous.from_function(init_params, domain, get_field)
field_at_x = a(1.0)
Source code in jaxdf/core.py
249 250 251 252 253 254 255 256 257 258 259 260 261 262 |
|
replace_params(new_params)
Returns a new field of the same type, with the same domain and auxiliary data, but with new parameters.
Example
x = FourierSeries(jnp.ones(10), domain=domain)
y_params = x.params + 1
y = x.replace_params(y_params)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
new_params |
Any
|
The new parameters. |
required |
Returns:
Name | Type | Description |
---|---|---|
Field |
A new field with the same domain and auxiliary data, but with new parameters. |
Source code in jaxdf/core.py
296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 |
|
Operator
Source code in jaxdf/core.py
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
|
abstract(evaluate)
Decorator for defining abstract operators. This is mainly used to define generic docstrings.
Source code in jaxdf/core.py
141 142 143 144 |
|
constants(value)
This is a higher order function for defining constant parameters of operators, independent of the operator arguments.
Example
@operator(init_params=constants({"a": 1, "b": 2.0}))
def my_operator(x, *, params):
return x + params["a"] + params["b"]
Parameters:
Name | Type | Description | Default |
---|---|---|---|
value |
Any
|
The value of the constant parameters. |
required |
Returns:
Name | Type | Description |
---|---|---|
Callable |
Callable
|
The parameters initializer function that returns the constant value. |
Source code in jaxdf/core.py
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 |
|