Pointers, masks, and boundaries
tl.load and tl.store are how a tile talks to HBM. Two ideas dominate: strides as the address calculator, and masks as the boundary handler. Get them right and the kernel reads from the right places without segfaulting on irregular sizes.
A pointer in Triton is a vector of addresses
When you pass a PyTorch tensor to a Triton kernel, what arrives is a pointer to the first element. To read element i, you add i to the pointer (and the compiler scales by dtype size). To read a whole tile, you add a tile of offsets:
pid = tl.program_id(0)
offs = pid*BLOCK + tl.arange(0, BLOCK) # tile of BLOCK indices
x = tl.load(a_ptr + offs) # tile of BLOCK values
So a_ptr + offs is itself a tile — a vector of pointers, one per lane. tl.load reads all BLOCK addresses in one call. The compiler will turn this into the right number of coalesced 32-wide loads.
Strides — when memory is not contiguous
For 1D arrays, indexing is just ptr + i. For 2D arrays, the row stride may not equal the column count: a view, a transpose, a sliced tensor can all give non-trivial strides. You must compute addresses with strides explicitly:
# For x of shape (M, K), the (i, j) element is at:
# x_ptr + i * stride_M + j * stride_K
offs_m = pid_m * BM + tl.arange(0, BM) # (BM,)
offs_k = tl.arange(0, BK) # (BK,)
addrs = x_ptr + offs_m[:, None] * stride_M \
+ offs_k[None, :] * stride_K # (BM, BK)
x_tile = tl.load(addrs) # (BM, BK)
The [:, None] / [None, :] dance is broadcasting — same as NumPy. The result is a 2D tile of addresses. tl.load turns it into a 2D tile of values.
tensor.stride(0), tensor.stride(1) are in element count, not bytes. Pass them in the same order your kernel multiplies. A row-major (M, K) tensor has stride(0) = K, stride(1) = 1; for a transposed view this is reversed. Garbage output for non-contiguous tensors almost always traces here.
The boundary problem — and why every load needs a mask
The grid is cdiv(N, BLOCK). When N isn't a multiple of BLOCK, the last program's tile spills past the end of the array:
If you load pid 3's 256 addresses without a mask, the last 24 of them read past the tensor's allocation. Best case: silently load garbage that contaminates the rest of the tile. Worst case: segfault.
The fix is the mask argument:
offs = pid * BLOCK + tl.arange(0, BLOCK)
m = offs < N
x = tl.load(a_ptr + offs, mask=m, other=0.0)
# lanes where m is False return `other`, no HBM access
tl.store(c_ptr + offs, y, mask=m)
# lanes where m is False are dropped on store
The mask is a tile of bools (same shape as the offsets). Lanes where mask=False become no-ops for the load (returning other) or store (no write at all). No branch divergence — the hardware predicates these per-lane.
tl.load on the boundary axis needs a mask. Every tl.store on the boundary axis needs a mask. If you forget one, irregular shapes will produce wrong outputs or crash, but the tests on power-of-two shapes will pass and you'll merge the bug.
Picking other correctly
The other argument is what masked-off lanes read. The right choice depends on what you do with the value next:
| Operation after load | Use other = | Why |
|---|---|---|
| Sum / addition | 0.0 | Identity of + |
| Product / multiplication | 1.0 | Identity of × |
| Max reduction | -float('inf') | Identity of max |
| Min reduction | +float('inf') | Identity of min |
| Matmul accumulation | 0.0 | Identity of dot product |
Forgetting this is the second-most-common silent correctness bug. The other=0.0 default works for sum/matmul. For softmax (which does a max reduction first), you need other=-inf on the boundary tile or you'll get a spurious "0" winner.
Coalesced vs strided — what the compiler can do for you
A load is coalesced when consecutive lanes in a warp read consecutive addresses. The hardware combines them into a single 128-byte transaction. If lanes read strided or scattered addresses, you get multiple transactions and the kernel becomes bandwidth-starved.
Concretely: for a row-major matrix x[M, K] where K is contiguous, the tile
x_ptr + offs_m[:, None]*stride_M + offs_k[None, :]*stride_K
is coalesced when stride_K=1 (the K axis is contiguous). Iteration across the BK lanes of one row hits consecutive bytes. If you accidentally transpose the strides — pass x.t()'s strides without realising — you read strided memory and bandwidth craters.
What about writes?
tl.store takes the same mask argument. Masked-off lanes simply don't write. There's no other for stores — masked lanes are silently dropped, which is exactly what you want for the boundary tile.
tl.store(c_ptr + offs, x + y, mask=offs < N)
Interactive · the mask in action
Set N and the BLOCK size. See which lanes of the last program's tile are masked off, and what happens if you forget the mask.
The full pattern, all together
@triton.jit
def safe_kernel(x_ptr, y_ptr, out_ptr, N, BLOCK: tl.constexpr):
pid = tl.program_id(0)
offs = pid * BLOCK + tl.arange(0, BLOCK)
m = offs < N
x = tl.load(x_ptr + offs, mask=m, other=0.0)
y = tl.load(y_ptr + offs, mask=m, other=0.0)
z = x + y
tl.store(out_ptr + offs, z, mask=m)
Three masks (one per access), one other=0.0, and one comparison. This is the safe skeleton of every 1D Triton kernel.
What's next
You can move bytes correctly. Lesson 05 introduces the op you're moving them for: tl.dot, the tile-level matmul that drops to tensor cores. The lesson sets up exactly when that lowering happens and what dtypes the hardware requires.