Skip to content

jaxdf.core

Module Overview

This module is the fundamental part of the jaxdf framework.

At its core is the Field class, a key element of jaxdf. This class is designed as a module derived from equinox.Module, which means it's a JAX-compatible dataclass. All types of discretizations within jaxdf are derived from the Field class.

Another crucial feature of jaxdf is the operator decorator. This decorator enables the implementation of multiple-dispatch functionality through the plum library. This is particularly useful for creating new operators within the framework.

operator = Operator() module-attribute

Decorator for defining operators using multiple dispatch. The type annotation of the evaluate function are used to determine the dispatch rules. The dispatch syntax is the same as the Julia one, that is: operators are dispatched on the types of the positional arguments.

Parameters:

Name Type Description Default
evaluate Callable

A function with the signature evaluate(field, *args, **kwargs, params). It must return a tuple, with the first element being a field and the second element being the default parameters for the operator.

required
init_params Callable

A function that overrides the default parameters initializer for the operator. Useful when running the operator just to get the parameters is expensive.

required
precedence int

The precedence of the operator if an ambiguous match is found.

required

Returns:

Name Type Description
Callable

The operator function with signature evaluate(field, *args, **kwargs, params).

Keyword arguments are not considered for dispatching. Keyword arguments are defined after the * in the function signature.

Example

@operator
def my_operator(x: FourierSeries, *, dx: float, params=None):
  ...

The argument params is mandatory and it must be a keyword argument. It is used to pass the parameters of the operator, for example the stencil coefficients of a finite difference operator.

The default value of the parameters is specified by the init_params function, as follows:

Example

def params_initializer(x, *, dx):
  return {"stencil": jnp.ones(x.shape) * dx}

@operator(init_params=params_initializer)
def my_operator(x, *, dx, params=None):
  b = params["stencil"] / dx
  y_params = jnp.convolve(x.params, b, mode="same")
  return x.replace_params(y_params)

The default value of params is not considered during computation. If the operator has no parameters, the init_params function can be omitted. In this case, the params value is set to None.

For constant parameters, the constants function can be used:

Example

@operator(init_params=constants({"a": 1, "b": 2.0}))
  def my_operator(x, *, params):
  return x + params["a"] + params["b"]

Field

Bases: Module

Source code in jaxdf/core.py
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
class Field(Module):
  params: PyTree
  domain: Domain

  # For concise code
  @property
  def θ(self):
    r"""Handy alias for the `params` attribute"""
    return self.params

  def __call__(self, x):
    r"""
        An Field can be called as a function, returning the field at a
        desired point.

        !!! example
            ```python
            ...
            a = Continuous.from_function(init_params, domain, get_field)
            field_at_x = a(1.0)
            ```
        """
    raise NotImplementedError(
        f"Not implemented for {self.__class__.__name__} discretization")

  @property
  def on_grid(self):
    """Returns the field on the grid points of the domain."""
    raise NotImplementedError(
        f"Not implemented for {self.__class__.__name__} discretization")

  @property
  def dims(self):
    r"""The dimension of the field values"""
    raise NotImplementedError

  @property
  def is_complex(self) -> bool:
    r"""Checks if a field is complex.

        Returns:
          bool: Whether the field is complex.
        """
    raise NotImplementedError

  @property
  def is_field_complex(self) -> bool:
    warnings.warn(
        "Field.is_field_complex is deprecated. Use Field.is_complex instead.",
        DeprecationWarning,
    )
    return self.is_complex

  @property
  def is_real(self) -> bool:
    return not self.is_complex

  def replace_params(self, new_params):
    r"""Returns a new field of the same type, with the same domain and auxiliary data, but with new parameters.

        !!! example
            ```python
            x = FourierSeries(jnp.ones(10), domain=domain)
            y_params = x.params + 1
            y = x.replace_params(y_params)
            ```

        Args:
          new_params (Any): The new parameters.

        Returns:
          Field: A new field with the same domain and auxiliary data, but with new parameters.
        """
    return self.__class__(new_params, self.domain)

  # Dummy magic functions to make it work with
  # the dispatch system
  def __add__(self, other):
    return __add__(self, other)

  def __radd__(self, other):
    return __radd__(self, other)

  def __sub__(self, other):
    return __sub__(self, other)

  def __rsub__(self, other):
    return __rsub__(self, other)

  def __mul__(self, other):
    return __mul__(self, other)

  def __rmul__(self, other):
    return __rmul__(self, other)

  def __neg__(self):
    return __neg__(self)

  def __pow__(self, other):
    return __pow__(self, other)

  def __rpow__(self, other):
    return __rpow__(self, other)

  def __truediv__(self, other):
    return __truediv__(self, other)

  def __rtruediv__(self, other):
    return __rtruediv__(self, other)

dims property

The dimension of the field values

is_complex: bool property

Checks if a field is complex.

Returns:

Name Type Description
bool bool

Whether the field is complex.

on_grid property

Returns the field on the grid points of the domain.

θ property

Handy alias for the params attribute

__call__(x)

An Field can be called as a function, returning the field at a desired point.

Example

...
a = Continuous.from_function(init_params, domain, get_field)
field_at_x = a(1.0)
Source code in jaxdf/core.py
249
250
251
252
253
254
255
256
257
258
259
260
261
262
def __call__(self, x):
  r"""
      An Field can be called as a function, returning the field at a
      desired point.

      !!! example
          ```python
          ...
          a = Continuous.from_function(init_params, domain, get_field)
          field_at_x = a(1.0)
          ```
      """
  raise NotImplementedError(
      f"Not implemented for {self.__class__.__name__} discretization")

replace_params(new_params)

Returns a new field of the same type, with the same domain and auxiliary data, but with new parameters.

Example

x = FourierSeries(jnp.ones(10), domain=domain)
y_params = x.params + 1
y = x.replace_params(y_params)

Parameters:

Name Type Description Default
new_params Any

The new parameters.

required

Returns:

Name Type Description
Field

A new field with the same domain and auxiliary data, but with new parameters.

Source code in jaxdf/core.py
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
def replace_params(self, new_params):
  r"""Returns a new field of the same type, with the same domain and auxiliary data, but with new parameters.

      !!! example
          ```python
          x = FourierSeries(jnp.ones(10), domain=domain)
          y_params = x.params + 1
          y = x.replace_params(y_params)
          ```

      Args:
        new_params (Any): The new parameters.

      Returns:
        Field: A new field with the same domain and auxiliary data, but with new parameters.
      """
  return self.__class__(new_params, self.domain)

Operator

Source code in jaxdf/core.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
class Operator:

  def __call__(
      self,
      evaluate: Union[Callable, None] = None,
      init_params: Union[Callable, None] = None,
      precedence: int = 0,
  ):
    if evaluate is None:
      # Returns the decorator
      def decorator(evaluate):
        return _operator(evaluate, precedence, init_params)

      return decorator
    else:
      return _operator(evaluate, precedence, init_params)

  def abstract(self, evaluate: Callable):
    """Decorator for defining abstract operators. This is mainly used
        to define generic docstrings."""
    return _abstract_operator(evaluate)

abstract(evaluate)

Decorator for defining abstract operators. This is mainly used to define generic docstrings.

Source code in jaxdf/core.py
141
142
143
144
def abstract(self, evaluate: Callable):
  """Decorator for defining abstract operators. This is mainly used
      to define generic docstrings."""
  return _abstract_operator(evaluate)

constants(value)

This is a higher order function for defining constant parameters of operators, independent of the operator arguments.

Example

@operator(init_params=constants({"a": 1, "b": 2.0}))
def my_operator(x, *, params):
  return x + params["a"] + params["b"]

Parameters:

Name Type Description Default
value Any

The value of the constant parameters.

required

Returns:

Name Type Description
Callable Callable

The parameters initializer function that returns the constant value.

Source code in jaxdf/core.py
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
def constants(value) -> Callable:
  r"""This is a higher order function for defining constant parameters of
    operators, independent of the operator arguments.

    !!! example

        ```python
        @operator(init_params=constants({"a": 1, "b": 2.0}))
        def my_operator(x, *, params):
          return x + params["a"] + params["b"]
        ```

    Args:
      value (Any): The value of the constant parameters.

    Returns:
      Callable: The parameters initializer function that returns the constant value.
    """

  def init_params(*args, **kwargs):
    return value

  return init_params