Skip to content

mods

Module

Bases: Module

A custom module inheriting from Equinox's Module class.

replace(name, value)

Replaces the attribute of the module with the given name with a new value.

This method utilizes eqx.tree_at to update the attribute in a functional manner, ensuring compatibility with JAX's functional approach and autodiff capabilities.

Parameters:

Name Type Description Default
name str

The name of the attribute to be replaced.

required
value PyTree

The new value to set for the attribute. This should be compatible with JAX's PyTree structure.

required

Returns:

Type Description

A new instance of Module with the specified attribute updated.

The rest of the module's attributes remain unchanged.

Example

    >>> module = jaxdf.Module(weight=1.0, bias=2.0)
    >>> new_module = module.replace("weight", 3.0)
    >>> new_module.weight == 3.0    # True