Jax Feature Releases

Jax stuff news!

“Feel the API”

jit -> sharding hints -> explicit sharding -> shard_map collectives -> pallas -> FFI

Refs

There are refs now! Dynamics slicing! Plumbing out metrics is hard! Batch-norms! How do we express things in Jax.

“I want to keep this thing in memory at the same place.”

For stuff like gradient accumulation / sparse updating this is very good.

history

  • we can’t express mutability
  • and we think the change

fine.

There’s refs now!

import jax
import jax.numpy as jnp

x_ref = jax.new_ref(jnp.zeros(3))

@jat.jit
def f():
    x_ref[1] += 1

This is kind of already kind of Pallas-style refs already.

API

  • new_ref(init_val: Array) -> Ref: makes a ref
  • freeze(ref: Ref) -> Array: materalizes a ref, and copies value out

limitations

How do you prevent statefulness from escaping?

  • you can’t return refs!
  • only callers can freeze refs!
  • you can’t pass a ref more than once!
  • refs are invalidated after freeze!
    • …refs can’t do math unless they are no longer maths!
    • …but you can set them!

some details

  • basically Pallas refs; now with 100% more autodiff
  • a new type; you can’t accidentally pass in an Array
  • prevents aliasing

Fixableish: you can’t do math on refs. You can only index reads and writes out of them.

questions

  • Does indexing into refs create a copy? Or will XLA magically magic