import matplotlib as mpl
mpl.rcParams.update({
'font.family': 'serif',
})
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
import matplotlib.font_manager as fm
Optimizing trough GMRESĀ¶
This example demostrates how it is possible to take advantage of the implicit function theorem
to differentiate trough a fixed point algorithm with $O(1)$ memory requirement, Here, the iterative solver is given by GMRES, which is used to solve the Helmholtz equation.
A great discussion about taking derivatives of a generic fixed-point algorithm is given in the DEQ paper
Helmholtz equationĀ¶
We'll assume to transmit waves from a set of $n$ transducers, which act as monopole sources: that means that we can define an apodization vector
$$ \mathbf a = (a_0, \dots, a_n), \qquad a_i \in \mathbb{C}, \; \|a_i\| < 1 $$
such that $\rho(\mathbf a)$ is the transmit wavefield. The unit norm constraint is needed to enforce the fact that each transducer has an upper limit on the maximum power it can transmit.
We could use several methods to represent this vector and its constraint. Here, we use
$$ a_j(\rho_j, \theta_j) = \frac{e^{i\theta_j}}{1 + \rho_j^2}. $$
FocusingĀ¶
Often, we want to find the apodization vector which returns a field having certain properties. For example, in a neurostimulation session we may want to maximize the acoustic power delivered to a certain spot, while keeping the acoustic field below an arbitrary treshold in another region.
Let's call $\mathbf p\in\mathbb{R}^2$ the point where we want to maximize the wavefield. For a field $\phi(\mathbf x,\mathbf a)$ generated by the apodization $\mathbf a$, the optimal apodization is then given by
$$ \hat {\mathbf a} = \operatorname*{arg\,max}_{\mathbf a} \|\phi(\mathbf p, \mathbf a) \| $$
We start by setting up the simulation:
from functools import partial
import jax
import numpy as np
from jax import numpy as jnp
from jax import random
key = random.PRNGKey(42)
import matplotlib
matplotlib.rcParams.update(matplotlib.rcParamsDefault)
# Defining geometry
from jwave import FourierSeries
from jwave.geometry import Domain, Medium
N = (320, 512) # Grid size
dx = (1e-4, 1e-4) # Spatial resolution
omega = 1.7e6*2*jnp.pi # Wavefield omega = 2*pi*f
target = [160, 360] # Target location
# Making geometry
domain = Domain(N, dx)
# Constructing medium physical properties
sound_speed = jnp.ones(N)
sound_speed = sound_speed.at[30:80, 50:80].set(1.6)
sound_speed = sound_speed.at[80:140, 50:100].set(1.4)
sound_speed = sound_speed.at[140:220, 45:130].set(1.2)
sound_speed = jnp.expand_dims(sound_speed.at[220:280, 70:100].set(1.5), -1)*1480
sound_speed = FourierSeries(sound_speed, domain)
medium = Medium(domain=domain, sound_speed=sound_speed, pml_size=25)
# Build the vector that holds the parameters of the apodization an the
# functions required to transform it into a source wavefield
transmit_phase = jnp.concatenate([jnp.ones((32,)), jnp.ones((32,))])
position = list(range(32, 32 + (8 * 32), 8))
def phase_to_apod(phases):
dim = len(phases) // 2
return jnp.exp(1j * phases[dim:]) / (1 + (phases[:dim]) ** 2)
def phases_to_field(phases, domain):
phases = phase_to_apod(phases)
src_field = jnp.zeros(domain.N).astype(jnp.complex64)
src_field = src_field.at[position, 25].set(phases)
return FourierSeries(jnp.expand_dims(src_field, -1), domain)
linear_phase = phases_to_field(transmit_phase, domain)
from matplotlib import pyplot as plt
from jwave.utils import display_complex_field, show_positive_field
plt.figure(figsize=(8, 5))
plt.imshow(medium.sound_speed.on_grid)
plt.colorbar()
plt.title("Sound speed map")
plt.scatter([25] * len(position), position, marker=".", label="Transducers")
plt.scatter(target[1], target[0], label="Target", marker='x')
<matplotlib.collections.PathCollection at 0x7f40bb45c310>
We will now extract the default parameters of the helmholtz solver. This is not strictly necessary, but in 3D simulation reduces the compilation time at the expenses of a slightly larger runtime: see https://github.com/google/jax/issues?q=constant+folding
from jax import jit
from jwave.acoustics.operators import helmholtz
from jwave.acoustics.time_harmonic import helmholtz_solver, helmholtz_solver_verbose
op_params = helmholtz.default_params(linear_phase, medium, omega=1.0)
print("Operator parameters: " + str(list(op_params.keys())))
plt.imshow(op_params["pml_on_grid"][0].on_grid[...,0].imag)
plt.title("Imaginary component of the 1st PML coordinate field")
plt.show()
Operator parameters: ['pml_on_grid', 'fft_u']
@jit
def fixed_medium_solver(src_field, op_params, guess=None, tol=1e-3):
return helmholtz_solver(
medium, omega, src_field, guess=guess, tol=tol, params=op_params
)
field = fixed_medium_solver(linear_phase, op_params)
_ = display_complex_field(field, figsize=(20, 20))
We can now define our loss function $L(\mathbf a)$ and perform gradient descent, to reach a local minima.
Note that this is possible because the GMRES implementation, which computes the solution of the Helmholtz problem, is differentiable.
from jax import value_and_grad, vmap
def loss(field):
field = field.on_grid
return -jnp.sum(jnp.abs(field[target[0], target[1]]))
def get_field(transmit_phase, tol, guess, op_params):
transmit_field = phases_to_field(transmit_phase, domain)
return fixed_medium_solver(transmit_field, op_params, guess, tol)
def full_loss(transmit_phase, tol, guess, op_params):
field = get_field(transmit_phase, tol, guess, op_params)
return loss(field), field
loss_with_grad = value_and_grad(full_loss, has_aux=True)
ā ļø Run the next cell only if you don't have tqdm
installed, otherwise it will be reinstalled
!pip install tqdm
Requirement already satisfied: tqdm in /home/antonio/anaconda3/envs/jwave/lib/python3.11/site-packages (4.65.0)
from jax import jit
from jax.example_libraries import optimizers
from tqdm import tqdm
losshistory = []
init_fun, update_fun, get_params = optimizers.adam(0.1, b1=0.9, b2=0.9)
opt_state = init_fun(transmit_phase)
@partial(jit, static_argnums=(1,))
def update(opt_state, tol, guess, op_params):
loss_and_field, gradient = loss_with_grad(
get_params(opt_state), tol, guess, op_params
)
lossval = loss_and_field[0]
field = loss_and_field[1]
return lossval, field, update_fun(k, gradient, opt_state)
pbar = tqdm(range(100))
tol = 1e-3
guess = None
for k in pbar:
lossval, new_field, opt_state = update(opt_state, tol, guess, op_params)
# For logging
pbar.set_description("Ampl: {:01.4f}".format(-lossval))
losshistory.append(lossval)
transmit_phase = get_params(opt_state)
Following is the wavefield for the optimized apodization
fig, ax = plt.subplots(1,2,figsize=(10,3), dpi=200)
im1 = ax[0].imshow(medium.sound_speed.on_grid, cmap="PuBu")
cbar = fig.colorbar(im1, ax=ax[0])
cbar.ax.get_yaxis().labelpad = 15
ax[0].scatter([25] * len(position), position, marker=".", color="black", label="Transducers")
ax[0].scatter(target[1], target[0], label="Target", color="green", marker='o')
ax[0].axis('off')
ax[0].set_title('Speed of sound map')
ax[0].legend()
# Scale bar
fontprops = fm.FontProperties(size=12)
scalebar = AnchoredSizeBar(
ax[0].transData,
100, '1 cm', 'lower right',
pad=0.3,
color='black',
frameon=False,
size_vertical=2,
fontproperties=fontprops)
ax[0].add_artist(scalebar)
im1 = ax[1].imshow(jnp.abs(new_field.on_grid), cmap="inferno", vmax=0.5)
cbar = fig.colorbar(im1, ax=ax[1])
cbar.ax.get_yaxis().labelpad = 15
ax[1].axis('off')
ax[1].set_title('Focused field amplitude')
ax[1].scatter(target[1], target[0], label="Target", color="green", marker='o')
fig.tight_layout()
plt.savefig("harmonic_focusing.pdf")
Lastly, we can visualize the learned apodization
plt.figure(figsize=(10, 3))
plt.plot(jnp.real(phase_to_apod(transmit_phase)))
plt.plot(jnp.imag(phase_to_apod(transmit_phase)))
# plt.plot(jnp.abs(phase_to_apod(transmit_phase)), "r.")
plt.title("Apodization")
Text(0.5, 1.0, 'Apodization')
plt.plot(-jnp.array(losshistory))
plt.title("Amplitude at target location")
plt.xlabel("Optimization step")
plt.show()
Speed of sound gradientsĀ¶
Gradients can be evaluated with respect to every parameter of the simulation. In this example, we will keep the source term fixed and vary the density of an acoustic lens to focus on a target.
from jax import random
from jwave.signal_processing import smooth
key = random.PRNGKey(12)
target = [60, 360] # Target location
# Constructing medium physical properties
def get_sos(segments, start_point=30, height=4, width=30):
sos = jnp.ones(N)
for k in range(len(segments)):
sos = sos.at[
start_point + k * height : start_point + (k + 1) * height, 50 : 50 + width
].add(jax.nn.sigmoid(segments[k]))
return FourierSeries(jnp.expand_dims(sos, -1), domain)
key, _ = random.split(key)
sos_control_points = random.normal(key, shape=(65,))
sos = get_sos(sos_control_points)
show_positive_field(sos, aspect="equal")
from jwave.acoustics.operators import helmholtz
medium = Medium(domain, sound_speed=get_sos(sos_control_points))
op_params = helmholtz.default_params(linear_phase, medium, omega=1.0)
print(op_params.keys())
dict_keys(['fft_u', 'pml_on_grid'])
from jax import value_and_grad
from jwave.acoustics.time_harmonic import helmholtz_solver
def loss(field):
field = field.on_grid
return -jnp.sum(jnp.abs(field[target[0], target[1]]))
def get_field(params, tol, field):
medium = Medium(domain, sound_speed=get_sos(params))
return helmholtz_solver(
medium, 1.0, linear_phase, guess=field, tol=tol, checkpoint=False
)
def full_loss(params, tol, field):
field = get_field(params, tol, field)
return loss(field), field
loss_with_grad = value_and_grad(full_loss, has_aux=True)
from jax import jit
from jax.example_libraries import optimizers
from tqdm import tqdm
losshistory = []
key, _ = random.split(key)
sos_vector = random.normal(key, shape=(65,))
init_fun, update_fun, get_params = optimizers.adam(0.1, b1=0.9, b2=0.9)
opt_state = init_fun(sos_control_points)
@jit
def update(opt_state, tol, field):
loss_and_field, gradient = loss_with_grad(get_params(opt_state), tol, field)
lossval = loss_and_field[0]
field = loss_and_field[1]
return lossval, field, update_fun(k, gradient, opt_state)
pbar = tqdm(range(100))
tol = 1e-3
field = -linear_phase
for k in pbar:
lossval, field, opt_state = update(opt_state, tol, field)
# For logging
pbar.set_description("Tol: {} Ampl: {:01.4f}".format(tol, -lossval))
losshistory.append(lossval)
transmit_phase = get_params(opt_state)
Tol: 0.001 Ampl: 0.3812: 100%|āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā| 100/100 [10:26<00:00, 6.26s/it]
plt.plot(-jnp.array(losshistory)) #
plt.title("Amplitude at target location")
Text(0.5, 1.0, 'Amplitude at target location')
opt_sos_vector = get_params(opt_state)
plt.figure(figsize=(10, 6))
plt.imshow(jnp.abs(field.on_grid), vmax=0.35, cmap="inferno")
plt.colorbar()
plt.scatter(target[1], target[0])
<matplotlib.collections.PathCollection at 0x7f529118a830>
sos = get_sos(opt_sos_vector)
plt.figure(figsize=(8, 8))
plt.imshow(sos.on_grid[..., 0])
plt.title("Sound speed map")
plt.scatter(target[1], target[0], label="Target")
plt.legend()
<matplotlib.legend.Legend at 0x7f5291216c80>
plt.plot(sos.on_grid[..., 64, 0])
[<matplotlib.lines.Line2D at 0x7f52a19abca0>]