mods
Module
Bases: Module
A custom module inheriting from Equinox's Module class.
Source code in jaxdf/mods.py
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
|
replace(name, value)
Replaces the attribute of the module with the given name with a new value.
This method utilizes eqx.tree_at
to update the attribute in a functional
manner, ensuring compatibility with JAX's functional approach and autodiff capabilities.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str
|
The name of the attribute to be replaced. |
required |
value |
PyTree
|
The new value to set for the attribute. This should be compatible with JAX's PyTree structure. |
required |
Returns:
Type | Description |
---|---|
A new instance of Module with the specified attribute updated. |
|
The rest of the module's attributes remain unchanged. |
Example
>>> module = jaxdf.Module(weight=1.0, bias=2.0)
>>> new_module = module.replace("weight", 3.0)
>>> new_module.weight == 3.0 # True
Source code in jaxdf/mods.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
|