Jax stuff news!
“Feel the API”
jit -> sharding hints -> explicit sharding -> shard_map collectives -> pallas -> FFI
Refs
There are refs now! Dynamics slicing! Plumbing out metrics is hard! Batch-norms! How do we express things in Jax.
“I want to keep this thing in memory at the same place.”
For stuff like gradient accumulation / sparse updating this is very good.
history
- we can’t express mutability
- and we think the change
fine.
There’s refs now!
import jax
import jax.numpy as jnp
x_ref = jax.new_ref(jnp.zeros(3))
@jat.jit
def f():
x_ref[1] += 1
This is kind of already kind of Pallas-style refs already.
API
new_ref(init_val: Array) -> Ref: makes a reffreeze(ref: Ref) -> Array: materalizes a ref, and copies value out
limitations
How do you prevent statefulness from escaping?
- you can’t return refs!
- only callers can freeze refs!
- you can’t pass a ref more than once!
- refs are invalidated after freeze!
- …refs can’t do math unless they are no longer maths!
- …but you can set them!
some details
- basically Pallas refs; now with 100% more autodiff
- a new type; you can’t accidentally pass in an Array
- prevents aliasing
Fixableish: you can’t do math on refs. You can only index reads and writes out of them.
questions
- Does indexing into refs create a copy? Or will XLA magically magic
