JAX-Fluids: Toward Differentiable CFD of Compressible Single- and Two-phase Flows - presented by Deniz Bezgin and Aaron Buhendwa

JAX-Fluids: Toward Differentiable CFD of Compressible Single- and Two-phase Flows

Deniz Bezgin and Aaron Buhendwa

ABDeniz Bezgin
Slide at 31:33
JAX-specific programming aspects
Aaron Buhendwa
Spatial conditionals in jit-compiled functions
Function transformations
Fixed shapes require the use of masks
jax.vmap to vectorize functions on a single device
1 from functools import partial
def linear_interpolation(
interpolation_position: Array,
# SHAPE (3,)
2 import jax.numpy as jnp, jax
buffer: Array,
# SHAPE (Nx, Ny, Nz)
3 from jax import Array
cell_centers: Tuple[Array],
) -> Array:
@partial(jax.jit,static_argnums=(3,))
interpolated_buffer=.
def upwind_finite_difference(
return interpolated_buffer
buffer: Array, # shape (...,nx+2*nh)
velocity: Array, # shape (nx,)
11 linear_interpolation_vec = jax.vmap(
dx: float,
linear_interpolation,
in_axes=(0, None, None),
nh: int
out_axes=0
) -> Array:
15 )
16 # NOTE: linear_interpolation_vec can be called with
17 # interpolation_position of shape (N_points, 3)
so = jnp.s[...,nh:-nh] # ui
s1 = jnp.s...,nh-1:-nh-1]#ui-1
jax.pmap to evaluate functions in parallel on multiple
s2 = jnp.s nh+1:-nh+1] # u_i+1
devices
dbuffer_dx_left = (buffer[s0] - buffer[s1])/dx
1 # Large array
dbuffer_dx_right(-buffer[s0] + buffer[s2])/dx
2 A = jnp.array(np.random.randn(N,N))
3 foo = lambda X: x**2
dbuffer_dx = jnp.where(velocity>0.0,
4 B = foo(A)
dbuffer_dx_left,
dbuffer_dx_right)
6 # Reshape A and distribute on
# N_device XLA devices
8 A=A. reshape (N_device, N/N_device, N)
return dbuffer_dx
9 foo_pmap = jax.pmap(foo, in_axes=0, out_axes=0, axis_name="i")
10 B = foo_pmap(A) 11 = B. reshape(N, N)
CPC Seminar 17 March 2025
Share slide
Summary (AI generated)

JIT compilation in JAX comes with specific requirements. One crucial requirement is that array shapes must be fixed at compile time. Once a function is compiled using JIT, the arrays become traced, meaning that JAX recognizes only the shape and data type of the arrays, not their actual values. To implement control flow statements that depend on the actual values of input arguments, you must use the static arguments keyword.

In our example, we illustrate a JIT-compiled function that takes a 3D buffer as input, along with cell sizes and an additional argument specifying the axis direction for derivative computation. The number of halo cells must also be provided. The axis direction must be specified as a static argument, as JAX requires this information at compile time to implement control flow based on its value. Similarly, the number of halo cells must be marked as static so that JAX can infer the shapes of all generated buffers at compile time.

These restrictions necessitate special considerations for implementing spatial conditions in compiled functions, particularly through the use of masks. For instance, in an upwind finite difference implementation, we compute first derivatives based on the local velocity field. The function receives a 1D buffer, the velocity, cell sizes, and the number of cells. We create slice objects for both left and right-sided derivatives and define a mask based on the velocity, where the mask is set to one in regions where the velocity is greater than zero and zero elsewhere. We then use JNP.wa to combine the two derivative options, ensuring that shapes are fixed and known at compile time.

Next, we explore useful function transformations in JAX, starting with JIT compilation. JAX also provides JAX.vmap, which vectorizes functions across a single device. This allows us to evaluate functions on batches of data efficiently. For example, the linear interpolation function takes a vector of points and a computational grid buffer as input. While it can be evaluated for a single point, it is more efficient to evaluate it for millions of points in parallel. Instead of looping through each point sequentially, which would be inefficient on a GPU, we can utilize vmap to specify the axis for broadcasting across all input arguments.

Additionally, we have JAX.pmap, which enables parallel evaluation of functions across multiple devices. For example, if we have a large array with shape N x N, we can split this array according to the number of available devices. By specifying the axis for broadcasting, JAX will distribute the array across the devices and evaluate it in parallel.

Finally, we will also discuss the computation of gradients, which is another important function transformation in JAX.