TPU Topo
HBM renders into two memories: vmem, where matricies go, and smem, where scalars and conditional logic goes.
Recall that math happens in vmem.
OG Pallas
- memory pipelining: HBM and VMEM mapping between
- custom prefetch:
PrefetchScalarGridSpec—enable runtime dependent scalar transform - input-output aliasing: recycle input and output buffers for in-place updates
- passing HBM buffers: for collectives where there’s no real math etc., you can just pass references to HBM
pl.kernel
A brand new API. Also need multi-core support like Megacore stuff.
refs
So there are pointers now.
def add_refs(xr, yr, or):
x = xr[...]
y = yr[...]
or[...] = x+y
def add_refsnow(x,y):
xr, yr = jax.new_ref(x), jax.new_ref(y)
or = jax.empty_ref(x) # x shape
add_refs(xr, yr, or)
return jax.freeze(or)
you can just use references
You can:
pltpu.sync_copybetween refs- thus dynamic slice on these kernels; you can
sync_copyinto smem and make the TPU do conditional logic (this is purely DYNAMIC runtime data!)
and you can now arbitrarily copy mutable refs into SMEM for in-kernel conditional logic
def slice():
pltpu.sync_copy(hbm, smem)
block_i = some_dynamic_computation
block_i = some_dynamic_computation
x_slc = pl.dynamic_slice(things_in_smem)
pltpu.sync_copy(hbm.at[x_slc], o_hbm)
and so, a new matmul kernel
In principle: the idea is that you no longer have to manually specify movement and allocation in a imperative way; instead, the compiler from pl.kernel figures out needs to move where; in particular, if you pass in a ref as input, Pallas will happily recycle the memory using the expected semantics.
@jat.jit
def matmul_kernel(x,y):
# block sizes of computation
block_m, block_k, block_n = 128, 128, 128
# measuring shape
m,k,n = x.shape[0] # etc.
# these are using the nice ref api
def kernel_body(xr, yr, acc_ref):
grid = (m // blocks[0], n // blocks[2], k // blocks[1])
def pipeline(xr_vmem, yr_vmem, or):
or[...] += jnp.dot(x)
pltpu.emit_pipeline(
pipeline,
in_specs=[
pl.BlockSpec(
((block_m, block_k), lambda i,j,k: (i,k)) # pipeline over j dimention
),
pl.BlockSpec(
((block_m, block_k), lambda i,j,k: (k,j)) # pipeline over j dimention
),
],
out_specs=[
pl.BlockSpec(
((block_m, block_k), lambda i,j,k: (i,j)) # pipeline over j dimention
),
],
grid=grid
)(xr, yr, or)
# so megacores are natively supported
mesh = pltpu.create_tensorcore_meth(axis_name="core", num_cores=1)
o = pl.kernel(
kernel_body,
mesh=mesh
out_type=jax.ShapeStruct((m,n),x.dtype), # this sets up a out_ref of size buffer
scratch_type=(pltpu.VMEM((block_m, block_n), x.dtype))
)(x,y)
return o
Importantly! If your pl.kernel takes references as input, no copies will not be made. So you can just do a jax.lax.foriloop(some_pallas_kernel_takes_refs, input_refs), and you have no launch overhead beyond what you explicitly specified in yoru kernel.
Good for:
- collectives, which does no real math
- any memory-recycling operation
is expressed naturally.
Other Pallas Developments
SparseCore
The TPUs have sparse cores now and can program a kernel that 1) communicates between them and 2) move data visible to the same VMEM.
You can express in Pallas using Semiphore signals etc between main, scalar, and sparse cores. You can use SparseCores as semicores to drive communication.
Next steps
- full device MPMD (multi-program, multi-data) kernels
- fused gather (<- initial communication, each iteration gathers a small chunk while we overlap computation and gather)
- ragged all to all
