Skip to content

discretization

Continuous

Bases: Field

A continuous discretization. This discretization assumes that the field is a function of the parameters contained in Continuous.params and a point in the domain. The function that computes the field from the parameters and the point in the domain is contained in Continuous.get_fun. This function has the signature get_fun(params, x), where params is the parameter vector and x is the point in the domain.

Source code in jaxdf/discretization.py
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
class Continuous(Field):
  r"""A continuous discretization. This discretization assumes that the field is a
    function of the parameters contained in `Continuous.params` and a point in the
    domain. The function that computes the field from the parameters and the point in
    the domain is contained in `Continuous.get_fun`. This function has the signature
    `get_fun(params, x)`, where `params` is the parameter vector and `x` is the point
    in the domain."""
  params: PyTree
  domain: Domain
  get_fun: Callable = eqx.field(static=True)

  def __eq__(self, other):
    return tree_equal(self, other) * (self.get_fun == other.get_fun)

  @property
  def is_complex(self) -> bool:
    r"""Checks if a field is complex.

        Returns:
          bool: Whether the field is complex.
        """
    origin = self.domain.origin
    value = self.get_fun(self.params, origin)
    return jnp.iscomplexobj(value)

  @property
  def dims(self):
    x = self.domain.origin
    return eval_shape(self.get_fun, self.params, x).shape

  def replace_params(self, new_params: PyTree) -> "Continuous":
    r"""Replaces the parameters of the discretization with new ones. The domain
        and `get_field` function are not changed.

        Args:
          new_params (PyTree): The new parameters of the discretization.

        Returns:
          Continuous: A continuous discretization with the new parameters.
        """
    return self.__class__(new_params, self.domain, self.get_field)

  def update_fun_and_params(
      self,
      params: PyTree,
      get_field: Callable,
  ) -> "Continuous":
    r"""Updates the parameters and the function of the discretization.

        Args:
          params (PyTree): The new parameters of the discretization.
          get_field (Callable): A function that takes a parameter vector and a point in
            the domain and returns the field at that point. The signature of this
            function is `get_field(params, x)`.

        Returns:
          Continuous: A continuous discretization with the new parameters and function.
        """
    return self.__class__(params, self.domain, get_field)

  @classmethod
  def from_function(cls, domain, init_fun: Callable, get_field: Callable,
                    seed):
    r"""Creates a continuous discretization from an `init_fun` function.

        Args:
          domain (Domain): The domain of the discretization.
          init_fun (Callable): A function that initializes the parameters of the
            discretization. The signature of this function is `init_fun(rng, domain)`.
          get_field (Callable): A function that takes a parameter vector and a point in
            the domain and returns the field at that point. The signature of this
            function is `get_field(params, x)`.
          seed (int): The seed for the random number generator.

        Returns:
          Continuous: A continuous discretization.
        """
    params = init_fun(seed, domain)
    return cls(params, domain=domain, get_fun=get_field)

  def __call__(self, x):
    r"""
        An object of this class can be called as a function, returning the field at a
        desired point.

        !!! example
            ```python
            ...
            a = Continuous.from_function(init_params, domain, get_field)
            field_at_x = a(1.0)
            ```
        """
    return self.get_fun(self.params, x)

  def get_field(self, x):
    warnings.warn(
        "Continuous.get_field is deprecated. Use Continuous.__call__ instead.",
        DeprecationWarning,
    )
    return self.__call__(x)

  @property
  def on_grid(self):
    """Returns the field on the grid points of the domain."""
    fun = self.get_fun
    ndims = len(self.domain.N)
    for _ in range(ndims):
      fun = vmap(fun, in_axes=(None, 0))

    return fun(self.params, self.domain.grid)

is_complex: bool property

Checks if a field is complex.

Returns:

Name Type Description
bool bool

Whether the field is complex.

on_grid property

Returns the field on the grid points of the domain.

__call__(x)

An object of this class can be called as a function, returning the field at a desired point.

Example

...
a = Continuous.from_function(init_params, domain, get_field)
field_at_x = a(1.0)
Source code in jaxdf/discretization.py
117
118
119
120
121
122
123
124
125
126
127
128
129
def __call__(self, x):
  r"""
      An object of this class can be called as a function, returning the field at a
      desired point.

      !!! example
          ```python
          ...
          a = Continuous.from_function(init_params, domain, get_field)
          field_at_x = a(1.0)
          ```
      """
  return self.get_fun(self.params, x)

from_function(domain, init_fun, get_field, seed) classmethod

Creates a continuous discretization from an init_fun function.

Parameters:

Name Type Description Default
domain Domain

The domain of the discretization.

required
init_fun Callable

A function that initializes the parameters of the discretization. The signature of this function is init_fun(rng, domain).

required
get_field Callable

A function that takes a parameter vector and a point in the domain and returns the field at that point. The signature of this function is get_field(params, x).

required
seed int

The seed for the random number generator.

required

Returns:

Name Type Description
Continuous

A continuous discretization.

Source code in jaxdf/discretization.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
@classmethod
def from_function(cls, domain, init_fun: Callable, get_field: Callable,
                  seed):
  r"""Creates a continuous discretization from an `init_fun` function.

      Args:
        domain (Domain): The domain of the discretization.
        init_fun (Callable): A function that initializes the parameters of the
          discretization. The signature of this function is `init_fun(rng, domain)`.
        get_field (Callable): A function that takes a parameter vector and a point in
          the domain and returns the field at that point. The signature of this
          function is `get_field(params, x)`.
        seed (int): The seed for the random number generator.

      Returns:
        Continuous: A continuous discretization.
      """
  params = init_fun(seed, domain)
  return cls(params, domain=domain, get_fun=get_field)

replace_params(new_params)

Replaces the parameters of the discretization with new ones. The domain and get_field function are not changed.

Parameters:

Name Type Description Default
new_params PyTree

The new parameters of the discretization.

required

Returns:

Name Type Description
Continuous Continuous

A continuous discretization with the new parameters.

Source code in jaxdf/discretization.py
67
68
69
70
71
72
73
74
75
76
77
def replace_params(self, new_params: PyTree) -> "Continuous":
  r"""Replaces the parameters of the discretization with new ones. The domain
      and `get_field` function are not changed.

      Args:
        new_params (PyTree): The new parameters of the discretization.

      Returns:
        Continuous: A continuous discretization with the new parameters.
      """
  return self.__class__(new_params, self.domain, self.get_field)

update_fun_and_params(params, get_field)

Updates the parameters and the function of the discretization.

Parameters:

Name Type Description Default
params PyTree

The new parameters of the discretization.

required
get_field Callable

A function that takes a parameter vector and a point in the domain and returns the field at that point. The signature of this function is get_field(params, x).

required

Returns:

Name Type Description
Continuous Continuous

A continuous discretization with the new parameters and function.

Source code in jaxdf/discretization.py
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def update_fun_and_params(
    self,
    params: PyTree,
    get_field: Callable,
) -> "Continuous":
  r"""Updates the parameters and the function of the discretization.

      Args:
        params (PyTree): The new parameters of the discretization.
        get_field (Callable): A function that takes a parameter vector and a point in
          the domain and returns the field at that point. The signature of this
          function is `get_field(params, x)`.

      Returns:
        Continuous: A continuous discretization with the new parameters and function.
      """
  return self.__class__(params, self.domain, get_field)

FiniteDifferences

Bases: OnGrid

A Finite Differences field defined on a collocation grid.

Source code in jaxdf/discretization.py
331
332
333
334
335
class FiniteDifferences(OnGrid):
  r"""A Finite Differences field defined on a collocation grid."""
  params: PyTree
  domain: Domain
  accuracy: int = eqx.field(default=8, static=True)

FourierSeries

Bases: OnGrid

A Fourier series field defined on a collocation grid.

Source code in jaxdf/discretization.py
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
class FourierSeries(OnGrid):
  r"""A Fourier series field defined on a collocation grid."""

  def __call__(self, x):
    r"""Uses the Fourier shift theorem to compute the value of the field
        at an arbitrary point. Requires N*2 one dimensional FFTs.

        Args:
          x (float): The point at which to evaluate the field.

        Returns:
          float, jnp.ndarray: The value of the field at the point.
        """
    dx = jnp.asarray(self.domain.dx)
    domain_size = jnp.asarray(self.domain.N) * dx
    shift = x - domain_size / 2 + 0.5 * dx

    k_vec = [
        jnp.exp(-1j * k * delta) for k, delta in zip(self._freq_axis, shift)
    ]
    ffts = self._ffts

    new_params = self.params

    def single_shift(axis, u):
      u = jnp.moveaxis(u, axis, -1)
      Fx = ffts[0](u, axis=-1)
      iku = Fx * k_vec[axis]
      du = ffts[1](iku, axis=-1, n=u.shape[-1])
      return jnp.moveaxis(du, -1, axis)

    for ax in range(self.domain.ndim):
      new_params = single_shift(ax, new_params)

    origin = tuple([0] * self.domain.ndim)
    return new_params[origin]

  @property
  def _freq_axis(self):
    r"""Returns the frequency axis of the grid"""
    if self.is_complex:

      def f(N, dx):
        return jnp.fft.fftfreq(N, dx) * 2 * jnp.pi

    else:

      def f(N, dx):
        return jnp.fft.rfftfreq(N, dx) * 2 * jnp.pi

    k_axis = [f(n, delta) for n, delta in zip(self.domain.N, self.domain.dx)]
    return k_axis

  @property
  def _ffts(self):
    r"""Returns the FFT and iFFT functions that are appropriate for the
        field type (real or complex)
        """
    if self.is_real:
      return [jnp.fft.rfft, jnp.fft.irfft]
    else:
      return [jnp.fft.fft, jnp.fft.ifft]

  @property
  def _cut_freq_axis(self):
    r"""Same as `FourierSeries._freq_axis`, but last frequency axis is relative to a real FFT.
        Those frequency axis match with the ones of the rfftn function
        """

    def f(N, dx):
      return jnp.fft.fftfreq(N, dx) * 2 * jnp.pi

    k_axis = [f(n, delta) for n, delta in zip(self.domain.N, self.domain.dx)]
    if not self.is_complex:
      k_axis[-1] = (jnp.fft.rfftfreq(self.domain.N[-1], self.domain.dx[-1]) *
                    2 * jnp.pi)
    return k_axis

  @property
  def _cut_freq_grid(self):
    return jnp.stack(jnp.meshgrid(*self._cut_freq_axis, indexing="ij"),
                     axis=-1)

  @property
  def _freq_grid(self):
    return jnp.stack(jnp.meshgrid(*self._freq_axis, indexing="ij"), axis=-1)

__call__(x)

Uses the Fourier shift theorem to compute the value of the field at an arbitrary point. Requires N*2 one dimensional FFTs.

Parameters:

Name Type Description Default
x float

The point at which to evaluate the field.

required

Returns:

Type Description

float, jnp.ndarray: The value of the field at the point.

Source code in jaxdf/discretization.py
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
def __call__(self, x):
  r"""Uses the Fourier shift theorem to compute the value of the field
      at an arbitrary point. Requires N*2 one dimensional FFTs.

      Args:
        x (float): The point at which to evaluate the field.

      Returns:
        float, jnp.ndarray: The value of the field at the point.
      """
  dx = jnp.asarray(self.domain.dx)
  domain_size = jnp.asarray(self.domain.N) * dx
  shift = x - domain_size / 2 + 0.5 * dx

  k_vec = [
      jnp.exp(-1j * k * delta) for k, delta in zip(self._freq_axis, shift)
  ]
  ffts = self._ffts

  new_params = self.params

  def single_shift(axis, u):
    u = jnp.moveaxis(u, axis, -1)
    Fx = ffts[0](u, axis=-1)
    iku = Fx * k_vec[axis]
    du = ffts[1](iku, axis=-1, n=u.shape[-1])
    return jnp.moveaxis(du, -1, axis)

  for ax in range(self.domain.ndim):
    new_params = single_shift(ax, new_params)

  origin = tuple([0] * self.domain.ndim)
  return new_params[origin]

Linear

Bases: Field

This discretization assumes that the field is a linear function of the parameters contained in Linear.params.

Source code in jaxdf/discretization.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class Linear(Field):
  r"""This discretization assumes that the field is a linear function of the
    parameters contained in `Linear.params`.
    """
  params: PyTree
  domain: Domain

  def __eq__(self, other):
    return tree_equal(self, other) * (self.domain == other.domain)

  @property
  def is_complex(self) -> bool:
    r"""Checks if a field is complex.

        Returns:
          bool: Whether the field is complex.
        """
    return self.params.dtype == jnp.complex64 or self.params.dtype == jnp.complex128

is_complex: bool property

Checks if a field is complex.

Returns:

Name Type Description
bool bool

Whether the field is complex.

OnGrid

Bases: Linear

Source code in jaxdf/discretization.py
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
class OnGrid(Linear):

  def __check_init__(self):
    # Check if the number of dimensions of the parameters is correct, fix if needed
    if len(self.params.shape) == len(self.domain.N):
      # If only the last one is missing, add it
      if self.params.shape == self.domain.N:
        object.__setattr__(self, "params", jnp.expand_dims(self.params,
                                                           axis=-1))

    if self.params.shape == self.domain.N:
      raise ValueError(
          f"The number of dimensions of the parameters is incorrect. It should be the number of dimensions of the domain plus at least one more. The parameters have shape {self.params.shape} and the domain has shape {self.domain.N}"
      )

  def add_dim(self):
    """Adds a dimension at the end of the `params` array."""
    new_params = jnp.expand_dims(self.params, axis=-1)
    # Returns a new object
    return self.__class__(new_params, self.domain)

  @property
  def dims(self):
    return self.params.shape[-1]

  def __getitem__(self, idx):
    r"""Allow indexing when leading batch / time dimensions are
        present in the parameters

        !!! example
            ```python
            ...
            domain = Domain((16, (1.0,))

            # 10 fields
            params = random.uniform(key, (10, 16, 1))
            a = OnGrid(params, domain)

            # Field at the 5th index
            field = a[5]
            ```

        Returns:
          OnGrid: A linear discretization on the grid points of the domain.

        Raises:
          IndexError: If the field is not indexable (single field).
        """
    if self.domain.ndim + 1 == len(self.params.shape):
      raise IndexError(
          "Indexing is only supported if there's at least one batch / time dimension"
      )

    new_params = self.params[idx]
    return self.__class__(new_params, self.domain)

  @classmethod
  def empty(cls, domain, dims=1):
    r"""Creates an empty OnGrid field (zero field). Equivalent to
        `OnGrid(jnp.zeros(domain.N), domain)`.
        """
    _N = list(domain.N) + [dims]
    N = tuple(_N)
    return cls(jnp.zeros(N), domain)

  @classmethod
  def from_grid(cls, grid_values, domain):
    r"""Creates an OnGrid field from a grid of values.

        Args:
          grid_values (ndarray): The grid of values.
          domain (Domain): The domain of the discretization.
        """
    a = cls(grid_values, domain)
    return a

  def replace_params(self, new_params):
    r"""Replaces the parameters of the discretization with new ones. The domain
        is not changed.

        Args:
          new_params (PyTree): The new parameters of the discretization.

        Returns:
          OnGrid: A linear discretization with the new parameters.
        """
    return self.__class__(new_params, self.domain)

  @property
  def on_grid(self):
    r"""The field on the grid points of the domain."""
    return self.params

on_grid property

The field on the grid points of the domain.

__getitem__(idx)

Allow indexing when leading batch / time dimensions are present in the parameters

Example

...
domain = Domain((16, (1.0,))

# 10 fields
params = random.uniform(key, (10, 16, 1))
a = OnGrid(params, domain)

# Field at the 5th index
field = a[5]

Returns:

Name Type Description
OnGrid

A linear discretization on the grid points of the domain.

Raises:

Type Description
IndexError

If the field is not indexable (single field).

Source code in jaxdf/discretization.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
def __getitem__(self, idx):
  r"""Allow indexing when leading batch / time dimensions are
      present in the parameters

      !!! example
          ```python
          ...
          domain = Domain((16, (1.0,))

          # 10 fields
          params = random.uniform(key, (10, 16, 1))
          a = OnGrid(params, domain)

          # Field at the 5th index
          field = a[5]
          ```

      Returns:
        OnGrid: A linear discretization on the grid points of the domain.

      Raises:
        IndexError: If the field is not indexable (single field).
      """
  if self.domain.ndim + 1 == len(self.params.shape):
    raise IndexError(
        "Indexing is only supported if there's at least one batch / time dimension"
    )

  new_params = self.params[idx]
  return self.__class__(new_params, self.domain)

add_dim()

Adds a dimension at the end of the params array.

Source code in jaxdf/discretization.py
164
165
166
167
168
def add_dim(self):
  """Adds a dimension at the end of the `params` array."""
  new_params = jnp.expand_dims(self.params, axis=-1)
  # Returns a new object
  return self.__class__(new_params, self.domain)

empty(domain, dims=1) classmethod

Creates an empty OnGrid field (zero field). Equivalent to OnGrid(jnp.zeros(domain.N), domain).

Source code in jaxdf/discretization.py
205
206
207
208
209
210
211
212
@classmethod
def empty(cls, domain, dims=1):
  r"""Creates an empty OnGrid field (zero field). Equivalent to
      `OnGrid(jnp.zeros(domain.N), domain)`.
      """
  _N = list(domain.N) + [dims]
  N = tuple(_N)
  return cls(jnp.zeros(N), domain)

from_grid(grid_values, domain) classmethod

Creates an OnGrid field from a grid of values.

Parameters:

Name Type Description Default
grid_values ndarray

The grid of values.

required
domain Domain

The domain of the discretization.

required
Source code in jaxdf/discretization.py
214
215
216
217
218
219
220
221
222
223
@classmethod
def from_grid(cls, grid_values, domain):
  r"""Creates an OnGrid field from a grid of values.

      Args:
        grid_values (ndarray): The grid of values.
        domain (Domain): The domain of the discretization.
      """
  a = cls(grid_values, domain)
  return a

replace_params(new_params)

Replaces the parameters of the discretization with new ones. The domain is not changed.

Parameters:

Name Type Description Default
new_params PyTree

The new parameters of the discretization.

required

Returns:

Name Type Description
OnGrid

A linear discretization with the new parameters.

Source code in jaxdf/discretization.py
225
226
227
228
229
230
231
232
233
234
235
def replace_params(self, new_params):
  r"""Replaces the parameters of the discretization with new ones. The domain
      is not changed.

      Args:
        new_params (PyTree): The new parameters of the discretization.

      Returns:
        OnGrid: A linear discretization with the new parameters.
      """
  return self.__class__(new_params, self.domain)