Homogeneous Medium¶
This example notebook follows the Homogeneous Propagation Medium Example of k-Wave.
Setup¶
Domain¶
Similarly to k-Wave, j-Wave requires the user to specify a computational domain where the simulation takes place. This is done using the Domain
dataclass which is lifted from jaxdf
.
The inputs for the constructor are the size of the domain in grid points in each spatial direction, and the corresponding discretization steps.
from jwave.geometry import Domain
N, dx = (128, 128), (0.1e-3, 0.1e-3)
domain = Domain(N, dx)
Acoustic medium¶
In this example, the speed of sound has a constant value of $1500m/s$. The speed of sound is defined as part of the Medium
dataclass, which also needs the computational domain as mandatory input argument
from jwave.geometry import Medium
medium = Medium(domain=domain, sound_speed=1500.0)
print(medium)
Medium: - attenuation: 0.0 - density: 1.0 - domain: Domain(N=(128, 128), dx=(0.0001, 0.0001)) - pml_size: 20 - sound_speed: 1500.0
Time¶
Time-stepping simulations requires to define a TimeAxis
object, which is used by the timestepping scheme of the numerical simulation. To ensure a stable simulation, this object can be constructed from the medium
object for a given CFL number.
from jwave.geometry import TimeAxis
time_axis = TimeAxis.from_medium(medium, cfl=0.3)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
time_axis
<jwave.geometry.TimeAxis at 0x7f831ef2d3f0>
Initial pressure¶
The initial pressure distribution is a Field, therefore it must be somehow represented according to a discretization. Most of the functions of jwave
are tested using a FourierSeries
discretization: since in this example we are not interested in tweaking the underlying discretization, we will use this class to define the initial pressure field.
from jax import numpy as jnp
from jwave import FourierSeries
from jwave.geometry import circ_mask
p0 = 1.0 * jnp.expand_dims(circ_mask(N, 4, (80, 60)), -1)
p0 = FourierSeries(p0, domain)
from matplotlib import pyplot as plt
from jwave.utils import show_field
show_field(p0)
plt.title(f"Initial pressure field")
plt.show()
Run the simulation¶
from jax import jit
from jwave.acoustics import simulate_wave_propagation
@jit
def compiled_simulator(medium, p0):
return simulate_wave_propagation(medium, time_axis, p0=p0)
pressure = compiled_simulator(medium, p0)
t = 250
show_field(pressure[t])
plt.title(f"Pressure field at t={time_axis.to_array()[t]}")
plt.show()
Timings¶
%timeit compiled_simulator(medium, p0).params.block_until_ready()
478 ms ± 1.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)