mods
Module
Bases: Module
A custom module inheriting from Equinox's Module class.
replace(name, value)
Replaces the attribute of the module with the given name with a new value.
This method utilizes eqx.tree_at to update the attribute in a functional
manner, ensuring compatibility with JAX's functional approach and autodiff capabilities.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
The name of the attribute to be replaced. |
required |
value
|
PyTree
|
The new value to set for the attribute. This should be compatible with JAX's PyTree structure. |
required |
Returns:
| Type | Description |
|---|---|
|
A new instance of Module with the specified attribute updated. |
|
|
The rest of the module's attributes remain unchanged. |
Example
>>> module = jaxdf.Module(weight=1.0, bias=2.0)
>>> new_module = module.replace("weight", 3.0)
>>> new_module.weight == 3.0 # True