This document is relevant for: Inf2
, Trn1
, Trn2
Single program, multiple data tensor addition#
In this tutorial we write a simple tensor addition kernel using NKI in PyTorch and JAX. In doing so, we learn about:
The NKI syntax and the SPMD programming model.
Best practices for validating and benchmarking your custom kernel against a reference native PyTorch or JAX implementation.
Note
This tutorial is written using the SPMD programming model in NKI. However, as discussed in NKI programming guide, adopting the SPMD programming model has no impact on performance of NKI kernel, and therefore is considered optional in current NKI release.
PyTorch#
Compute kernel#
We start by defining the compute kernel that has large tensor inputs,
but operates on a subset of the tensor at a tile size of [128, 512]
.
The partition dimension tile size is chosen according to the tile size
restrictions (nki.language.tile_size.pmax),
while the free dimension tile size is chosen arbitrarily (512
).
1import neuronxcc.nki as nki
2import neuronxcc.nki.language as nl
3
4
5@nki.jit
6def nki_tensor_add_kernel_(a_input, b_input):
7 """NKI kernel to compute element-wise addition of two input tensors
8
9 This kernel assumes strict input/output sizes can be uniformly tiled to [128,512]
10
11 Args:
12 a_input: a first input tensor
13 b_input: a second input tensor
14
15 Returns:
16 c_output: an output tensor
17 """
18 # Create output tensor shared between all SPMD instances as result tensor
19 c_output = nl.ndarray(a_input.shape, dtype=a_input.dtype, buffer=nl.shared_hbm)
20
21 # Calculate tile offsets based on current 'program'
22 offset_i_x = nl.program_id(0) * 128
23 offset_i_y = nl.program_id(1) * 512
24
25 # Generate tensor indices to index tensors a and b
26 ix = offset_i_x + nl.arange(128)[:, None]
27 iy = offset_i_y + nl.arange(512)[None, :]
28
29 # Load input data from device memory (HBM) to on-chip memory (SBUF)
30 # We refer to an indexed portion of a tensor as an intermediate tensor
31 a_tile = nl.load(a_input[ix, iy])
32 b_tile = nl.load(b_input[ix, iy])
33
34 # compute a + b
35 c_tile = a_tile + b_tile
36
37 # store the addition results back to device memory (c_output)
38 nl.store(c_output[ix, iy], value=c_tile)
39
40 # Transfer the ownership of `c_output` to the caller
41 return c_output
In this example:
We define the NKI kernel in
nki_tensor_add_kernel_
, decorate it with the nki.jit decorator to call the nki compiler to compile the kernel.Inside, we first allocate tensor
c_output
as the result of the kernelNext, we define offsets into the tensors, based on the ID of the worker executing the code (
nl.program_id
), and generate tile indices using these offsets withnl.arange
. We use advanced indexing here to showcase how it works. Basic indexing with slicing can also work. See NKI Programming Model for more information on different tensor indexing modes.We use
nl.program_id
to enable SPMD execution (single-program, multiple-data, see SPMD: Launching Multiple Instances of a Kernel), where each worker only operates on a (sub-tensor) tile of the input/output tensors. By accessing its ownprogram_id
, each worker can calculate the offsets it needs to access the correct tiles.The first axis of the tensor (mapped to the partition-dimension) is tiled into blocks of 128, based on hardware restrictions (see Tile Size Considerations). The second axis (mapped to the free-dimension) is tiled into blocks of 512 (no tile-size constraint, since the addition operation is performed on the Vector engine, the only restriction is on-chip memory capacity).
We then load sub-tensors data from tensors
a_input
andb_input
usingnl.load
, to place the tilesa_tile
andb_tile
in the on-chip memory (SBUF)We sum them to compute
c_tile
, and store it back to DRAM in the relevant portion of thec_output
tensor, usingnl.store
. Since both inputs and output are the same shape, we can use the same set of indices to access all three tensors.At the end, we use
return
statement to transfer the ownership of tensorc_output
to the caller of the kernel.
SPMD execution#
We declare a helper function, to launch the compute-kernel with appropriate grid/block sizes, to perform the computation over the whole input tensors.
1def nki_tensor_add(a_input, b_input):
2 """NKI kernel caller to compute element-wise addition of two input tensors
3
4 This kernel caller lifts tile-size restriction, by applying the kernel on tiles of the inputs/outputs
5
6 Args:
7 a_input: a first input tensor, of shape [N*128, M*512]
8 b_input: a second input tensor, of shape [N*128, M*512]
9
10 Returns:
11 a tensor of shape [N*128, M*512], the result of a_input + b_input
12 """
13
14 # The SPMD launch grid denotes the number of kernel instances.
15 # In this case, we use a 2D grid where the size of each invocation is 128x512
16 grid_x = a_input.shape[0] // 128
17 grid_y = a_input.shape[1] // 512
18
19 return nki_tensor_add_kernel_[grid_x, grid_y](a_input, b_input)
We are using a two-dimensional grid, where the first dimension of the tensor is tiled in the X dimension of the grid, while the second dimension is tiled in the Y dimension of the grid. In this scenario we assume that tensor sizes are a multiple of maximum tile sizes allowed, so we do not need to handle partial tiles.
Launching kernel and testing correctness#
To execute the kernel, we prepare tensors a
and b
, and call the
nki_tensor_add
helper function. We also verify the correctness of the NKI kernel against, torch by
comparing the outputs of both, using torch.allclose
:
1import torch
2from torch_xla.core import xla_model as xm
3
4if __name__ == "__main__":
5 device = xm.xla_device()
6
7 a = torch.rand((256, 1024), dtype=torch.bfloat16).to(device=device)
8 b = torch.rand((256, 1024), dtype=torch.bfloat16).to(device=device)
9
10 output_nki = nki_tensor_add(a, b)
11 print(f"output_nki={output_nki}")
12
13 output_torch = a + b
14 print(f"output_torch={output_torch}")
15
16 allclose = torch.allclose(output_torch, output_nki, atol=1e-4, rtol=1e-2)
17 if allclose:
18 print("NKI and Torch match")
19 else:
20 print("NKI and Torch differ")
21
22 assert allclose
Output:
2023-12-29 15:18:00.000558: 14283 INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2023-12-29 15:18:00.000559: 14283 INFO ||NEURON_CC_WRAPPER||: Call compiler with cmd: ['neuronx-cc', '--target=trn1', 'compile', '--framework', 'XLA', '/tmp/neuroncc_compile_workdir/49f554a2-2c55-4a88-8054-cc9f20824a46/model.MODULE_5007921933048625946+d41d8cd9.hlo.pb', '--output', '/tmp/neuroncc_compile_workdir/49f554a2-2c55-4a88-8054-cc9f20824a46/model.MODULE_5007921933048625946+d41d8cd9.neff', '--verbose=35']
.
Compiler status PASS
output_nki=tensor([[0.9297, 0.8359, 1.1719, ..., 0.4648, 0.2188, 0.9336],
[0.3906, 1.3125, 0.8789, ..., 1.6562, 1.7734, 0.9531],
[0.6445, 1.1406, 1.3281, ..., 0.9531, 0.8711, 0.9336],
...,
[0.4023, 0.6406, 1.5312, ..., 0.7617, 0.7734, 0.3359],
[0.8125, 0.7422, 1.2109, ..., 0.8516, 1.2031, 0.5430],
[1.3281, 1.2812, 1.3984, ..., 1.2344, 0.8711, 0.5664]],
device='xla:1', dtype=torch.bfloat16)
2023-12-29 15:18:02.000219: 14463 INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2023-12-29 15:18:02.000220: 14463 INFO ||NEURON_CC_WRAPPER||: Call compiler with cmd: ['neuronx-cc', '--target=trn1', 'compile', '--framework', 'XLA', '/tmp/neuroncc_compile_workdir/2e135b73-1c3b-45e4-a6f0-2c4b105c20e5/model.MODULE_10032327759287407517+d41d8cd9.hlo.pb', '--output', '/tmp/neuroncc_compile_workdir/2e135b73-1c3b-45e4-a6f0-2c4b105c20e5/model.MODULE_10032327759287407517+d41d8cd9.neff', '--verbose=35']
.
Compiler status PASS
output_torch=tensor([[0.9297, 0.8359, 1.1719, ..., 0.4648, 0.2188, 0.9336],
[0.3906, 1.3125, 0.8789, ..., 1.6562, 1.7734, 0.9531],
[0.6445, 1.1406, 1.3281, ..., 0.9531, 0.8711, 0.9336],
...,
[0.4023, 0.6406, 1.5312, ..., 0.7617, 0.7734, 0.3359],
[0.8125, 0.7422, 1.2109, ..., 0.8516, 1.2031, 0.5430],
[1.3281, 1.2812, 1.3984, ..., 1.2344, 0.8711, 0.5664]],
device='xla:1', dtype=torch.bfloat16)
2023-12-29 15:18:03.000797: 14647 INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2023-12-29 15:18:03.000798: 14647 INFO ||NEURON_CC_WRAPPER||: Call compiler with cmd: ['neuronx-cc', '--target=trn1', 'compile', '--framework', 'XLA', '/tmp/neuroncc_compile_workdir/74f8b6ae-76d9-4dd8-af7f-e5e1c40a27a3/model.MODULE_5906037506311912405+d41d8cd9.hlo.pb', '--output', '/tmp/neuroncc_compile_workdir/74f8b6ae-76d9-4dd8-af7f-e5e1c40a27a3/model.MODULE_5906037506311912405+d41d8cd9.neff', '--verbose=35']
.
Compiler status PASS
NKI and Torch match
Note that the tensor values you see will differ from what’s printed above, since this example uses torch.rand to initialize the inputs.
JAX#
Compute kernel#
We can reuse the same NKI compute kernel defined for PyTorch above.
1import neuronxcc.nki as nki
2import neuronxcc.nki.language as nl
3
4
5@nki.jit
6def nki_tensor_add_kernel_(a_input, b_input):
7 """NKI kernel to compute element-wise addition of two input tensors
8
9 This kernel assumes strict input/output sizes can be uniformly tiled to [128,512]
10
11 Args:
12 a_input: a first input tensor
13 b_input: a second input tensor
14
15 Returns:
16 c_output: an output tensor
17 """
18 # Create output tensor shared between all SPMD instances as result tensor
19 c_output = nl.ndarray(a_input.shape, dtype=a_input.dtype, buffer=nl.shared_hbm)
20
21 # Calculate tile offsets based on current 'program'
22 offset_i_x = nl.program_id(0) * 128
23 offset_i_y = nl.program_id(1) * 512
24
25 # Generate tensor indices to index tensors a and b
26 ix = offset_i_x + nl.arange(128)[:, None]
27 iy = offset_i_y + nl.arange(512)[None, :]
28
29 # Load input data from device memory (HBM) to on-chip memory (SBUF)
30 # We refer to an indexed portion of a tensor as an intermediate tensor
31 a_tile = nl.load(a_input[ix, iy])
32 b_tile = nl.load(b_input[ix, iy])
33
34 # compute a + b
35 c_tile = a_tile + b_tile
36
37 # store the addition results back to device memory (c_output)
38 nl.store(c_output[ix, iy], value=c_tile)
39
40 # Transfer the ownership of `c_output` to the caller
41 return c_output
SPMD execution#
Now we can also declare a helper function, to launch the compute-kernel with appropriate grid/block sizes, to perform the computation:
1def nki_tensor_add(a_input, b_input):
2 """NKI kernel caller to compute element-wise addition of two input tensors
3
4 This kernel caller lifts tile-size restriction, by applying the kernel on tiles of the inputs/outputs
5
6 Args:
7 a_input: a first input tensor, of shape [N*128, M*512]
8 b_input: a second input tensor, of shape [N*128, M*512]
9
10 Returns:
11 a tensor of shape [N*128, M*512], the result of a_input + b_input
12 """
13
14 # The SPMD launch grid denotes the number of kernel instances.
15 # In this case, we use a 2D grid where the size of each invocation is 128x512
16 grid_x = a_input.shape[0] // 128
17 grid_y = a_input.shape[1] // 512
18
19 return nki_tensor_add_kernel_[grid_x, grid_y](a_input, b_input)
We are using a two-dimensional grid, where the first dimension of the tensor is tiled in the X dimension of the grid, while the second dimension is tiled in the Y dimension of the grid. In this scenario we assume that tensor sizes are a multiple of maximum tile sizes allowed, so we do not need to handle partial tiles.
Launching kernel and testing correctness#
To execute the kernel, we prepare arrays a
and b
, and call the
nki_tensor_add
helper function. We also verify the correctness of the NKI kernel against, JAX by
comparing the outputs of both, using jax.numpy.allclose
:
1import jax
2import jax.numpy as jnp
3
4if __name__ == "__main__":
5
6 seed_a, seed_b = jax.random.split(jax.random.PRNGKey(42))
7 a = jax.random.uniform(seed_a, (256, 1024), dtype=jnp.bfloat16)
8 b = jax.random.uniform(seed_b, (256, 1024), dtype=jnp.bfloat16)
9
10 output_nki = nki_tensor_add(a, b)
11 print(f"output_nki={output_nki}")
12
13 output_jax = a + b
14 print(f"output_jax={output_jax}")
15
16 allclose = jnp.allclose(output_jax, output_nki, atol=1e-4, rtol=1e-2)
17 if allclose:
18 print("NKI and JAX match")
19 else:
20 print("NKI and JAX differ")
21
22 assert allclose
Output:
.
Compiler status PASS
.
Compiler status PASS
.
Compiler status PASS
output_nki=[[0.992188 1.27344 1.65625 ... 0.90625 1.34375 1.77344]
[0 0.90625 1.34375 ... 0.390625 0.703125 0.914062]
[0.5 0.390625 0.703125 ... 1.22656 1.15625 1.01562]
...
[1.98438 1.98438 1.98438 ... 1.33594 1.64062 1.35938]
[0.992188 1.33594 1.64062 ... 1.16406 1.67188 1.20312]
[1.49219 1.16406 1.67188 ... 1.375 1 1.6875]]
.
Compiler status PASS
output_jax=[[0.992188 1.27344 1.65625 ... 0.90625 1.34375 1.77344]
[0 0.90625 1.34375 ... 0.390625 0.703125 0.914062]
[0.5 0.390625 0.703125 ... 1.22656 1.15625 1.01562]
...
[1.98438 1.98438 1.98438 ... 1.33594 1.64062 1.35938]
[0.992188 1.33594 1.64062 ... 1.16406 1.67188 1.20312]
[1.49219 1.16406 1.67188 ... 1.375 1 1.6875]]
.
Compiler status PASS
NKI and JAX match
Note that the array values you see will differ from what’s printed above, since this example uses jax.random.uniform to initialize the inputs.
Download All Source Code#
Click the links to download source code of the kernels and the testing code discussed in this tutorial.
NKI baremetal implementation:
spmd_tensor_addition_nki_kernels.py
- PyTorch implementation:
spmd_tensor_addition_torch.py
You must also download
spmd_tensor_addition_nki_kernels.py
into the same folder to run this PyTorch script.
- PyTorch implementation:
- JAX implementation:
spmd_tensor_addition_jax.py
You must also download
spmd_tensor_addition_nki_kernels.py
into the same folder to run this PyTorch script.
- JAX implementation:
You can also view the source code in the GitHub repository nki_samples
Example usage of the scripts:#
Run NKI baremetal implementation:
python3 spmd_tensor_addition_nki_kernels.py
Run PyTorch implementation:
python3 spmd_tensor_addition_torch.py
Run JAX implementation:
python3 spmd_tensor_addition_jax.py
This document is relevant for: Inf2
, Trn1
, Trn2