Skip to content

mods

Module

Bases: Module

A custom module inheriting from Equinox's Module class.

Source code in jaxdf/mods.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class Module(eqx.Module):
  """
    A custom module inheriting from Equinox's Module class.
    """

  def replace(self, name: str, value: PyTree):
    """
        Replaces the attribute of the module with the given name with a new value.

        This method utilizes `eqx.tree_at` to update the attribute in a functional
        manner, ensuring compatibility with JAX's functional approach and autodiff capabilities.

        Args:
            name (str): The name of the attribute to be replaced.
            value (PyTree): The new value to set for the attribute. This should be
                            compatible with JAX's PyTree structure.

        Returns:
            A new instance of Module with the specified attribute updated.
            The rest of the module's attributes remain unchanged.

        !!! example
        ```python
            >>> module = jaxdf.Module(weight=1.0, bias=2.0)
            >>> new_module = module.replace("weight", 3.0)
            >>> new_module.weight == 3.0    # True
        ```
        """
    f = lambda m: m.__getattribute__(name)
    return eqx.tree_at(f, self, value)

replace(name, value)

Replaces the attribute of the module with the given name with a new value.

This method utilizes eqx.tree_at to update the attribute in a functional manner, ensuring compatibility with JAX's functional approach and autodiff capabilities.

Parameters:

Name Type Description Default
name str

The name of the attribute to be replaced.

required
value PyTree

The new value to set for the attribute. This should be compatible with JAX's PyTree structure.

required

Returns:

Type Description

A new instance of Module with the specified attribute updated.

The rest of the module's attributes remain unchanged.

Example

    >>> module = jaxdf.Module(weight=1.0, bias=2.0)
    >>> new_module = module.replace("weight", 3.0)
    >>> new_module.weight == 3.0    # True
Source code in jaxdf/mods.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def replace(self, name: str, value: PyTree):
  """
      Replaces the attribute of the module with the given name with a new value.

      This method utilizes `eqx.tree_at` to update the attribute in a functional
      manner, ensuring compatibility with JAX's functional approach and autodiff capabilities.

      Args:
          name (str): The name of the attribute to be replaced.
          value (PyTree): The new value to set for the attribute. This should be
                          compatible with JAX's PyTree structure.

      Returns:
          A new instance of Module with the specified attribute updated.
          The rest of the module's attributes remain unchanged.

      !!! example
      ```python
          >>> module = jaxdf.Module(weight=1.0, bias=2.0)
          >>> new_module = module.replace("weight", 3.0)
          >>> new_module.weight == 3.0    # True
      ```
      """
  f = lambda m: m.__getattribute__(name)
  return eqx.tree_at(f, self, value)