Jun 8 – 13, 2025
OAC conference center, Kolymbari, Crete, Greece.
Europe/Athens timezone

Differentiable Computation with Awkward Array and JAX

Jun 10, 2025, 12:30 PM
25m
OAC conference center, Kolymbari, Crete, Greece.

OAC conference center, Kolymbari, Crete, Greece.

Talk Methods and tools Methods and tools

Speaker

Saransh Chopra (Princeton University (US))

Description

Modern scientific computing often involves nested and variable-length data structures, which pose challenges for automatic differentiation (AD). Awkward Array is a library for manipulating irregular data and its integration with JAX enables forward and reverse mode AD on irregular data. Several Python libraries, such as PyTorch, TensorFlow, and Zarr, offer variations of ragged data structures, but differentiating through their ragged types remains impossible or problematic. Awkward's JAX backend allows users to differentiate nested and variable-length data structures without compromising readability, ease of use, and performance.
This talk presents the current status of the Awkward Array's JAX backend, highlighting its implementation using JAX's pytrees, tracing mechanisms, and compatibility with JAX's AD system. We discuss the coverage of Awkward Array's automatic differentiation support, strategies for differentiable programming with nested data, and challenges encountered in extending JAX's API to support non-rectilinear array structures. Finally, we outline future development directions, including keeping up with JAX's evolving AD ecosystem, improved interoperability with ML frameworks, and potential applications in physics and beyond.

Authors

Ianna Osborne (Princeton University) Saransh Chopra (Princeton University (US))

Presentation materials

There are no materials yet.