JAX-Fluids: Toward Differentiable CFD of Compressible Single- and Two-phase Flows
Deniz Bezgin and Aaron Buhendwa
Computer Physics Communications Seminar Series
Host Computer Physics Communications |
DateMonday, March 17, 2025 2:00 PM to 3:00 PM (UTC) |
Live eventThe live event will be accessible via this page. |
JAX-Fluids: Toward Differentiable CFD of Compressible Single- and Two-phase Flows
This talk presents an overview of the automatically differentiable JAX-Fluids computational fluid dynamics (CFD) solver. JAX-Fluids is a high-order Godunov-type finite-volume solver for compressible single- and multi-phase flows. The solver is implemented using the JAX Python package which allows the computation of automatic differentiation (AD) gradients throughout the entire code framework. The present talk is structured into three parts. First, we discuss the numerical methods implemented in the JAX-Fluids solver, including the available two-phase models (i.e., a level-set based sharp-interface model and a five-equation diffuse-interface model). Second, we explore a JAX primitives-based parallelization strategy which scales effectively on GPU- and TPU-clusters while maintaining AD capabilities in distributed settings. In this section, we also highlight JAX-specific implementation aspects that are different from traditional HPC languages such as C++ or Fortran. Third, we showcase applications that combine high-order numerical methods with automatic differentiation. In particular, we demonstrate that JAX-Fluids allows the end-to-end optimization of numerical models and the solution of inverse problems, thereby facilitating research at the intersection of conventional CFD and machine learning.