Pallas Tutorial Notes

TPU Topo

HBM renders into two memories: vmem, where matricies go, and smem, where scalars and conditional logic goes.

Recall that math happens in vmem.

OG Pallas

  1. memory pipelining: HBM and VMEM mapping between
  2. custom prefetch: PrefetchScalarGridSpec—enable runtime dependent scalar transform
  3. input-output aliasing: recycle input and output buffers for in-place updates
  4. passing HBM buffers: for collectives where there’s no real math etc., you can just pass references to HBM

pl.kernel

A brand new API. Also need multi-core support like Megacore stuff.

refs

So there are pointers now.

def add_refs(xr, yr, or):
    x = xr[...]
    y = yr[...]
    or[...] = x+y

def add_refsnow(x,y):
    xr, yr = jax.new_ref(x), jax.new_ref(y)
    or = jax.empty_ref(x) # x shape
    add_refs(xr, yr, or)
    return jax.freeze(or)

you can just use references

You can:

  • pltpu.sync_copy between refs
  • thus dynamic slice on these kernels; you can sync_copy into smem and make the TPU do conditional logic (this is purely DYNAMIC runtime data!)

and you can now arbitrarily copy mutable refs into SMEM for in-kernel conditional logic

def slice():
    pltpu.sync_copy(hbm, smem)

    block_i =  some_dynamic_computation
    block_i =  some_dynamic_computation

    x_slc = pl.dynamic_slice(things_in_smem)

    pltpu.sync_copy(hbm.at[x_slc], o_hbm)

and so, a new matmul kernel

In principle: the idea is that you no longer have to manually specify movement and allocation in a imperative way; instead, the compiler from pl.kernel figures out needs to move where; in particular, if you pass in a ref as input, Pallas will happily recycle the memory using the expected semantics.

@jat.jit
def matmul_kernel(x,y):
    # block sizes of computation
    block_m, block_k, block_n = 128, 128, 128

    # measuring shape
    m,k,n = x.shape[0] # etc.

    # these are using the nice ref api
    def kernel_body(xr, yr, acc_ref):
        grid = (m // blocks[0], n // blocks[2], k // blocks[1])

        def pipeline(xr_vmem, yr_vmem, or):
            or[...] += jnp.dot(x)

        pltpu.emit_pipeline(
            pipeline,
            in_specs=[
                pl.BlockSpec(
                    ((block_m, block_k), lambda i,j,k: (i,k)) # pipeline over j dimention
                ),
                pl.BlockSpec(
                    ((block_m, block_k), lambda i,j,k: (k,j)) # pipeline over j dimention
                ),
            ],
            out_specs=[
                pl.BlockSpec(
                    ((block_m, block_k), lambda i,j,k: (i,j)) # pipeline over j dimention
                ),
            ],
            grid=grid
        )(xr, yr, or)

    # so megacores are natively supported
    mesh = pltpu.create_tensorcore_meth(axis_name="core", num_cores=1)

    o = pl.kernel(
        kernel_body,
        mesh=mesh
        out_type=jax.ShapeStruct((m,n),x.dtype), # this sets up a out_ref of size buffer
        scratch_type=(pltpu.VMEM((block_m, block_n), x.dtype))
    )(x,y)

    return o

Importantly! If your pl.kernel takes references as input, no copies will not be made. So you can just do a jax.lax.foriloop(some_pallas_kernel_takes_refs, input_refs), and you have no launch overhead beyond what you explicitly specified in yoru kernel.

Good for:

  • collectives, which does no real math
  • any memory-recycling operation

is expressed naturally.

Other Pallas Developments

SparseCore

The TPUs have sparse cores now and can program a kernel that 1) communicates between them and 2) move data visible to the same VMEM.

You can express in Pallas using Semiphore signals etc between main, scalar, and sparse cores. You can use SparseCores as semicores to drive communication.

Next steps

  • full device MPMD (multi-program, multi-data) kernels
  • fused gather (<- initial communication, each iteration gathers a small chunk while we overlap computation and gather)
  • ragged all to all