Image matching exampleĀ¶
Let's start by installing some required libraries if they are missing
InĀ [Ā ]:
Copied!
!pip install tqdm optax matplotlib
!pip install tqdm optax matplotlib
And now we import all the libraries that will be used in this tutorial
InĀ [1]:
Copied!
import holab as hb
# Import standard and third-party libraries
from functools import partial
from importlib import util
import jax
from jax import numpy as jnp
from jaxtyping import Num
from jwave import Domain
import optax
from matplotlib import pyplot as plt
from tqdm import trange
import holab as hb
# Import standard and third-party libraries
from functools import partial
from importlib import util
import jax
from jax import numpy as jnp
from jaxtyping import Num
from jwave import Domain
import optax
from matplotlib import pyplot as plt
from tqdm import trange
Define the experiment parameters
InĀ [2]:
Copied!
f0 = 1_000_000.0 # 1 MHz
lens_thickness = 0.006 # 6 mm
lens_radius = 0.025 # 25 mm
trasducer_radius = 0.0127 # 12.7 mm
projection_distance = 0.012 # 12 mm
intensity_regularizer = 0.1
learning_rate = 0.1
background_material = hb.materials.water
f0 = 1_000_000.0 # 1 MHz
lens_thickness = 0.006 # 6 mm
lens_radius = 0.025 # 25 mm
trasducer_radius = 0.0127 # 12.7 mm
projection_distance = 0.012 # 12 mm
intensity_regularizer = 0.1
learning_rate = 0.1
background_material = hb.materials.water
Define the simulation settings
InĀ [3]:
Copied!
settings = hb.Settings()
settings.lateral_padding = 0.0027 # Update paddings
settings.axial_padding = 0.001
settings
settings = hb.Settings()
settings.lateral_padding = 0.0027 # Update paddings
settings.axial_padding = 0.001
settings
Out[3]:
Settings(random_seed=42, ppw=6, pml_size=16, lateral_padding=0.0027, axial_padding=0.001)
Create simulation domain
InĀ [4]:
Copied!
domain = settings.construct_domain(
f0=f0,
lateral_size=lens_radius*2,
axial_size=lens_thickness,
background = background_material,
)
print(f"Simulation domain: {domain}")
domain = settings.construct_domain(
f0=f0,
lateral_size=lens_radius*2,
axial_size=lens_thickness,
background = background_material,
)
print(f"Simulation domain: {domain}")
Simulation domain: Domain(N=(256, 256, 64), dx=(0.0002466666666666667, 0.0002466666666666667, 0.0002466666666666667))
Initialize lens object
InĀ [5]:
Copied!
lens = hb.TwoMaterialsInterpolated.disk(
domain = domain,
radius = lens_radius,
thickness = lens_thickness,
material1 = hb.materials.agilus30,
material2 = hb.materials.veroclear,
)
print(f"Lens object: {lens}")
lens = hb.TwoMaterialsInterpolated.disk(
domain = domain,
radius = lens_radius,
thickness = lens_thickness,
material1 = hb.materials.agilus30,
material2 = hb.materials.veroclear,
)
print(f"Lens object: {lens}")
Lens object: TwoMaterialsInterpolated[Agilus30, VeroClear]
Define the transducer source
InĀ [6]:
Copied!
transducer_z_pos = -lens_thickness/2. - domain.dx[2]
source = hb.make_thin_disk(domain, trasducer_radius, transducer_z_pos) + 0j
print(f"Source: {source}")
transducer_z_pos = -lens_thickness/2. - domain.dx[2]
source = hb.make_thin_disk(domain, trasducer_radius, transducer_z_pos) + 0j
print(f"Source: {source}")
Source: FourierSeries[dims=1, size=(256, 256, 64)]
Load target image
InĀ [21]:
Copied!
domain_2D = Domain(N = domain.N[:2], dx = domain.dx[:2])
target_image = hb.load_image(
domain_2D,
image_name="dove.png",
folder = "../experiments/images/")
plt.imshow(jnp.abs(target_image.on_grid))
plt.colorbar()
plt.title("Target image")
plt.show()
domain_2D = Domain(N = domain.N[:2], dx = domain.dx[:2])
target_image = hb.load_image(
domain_2D,
image_name="dove.png",
folder = "../experiments/images/")
plt.imshow(jnp.abs(target_image.on_grid))
plt.colorbar()
plt.title("Target image")
plt.show()
We now define the loss function, generate the corresponding gradient function and compile it for fast GPU computations using jax
program transformations
InĀ [8]:
Copied!
@jax.jit
@partial(jax.value_and_grad, has_aux=True)
def loss_function(
lens_interpolation_coefficient,
*,
lens = lens,
source = source,
target_image = target_image
) -> Num:
# Update the lens parameters
lens.interpolation_coefficient = lens_interpolation_coefficient
# Compute the hologram
hologram_plane_field = hb.compute_hologram(
lens = lens,
source = source,
f0 = f0,
lens_thickness = lens_thickness,
projection_distance = projection_distance,
)
# Calculate the loss against the target image
corr_val = hb.losses.amplitude_correlation(
hologram_plane_field,
target_image,
)
# Add an intensity regularization
regularizer = hb.losses.field_intensity(hologram_plane_field)
# Return the loss value
full_loss = -corr_val - intensity_regularizer*regularizer
return full_loss, hologram_plane_field
@jax.jit
@partial(jax.value_and_grad, has_aux=True)
def loss_function(
lens_interpolation_coefficient,
*,
lens = lens,
source = source,
target_image = target_image
) -> Num:
# Update the lens parameters
lens.interpolation_coefficient = lens_interpolation_coefficient
# Compute the hologram
hologram_plane_field = hb.compute_hologram(
lens = lens,
source = source,
f0 = f0,
lens_thickness = lens_thickness,
projection_distance = projection_distance,
)
# Calculate the loss against the target image
corr_val = hb.losses.amplitude_correlation(
hologram_plane_field,
target_image,
)
# Add an intensity regularization
regularizer = hb.losses.field_intensity(hologram_plane_field)
# Return the loss value
full_loss = -corr_val - intensity_regularizer*regularizer
return full_loss, hologram_plane_field
At this point, we can write the optimization step
InĀ [10]:
Copied!
# Optimize things
optimizer = optax.adam(learning_rate)
params = lens.interpolation_coefficient
opt_state = optimizer.init(params)
@jax.jit
def update(params, opt_state, lens, source, target_image):
(lossval, aux), grads = loss_function(params, lens=lens, source=source, target_image=target_image)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state, lossval, aux, grads
# Optimize things
optimizer = optax.adam(learning_rate)
params = lens.interpolation_coefficient
opt_state = optimizer.init(params)
@jax.jit
def update(params, opt_state, lens, source, target_image):
(lossval, aux), grads = loss_function(params, lens=lens, source=source, target_image=target_image)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state, lossval, aux, grads
And finally, run the optimization loop
InĀ [11]:
Copied!
optim_steps = 20
with trange(optim_steps) as pbar:
for i in pbar:
params, opt_state, lossval, aux, grads = update(
params,
opt_state,
lens = lens,
source = source,
target_image = target_image
)
pbar.set_description(f"Loss value: {lossval}")
optim_steps = 20
with trange(optim_steps) as pbar:
for i in pbar:
params, opt_state, lossval, aux, grads = update(
params,
opt_state,
lens = lens,
source = source,
target_image = target_image
)
pbar.set_description(f"Loss value: {lossval}")
Loss value: -0.5439061522483826: 100%|āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā| 20/20 [21:41<00:00, 65.06s/it]
Let's now visualize the lens and the resulting hologram
InĀ [Ā ]:
Copied!
# Make the FourierSeries into arrays
target_image = target_image.on_grid
result_hologram = aux.on_grid
# Make the FourierSeries into arrays
target_image = target_image.on_grid
result_hologram = aux.on_grid
InĀ [37]:
Copied!
result_hologram = result_hologram / jnp.max(jnp.abs(result_hologram))
target_image = target_image / jnp.max(jnp.abs(target_image))
fig, ax = plt.subplots(1,3,figsize=(12,5))
ax[0].imshow(jnp.abs(target_image), cmap="Greys", interpolation="nearest", vmax=1)
ax[0].set_title("Target")
ax[1].imshow(jnp.abs(result_hologram), cmap="Greys", interpolation="nearest", vmax=1)
ax[1].set_title("Amplitude")
ax[2].imshow(jnp.angle(result_hologram), cmap="twilight", interpolation="nearest")
ax[2].set_title("Phase")
plt.show()
result_hologram = result_hologram / jnp.max(jnp.abs(result_hologram))
target_image = target_image / jnp.max(jnp.abs(target_image))
fig, ax = plt.subplots(1,3,figsize=(12,5))
ax[0].imshow(jnp.abs(target_image), cmap="Greys", interpolation="nearest", vmax=1)
ax[0].set_title("Target")
ax[1].imshow(jnp.abs(result_hologram), cmap="Greys", interpolation="nearest", vmax=1)
ax[1].set_title("Amplitude")
ax[2].imshow(jnp.angle(result_hologram), cmap="twilight", interpolation="nearest")
ax[2].set_title("Phase")
plt.show()
InĀ [34]:
Copied!
lens.interpolation_coefficient = params
lens_medium = lens.as_medium(f0=f0)
sound_speed = lens_medium.sound_speed.on_grid
maxval = jnp.amax(sound_speed)
fig, ax = plt.subplots(1,3,figsize=(12,6), gridspec_kw={'width_ratios': [1,1,3]})
ax[0].imshow(jnp.squeeze(sound_speed[128]), cmap="inferno", vmin=2000)
ax[0].set_title("Y-Z section")
ax[1].imshow(jnp.squeeze(sound_speed[:,128]), cmap="inferno", vmin=2000)
ax[1].set_title("X-Z section")
ax[2].imshow(jnp.squeeze(sound_speed[:,:,32]), cmap="inferno", vmin=2000)
ax[2].set_title("X-Y section")
lens.interpolation_coefficient = params
lens_medium = lens.as_medium(f0=f0)
sound_speed = lens_medium.sound_speed.on_grid
maxval = jnp.amax(sound_speed)
fig, ax = plt.subplots(1,3,figsize=(12,6), gridspec_kw={'width_ratios': [1,1,3]})
ax[0].imshow(jnp.squeeze(sound_speed[128]), cmap="inferno", vmin=2000)
ax[0].set_title("Y-Z section")
ax[1].imshow(jnp.squeeze(sound_speed[:,128]), cmap="inferno", vmin=2000)
ax[1].set_title("X-Z section")
ax[2].imshow(jnp.squeeze(sound_speed[:,:,32]), cmap="inferno", vmin=2000)
ax[2].set_title("X-Y section")
Out[34]:
Text(0.5, 1.0, 'X-Y section')
InĀ [38]:
Copied!
values = sound_speed[sound_speed > 1480]
# Histogram
plt.figure(figsize=(12,4))
n, bins, patches = plt.hist(values, bins=20, facecolor='#2ab0ff', linewidth=0.5, alpha=0.7)
for i in range(len(patches)):
v = (n[i]/max(n))**0.10
patches[i].set_facecolor(plt.cm.viridis(v))
plt.title('Sound speed values in lens', fontsize=12)
plt.xlabel('Value', fontsize=10)
plt.ylabel('Num voxels', fontsize=10)
plt.yscale("log")
plt.vlines([2035, 2475], ymin=0, ymax=6e5, linestyle="-.", linewidths=2, color="#ff7777")
plt.xlim([2020, 2490])
plt.show()
values = sound_speed[sound_speed > 1480]
# Histogram
plt.figure(figsize=(12,4))
n, bins, patches = plt.hist(values, bins=20, facecolor='#2ab0ff', linewidth=0.5, alpha=0.7)
for i in range(len(patches)):
v = (n[i]/max(n))**0.10
patches[i].set_facecolor(plt.cm.viridis(v))
plt.title('Sound speed values in lens', fontsize=12)
plt.xlabel('Value', fontsize=10)
plt.ylabel('Num voxels', fontsize=10)
plt.yscale("log")
plt.vlines([2035, 2475], ymin=0, ymax=6e5, linestyle="-.", linewidths=2, color="#ff7777")
plt.xlim([2020, 2490])
plt.show()