This document is relevant for: Inf2
, Trn1
, Trn2
Matrix multiplication#
In this tutorial, we will start with a simple NKI matrix multiplication kernel and optimize it step by step. In doing so, we learn about:
The NKI syntax and programming model.
Layout, tiling, and memory management considerations when performing matrix multiplication in NKI.
Best practices for validating and benchmarking your custom kernel against a reference native torch implementation.
Basic compute kernel#
Fig. 77 illustrates how a simple matrix
multiplication: lhs [M, K] * rhs [K, N] = output [M, N]
would be mapped to the
Tensor Engine (TensorE) and SRAMs from its original mathematical view. Note, the PSUM
partition dimension is rotated 90 degrees from SBUF partition dimension solely for layout visualization.
The copy preserves the output
tile layout from PSUM to SBUF, by copying data from each PSUM partition
to the corresponding SBUF partition.
The NKI example below implements a compute kernel for a single-tile matrix
multiplication. It computes a 64(M) x 128(K) x 512 (N)
matrix
multiplication operation.
1@nki.jit
2def nki_matmul_basic_(lhsT, rhs):
3 """NKI kernel to compute a 64x128x512 matrix multiplication operation
4
5 Args:
6 lhsT: an input tensor of shape [128,64], a left hand side argument of the
7 matrix multiplication, delivered transposed for optimal performance
8 rhs: an input tensor of shape [128,512], a right hand side argument of the
9 matrix multiplication
10 Returns:
11 result: the resulting output tensor of shape [64,512]
12 """
13 result = nl.ndarray((64, 512), dtype=lhsT.dtype, buffer=nl.shared_hbm)
14
15 # Defining indexes for input LHS.T
16 # - Note: here we take LayoutConstraint #1 into account:
17 # "For MatMult, contraction axis must be mapped to P-dim"
18 i_lhsT_p, i_lhsT_f = nl.mgrid[0:128, 0:64]
19
20 # Defining indexes for input RHS
21 # - Note: here we take LayoutConstraint #1 into account:
22 # "For MatMult, contraction axis must be mapped to P-dim"
23 i_rhs_p, i_rhs_f = nl.mgrid[0:128, 0:512]
24
25 # Defining indexes for the output ([64,128]@[128,512] -> [64,512])
26 i_out_p, i_out_f = nl.mgrid[0:64, 0:512]
27
28 # Loading the inputs (HBM->SBUF)
29 # Note: here we take Tile dtype definition into account,
30 # which forces P-dim as the left most index
31 lhs_tile = nl.load(lhsT[i_lhsT_p, i_lhsT_f])
32 rhs_tile = nl.load(rhs[i_rhs_p, i_rhs_f])
33
34 # Perform the matrix-multiplication
35 # Note1: We set transpose_x to True, to indicate that the LHS input is transposed
36 # Note2: A NKI matmul instruction always writes to PSUM in float32 data-type
37 result_psum = nl.matmul(lhs_tile, rhs_tile, transpose_x=True)
38
39 # Copy the result from PSUM back to SBUF, and cast to expected output data-type
40 result_sbuf = nl.copy(result_psum, dtype=result.dtype)
41
42 # The result of a [64,128] x [128,512] matrix multiplication has a shape of [64, 512].
43 # This dictates which indices to use to address the result tile.
44 nl.store(result[i_out_p, i_out_f], value=result_sbuf)
45
46 return result
In this example, we define the NKI kernel as nki_matmul_basic_:
We define indices to access the LHS and RHS input tensors.
To adhere to NKI’s layout considerations (Layout Considerations), we map the contraction axis of both LHS and RHS to the P-dimension, which means we load LHS in transposed form.
To adhere to NKI’s tile size considerations (Tile Size Considerations), we limit the matmul instruction arguments to tiles of up to
[128,128]
for LHS, and[128,512]
for RHS.Using the
nl.load
operation, we load the inputs from HBM tensors to SBUF tiles.We then use the
nl.matmul
operation to perform the matrix multiplication. Note that we set thetranspose_x
argument toTrue
, since the LHS argument is transposed. Also note that the 64x128 dimension here actually under-utilizes the TensorE, but it helps to distinguish the M, K and N dimensions for education purposes in this first code example.nl.matmul
always writes its result to PSUM, and sincenl.store
only moves data from SBUF to HBM, we copy the multiplication result from PSUM back to SBUF usingnl.copy
.
We can then execute the kernel and verify correctness against the torch implementation as follows. Note that we use torch.allclose to tolerate numerical error inherent to floating-point arithmetic.
1device = xm.xla_device()
2cpu = torch.device('cpu')
3
4# Test the small workload with basic kernel
5lhs_small = torch.rand((64, 128), dtype=torch.bfloat16, device=device)
6rhs_small = torch.rand((128, 512), dtype=torch.bfloat16, device=device)
7
8# Run NKI kernel
9output_small = nki_matmul_basic_(lhs_small.T, rhs_small)
10
11# Run torch reference
12output_small_torch = torch.matmul(lhs_small, rhs_small)
13
14# Compare results
15print("Checking correctness of nki_matmul_basic")
16if torch.allclose(output_small_torch, output_small, atol=1e-4, rtol=1e-2):
17 print("NKI and Torch match")
18else:
19 print("NKI and Torch differ")
Tiling matrix multiplications#
So far, we’ve limited our matrix multiplication to the tile sizes
allowed by NKI’s tile size and layout constraints. Next, we’ll see how
to handle larger matrix multiplications. Let’s start with a pseudo-code
for tiling an [M,K] @ [K,N]
matrix-multiplication.
Note that we assume the left-hand-side matrix ([M,K]
) is already transposed
to LHS_T ([K,M]
) for optimal performance of the underlying TensorE.
# LHS_T: left-hand-side matmul argument (shape [K,M])
# RHS: right-hand-side matmul argument (shape [K,N])
# RES: matmul result (shape [M,N])
# Tile LHS_T free dimension
for m in range(0, M, 128):
# Tile RHS free dimension
for n in range(0, N, 512):
# Zero-out the accumulator buffer
accum = zeros((128, 512))
# Tile contraction dimension
for k in range(0, K, 128):
lhsT_tile = LHS_T[m : m+128, k : k+128]
rhs_tile = RHS[k : k+128, n : n+512]
accum += dot(lhsT_tile, rhs_tile)
RES[m : m+128, n : n+512] = accum
This form of tiling can be achieved in NKI as follows:
1@nki.jit
2def nki_matmul_tiled_(lhsT, rhs):
3 """NKI kernel to compute a matrix multiplication operation in a tiled manner
4
5 Args:
6 lhsT: an input tensor of shape [K,M], where both K and M are multiples for
7 128. It is the left-hand-side argument of the matrix multiplication,
8 delivered transposed for optimal performance.
9 rhs: an input tensor of shape [K,N], where K is a multiple of 128, and N
10 is a multiple of 512. It is the right-hand-side argument of the matrix
11 multiplication.
12 Returns:
13 result: the resulting output tensor of shape [M,N]
14 """
15
16 K, M = lhsT.shape
17 K_, N = rhs.shape
18 assert K == K_, "lhsT and rhs must have the same contraction dimension"
19 result = nl.ndarray((M, N), dtype=lhsT.dtype, buffer=nl.shared_hbm)
20
21 TILE_M = nl.tile_size.gemm_stationary_fmax # 128
22 TILE_K = nl.tile_size.pmax # 128
23 TILE_N = nl.tile_size.gemm_moving_fmax # 512
24
25 # Use affine_range to loop over tiles
26 for m in nl.affine_range(M // TILE_M):
27 for n in nl.affine_range(N // TILE_N):
28 # Allocate a tensor in PSUM
29 res_psum = nl.zeros((TILE_M, TILE_N), nl.float32, buffer=nl.psum)
30
31 for k in nl.affine_range(K // TILE_K):
32 # Declare the tiles on SBUF
33 lhsT_tile = nl.ndarray((TILE_K, TILE_M), dtype=lhsT.dtype, buffer=nl.sbuf)
34 rhs_tile = nl.ndarray((TILE_K, TILE_N), dtype=rhs.dtype, buffer=nl.sbuf)
35
36 # Load tiles from lhsT and rhs
37 lhsT_tile[...] = nl.load(lhsT[k * TILE_K:(k + 1) * TILE_K,
38 m * TILE_M:(m + 1) * TILE_M])
39 rhs_tile[...] = nl.load(rhs[k * TILE_K:(k + 1) * TILE_K,
40 n * TILE_N:(n + 1) * TILE_N])
41
42 # Accumulate partial-sums into PSUM
43 res_psum += nl.matmul(lhsT_tile[...], rhs_tile[...], transpose_x=True)
44
45 # Copy the result from PSUM back to SBUF, and cast to expected output data-type
46 res_sb = nl.copy(res_psum, dtype=result.dtype)
47 nl.store(result[m * TILE_M:(m + 1) * TILE_M, n * TILE_N:(n + 1) * TILE_N],
48 value=res_sb)
49
50 return result
Note the use of nl.mgrid
to define indices, this is the same as the
mgrid
in NumPy. It is similar to the other way to define indexes through
nl.arange
but it enables a more concise way to introduce indexes from
multiple dimensions. nl.affine_range
is used to define loop-level
iterators. The loops defined with affine_range
are not unrolled by the
compiler, which enables faster compilation.
There is an alternative way to implement this tiled matrix multiplication kernel
using the SPMD programming model. We can use the SPMD model to launch (M/128)
x (N/512)
instances of the kernel to complete the innermost loop. For more
details, refer to the SPMD programming model.
Optimization 1: Removing Redundant Loads#
Currently, every nl.matmul
is accompanied with two nl.load
calls in the
inner loop, both of which move data from HBM to SBUF. Let’s introduce a metric,
arithmetic intensity, to help understand why this is problematic. The arithmetic
intensity of a workload is defined as the number of computation operations
performed per byte of data accessed from HBM on average. The reason why we do
not consider data accessed from SBUF in this metric is because the SBUF
bandwidth (~20x higher than HBM) is high enough to sustain the peak computation
throughput in TensorE.
Fig. 78 shows the roofline model, which models the
relationship between arithmetic intensity of a workload and its achievable
performance on a given computing platform. To saturate TensorE in a
NeuronCore-v2, the arithmetic intensity threshold of a workload is 222
Flops/Byte for bfloat16
data type. Inside the inner loop of
nki_matmul_tiled_
, accessing lhsT_tile
and rhs_tile
requires
160 KB of data read from HBM, while the nl.matmul
call involves 16 MFlops.
This leads to an arithmetic intensity of 102, which is significantly lower than
the saturation threshold of 222. Therefore, nki_matmul_tiled_
operates in the memory bound region of the roofline model and under-utilizes
TensorE. To make the best out of TensorE, we need to improve the arithmetic
intensity of the matmul kernel.
With NKI, programmers can control when and how to load data from HBM into SBUF and also perform computation. We will demonstrate in the upcoming steps how to increase the arithmetic intensity of the matmul kernel using NKI, thereby maximizing the utilization of TensorE.
First, we notice that in nki_matmul_tiled_
, the same tiles from
lhsT
and rhs
matrices are loaded more than once across different
iterations of the inner loop. The following example reduces these redundant
loads through hoisting them out of the innermost loop.
1@nki.jit
2def nki_matmul_hoist_load_(lhsT, rhs):
3 """NKI kernel to compute a matrix multiplication operation in a tiled manner
4 while hoisting the load of the lhsT and rhs to outer loops.
5
6 Args:
7 lhsT: an input tensor of shape [K,M], where both K and M are multiples for
8 128. It is the left-hand-side argument of the matrix multiplication,
9 delivered transposed for optimal performance.
10 rhs: an input tensor of shape [K,N], where K is a multiple of 128, and N
11 is a multiple of 512. It is the right-hand-side argument of the matrix
12 multiplication.
13 Returns:
14 result: the resulting output tensor of shape [M,N]
15 """
16
17 K, M = lhsT.shape
18 K_, N = rhs.shape
19 assert K == K_, "lhsT and rhs must have the same contraction dimension"
20 result = nl.ndarray((M, N), dtype=lhsT.dtype, buffer=nl.shared_hbm)
21
22 TILE_M = nl.tile_size.gemm_stationary_fmax # 128
23 TILE_K = nl.tile_size.pmax # 128
24 TILE_N = nl.tile_size.gemm_moving_fmax # 512
25
26 # Define the indices (shape) of the tiles
27 i_lhsT = nl.mgrid[0:TILE_K, 0:TILE_M]
28 i_rhs = nl.mgrid[0:TILE_K, 0:TILE_N]
29 i_res = nl.mgrid[0:TILE_M, 0:TILE_N]
30
31 # Use affine_range to loop over tiles
32 for m in nl.affine_range(M // TILE_M):
33 # Load a whole column tiles from lhsT (with K * TILE_N numbers)
34 # This corresponds to the whole row in the original lhs
35 lhsT_tiles = nl.ndarray((K // TILE_K, nl.par_dim(TILE_K), TILE_N),
36 dtype=lhsT.dtype,
37 buffer=nl.sbuf)
38
39 for k in nl.affine_range(K // TILE_K):
40 # use `.p` for partition dimension and `.x` for the first free dimension
41 lhsT_tiles[k, i_lhsT.p, i_lhsT.x] = nl.load(lhsT[k * TILE_K + i_lhsT.p,
42 m * TILE_M + i_lhsT.x])
43
44 for n in nl.affine_range(N // TILE_N):
45
46 # Load a whole column tiles from rhs (with K * TILE_M numbers)
47 rhs_tiles = nl.ndarray((K // TILE_K, nl.par_dim(TILE_K), TILE_N),
48 dtype=rhs.dtype,
49 buffer=nl.sbuf)
50 for k in nl.affine_range(K // TILE_K):
51 rhs_tiles[k, i_rhs.p, i_rhs.x] = nl.load(rhs[k * TILE_K + i_rhs.p,
52 n * TILE_N + i_rhs.x])
53
54 # Allocate a tile in PSUM for the result
55 res_psum = nl.zeros((TILE_M, TILE_N), nl.float32, buffer=nl.psum)
56 for k in nl.affine_range(K // TILE_K):
57 # Accumulate partial-sums into PSUM
58 res_psum[...] += nl.matmul(lhsT_tiles[k, i_lhsT.p, i_lhsT.x],
59 rhs_tiles[k, i_rhs.p, i_rhs.x],
60 transpose_x=True)
61
62 # Copy the result from PSUM back to SBUF, and cast to expected output data-type
63 res_sb = nl.copy(res_psum, dtype=result.dtype)
64 nl.store(result[m * TILE_M + i_res.p, n * TILE_N + i_res.x], value=res_sb)
65
66 return result
Optimization 2: Reuse More Load Through Blocking#
While hoisting the load out of the innermost loop eliminates some redundant loads, we can push this further by reordering the computation and the associated memory accesses. The technique we are going to use is called blocking. Blocking explicitly improves temporal locality and reduces memory accesses. It is very similar to the tiling step we did earlier in spirit.
Note that we reserve the word “tile” for defining the granularity of computation and “tiling” for the previous optimization technique that maps the high-level computation onto multiple matrix multiplication instructions executed on the TensorE. TensorE processes a specific “tile size” in a single instruction, leveraging the inherent parallelism in matrix multiplication.
Here, we do blocking, by grouping the work associated with a set of tiles together at another loop nest level. Blocking effectively interleaves a set of compute instructions and loading (DMA) instructions. This optimization does not bring us additional parallelism in computation, but rather improve the arithmetic intensity. This shifts a memory-bound matrix multiplication implementation to a compute-bound one, in order to fully leverage the compute capabilities of TensorE.
Fig. 80 below visualizes the memory pattern after blocking both free dimensions.
1@nki.jit
2def nki_matmul_block_free_dimension_(lhsT, rhs):
3 """NKI kernel to compute a matrix multiplication operation while blocking the
4 free dimensions of the LHS and RHS to improve memory access pattern.
5
6 Args:
7 lhsT: an input tensor of shape [K,M], where both K and M are multiples for
8 128. It is the left-hand-side argument of the matrix multiplication,
9 delivered transposed for optimal performance.
10 rhs: an input tensor of shape [K,N], where K is a multiple of 128, and N
11 is a multiple of 512. It is the right-hand-side argument of the matrix
12 multiplication.
13 Returns:
14 result: the resulting output tensor of shape [M,N]
15 """
16
17 K, M = lhsT.shape
18 K_, N = rhs.shape
19 assert K == K_, "lhsT and rhs must have the same contraction dimension"
20 result = nl.ndarray((M, N), dtype=lhsT.dtype, buffer=nl.shared_hbm)
21
22 TILE_M = nl.tile_size.gemm_stationary_fmax # 128
23 TILE_K = nl.tile_size.pmax # 128
24 TILE_N = nl.tile_size.gemm_moving_fmax # 512
25
26 # Define the indices (shape) of the tiles
27 i_lhsT = nl.mgrid[0:TILE_K, 0:TILE_M]
28 i_rhs = nl.mgrid[0:TILE_K, 0:TILE_N]
29 i_res = nl.mgrid[0:TILE_M, 0:TILE_N]
30
31 # Configuring the blocking size for the free dimensions
32 TILES_IN_BLOCK_M = 2
33 TILES_IN_BLOCK_N = 2
34
35 BLOCK_M = TILE_M * TILES_IN_BLOCK_M # 256
36 BLOCK_N = TILE_N * TILES_IN_BLOCK_N # 1024
37
38 # the size has to be multiple of block size
39 assert M % BLOCK_M == 0
40 assert N % BLOCK_N == 0
41
42 # Loop over blocks over the M dimension
43 for m in nl.affine_range(M // BLOCK_M):
44 # Load TILES_IN_BLOCK_M columns tiles from lhsT
45 lhsT_tiles = nl.ndarray(
46 (TILES_IN_BLOCK_M, K // TILE_K, nl.par_dim(TILE_K), TILE_M),
47 dtype=lhsT.dtype,
48 buffer=nl.sbuf)
49 for bm in nl.affine_range(TILES_IN_BLOCK_M):
50 for k in nl.affine_range(K // TILE_K):
51 lhsT_tiles[bm, k, i_lhsT.p, i_lhsT.x] = nl.load(
52 lhsT[k * TILE_K + i_lhsT.p,
53 (m * TILES_IN_BLOCK_M + bm) * TILE_M + i_lhsT.x])
54
55 for n in nl.affine_range(N // BLOCK_N):
56 # Load TILES_IN_BLOCK_N columns from rhs
57 rhs_tiles = nl.ndarray(
58 (TILES_IN_BLOCK_N, K // TILE_K, nl.par_dim(TILE_K), TILE_N),
59 dtype=rhs.dtype,
60 buffer=nl.sbuf)
61 for bn in nl.affine_range(TILES_IN_BLOCK_N):
62 for k in nl.affine_range(K // TILE_K):
63 rhs_tiles[bn, k, i_rhs.p, i_rhs.x] = nl.load(
64 rhs[k * TILE_K + i_rhs.p,
65 (n * TILES_IN_BLOCK_N + bn) * TILE_N + i_rhs.x])
66
67 for bm in nl.affine_range(TILES_IN_BLOCK_M):
68 for bn in nl.affine_range(TILES_IN_BLOCK_N):
69 # Allocate a tensor in PSUM
70 res_psum = nl.zeros((TILE_M, TILE_N), nl.float32, buffer=nl.psum)
71 for k in nl.affine_range(K // TILE_K):
72 # Accumulate partial-sums into PSUM
73 res_psum += nl.matmul(lhsT_tiles[bm, k, i_lhsT.p, i_lhsT.x],
74 rhs_tiles[bn, k, i_rhs.p, i_rhs.x],
75 transpose_x=True)
76
77 # Copy the result from PSUM back to SBUF, and cast to expected output data-type
78 res_sb = nl.copy(res_psum, dtype=result.dtype)
79 nl.store(result[(m * TILES_IN_BLOCK_M + bm) * TILE_M + i_res.p,
80 (n * TILES_IN_BLOCK_N + bn) * TILE_N + i_res.x],
81 value=res_sb)
82
83 return result
Optimization 3: Further Blocking and DMA Efficiency Optimization#
Next, let’s also consider blocking the contraction dimension.
Without blocking the contraction dimension, each block of computation leads to
the final result of each output block directly, since the input blocks in both
lhs_T
and rhs
cover the entire contraction dimension. After contraction
dimension blocking, the accumulation is separated into different groups.
We can accumulate the partial sum from each computation block back to an
SBUF tensor for the final result.
A small amount of HBM traffic might also be
introduced if the partial sum cannot be kept in SBUF before being consumed.
On the bright side, we can increase the block size for the free dimensions,
which continues to improve the arithmetic intensity.
One final step we can do with NKI is to optimize the layout of the loaded tiles to improve DMA efficiency. This is done through arranging the order of dimensions in nl.ndarray and marking the partition dimension.
By putting all these optimizations together, we can use NKI to implement optimized matrix multiplication for different sizes. Note that different sizes of input matrices require different optimization plans. The following code optimizes for large matrix multiplication where the free dimensions of both input matrices are multiples of 2048 and the contraction dimension is a multiple of 512.
With the blocking configuration in the code (16 tiles or 2048 numbers in the
M
dimension; 2 tiles or 1024 numbers in the N
dimension; and 8 tiles or
1024 numbers in the K
dimension), this computation has an arithmetic
intensity of 683 Flops/Byte (2048*1024*1024/(2048*1024 + 1024*1024)). This is
certainly above the threshold of 222.
At the same time, this blocking configuration keeps all the tensors within the
SBUF limit as much as possible. With all matrices in BF16 data type, the
lhsT_tiles
requires 4MB and rhs_tiles
requires 2MB SBUF memory. The
result_tiles
requires 4 * NUM_BLOCK_M
MB SBUF memory, where
NUM_BLOCK_M
is M // 2048
. Thus, as long as M <= 8192
, the required
SBUF memory is under the 24 MB budget (4 + 2 + 4 * (8192 // 2048) == 22 MB).
When the M
dimension becomes bigger, spilling and reloading of the
result_tiles
will happen, but because the frequency is relatively low, the
computation can still be sufficient.
Since the K blocking loop is hand optimized for our ideal data locality, we do
not actually want the compiler to rewrite this loop during its vectorization and
other loop-level optimization passes. To communicate this we use
nl.sequential_range()
to construct the K blocking loop.
1@nki.jit
2def nki_matmul_fully_optimized_(
3 lhsT,
4 rhs,
5 # Meta-parameters
6 TILES_IN_BLOCK_M=16,
7 TILES_IN_BLOCK_N=2,
8 TILES_IN_BLOCK_K=8,
9):
10 """NKI kernel to compute a large matrix multiplication efficiently by
11 blocking all dimensions and doing layout optimization.
12
13 Args:
14 lhsT: an input tensor of shape [K,M], where K is a multiple of 128 *
15 TILES_IN_BLOCK_K and M is a multiple of 128 * TILES_IN_BLOCK_M. It is the
16 left-hand-side argument of the matrix multiplication, delivered transposed
17 for optimal performance.
18 rhs: an input tensor of shape [K,N], where K is a multiple of 128 *
19 TILES_IN_BLOCK_K and N is a multiple of 512 * TILES_IN_BLOCK_N. It is
20 the right-hand-side argument of the matrix multiplication.
21 TILES_IN_BLOCK_*: meta parameters to control blocking dimensions
22 Returns:
23 result: the resulting output tensor of shape [M,N]
24 """
25
26 K, M = lhsT.shape
27 K_, N = rhs.shape
28 assert K == K_, "lhsT and rhs must have the same contraction dimension"
29 result = nl.ndarray((M, N), dtype=lhsT.dtype, buffer=nl.shared_hbm)
30
31 TILE_M = nl.tile_size.gemm_stationary_fmax # 128
32 TILE_K = nl.tile_size.pmax # 128
33 TILE_N = nl.tile_size.gemm_moving_fmax # 512
34
35 BLOCK_M = TILE_M * TILES_IN_BLOCK_M
36 BLOCK_N = TILE_N * TILES_IN_BLOCK_N
37 BLOCK_K = TILE_K * TILES_IN_BLOCK_K
38
39 # the size has to be multiple of block size
40 assert M % BLOCK_M == 0
41 assert N % BLOCK_N == 0
42 assert K % BLOCK_K == 0
43
44 NUM_BLOCK_M = M // BLOCK_M
45 NUM_BLOCK_N = N // BLOCK_N
46 NUM_BLOCK_K = K // BLOCK_K
47
48 # Blocking N dimension (the RHS free dimension)
49 for n in nl.affine_range(NUM_BLOCK_N):
50 result_tiles = nl.zeros((NUM_BLOCK_M, TILES_IN_BLOCK_M, TILES_IN_BLOCK_N,
51 nl.par_dim(TILE_M), TILE_N),
52 dtype=lhsT.dtype,
53 buffer=nl.sbuf)
54
55 # Blocking K dimension (the contraction dimension)
56 # Use `sequential_range` because we do not want the compiler to change this loop by,
57 # for example, vectorizing it
58 for k in nl.sequential_range(NUM_BLOCK_K):
59 # Loading tiles from rhs
60 # setting the load tile to `TILE_K x BLOCK_SIZE_N` to optimize DMA performance
61 i_rhs = nl.mgrid[0:TILE_K, 0:BLOCK_N]
62 rhs_tiles = nl.ndarray((TILES_IN_BLOCK_K, nl.par_dim(TILE_K), BLOCK_N),
63 dtype=rhs.dtype,
64 buffer=nl.sbuf)
65
66 for bk_r in nl.affine_range(TILES_IN_BLOCK_K):
67 rhs_tiles[bk_r, i_rhs.p, i_rhs.x] = nl.load(
68 rhs[(TILES_IN_BLOCK_K * k + bk_r) * TILE_K + i_rhs.p,
69 BLOCK_N * n + i_rhs.x])
70
71 # Blocking M dimension (the LHS free dimension)
72 for m in nl.affine_range(NUM_BLOCK_M):
73 # Loading tiles from lhsT
74 i_lhsT = nl.mgrid[0:TILE_K, 0:BLOCK_M]
75 lhsT_tiles = nl.ndarray((TILES_IN_BLOCK_K, nl.par_dim(TILE_K), BLOCK_M),
76 dtype=lhsT.dtype,
77 buffer=nl.sbuf)
78 for bk_l in nl.affine_range(TILES_IN_BLOCK_K):
79 lhsT_tiles[bk_l, i_lhsT.p, i_lhsT.x] = nl.load(
80 lhsT[(TILES_IN_BLOCK_K * k + bk_l) * TILE_K + i_lhsT.p,
81 BLOCK_M * m + i_lhsT.x])
82
83 # Do matmul with all tiles in the blocks
84 i_lhsT_mm = nl.mgrid[0:TILE_K, 0:TILE_M]
85 i_rhs_mm = nl.mgrid[0:TILE_K, 0:TILE_N]
86 i_res_mm = nl.mgrid[0:TILE_M, 0:TILE_N]
87 for bn in nl.affine_range(TILES_IN_BLOCK_N):
88 for bm in nl.affine_range(TILES_IN_BLOCK_M):
89 res_tile = nl.zeros((TILE_M, TILE_N), dtype=nl.float32, buffer=nl.psum)
90
91 for bk in nl.affine_range(TILES_IN_BLOCK_K):
92 res_tile[...] += nisa.nc_matmul(
93 lhsT_tiles[bk, i_lhsT_mm.p, bm * TILE_M + i_lhsT_mm.x],
94 rhs_tiles[bk, i_rhs_mm.p, bn * TILE_N + i_rhs_mm.x])
95
96 # Accumulate on corresponding SBUF tile
97 result_tiles[m, bm, bn, i_res_mm.p,
98 i_res_mm.x] += res_tile[i_res_mm.p, i_res_mm.x]
99
100 # Copying the result from SBUF to HBM
101 for m in nl.affine_range(NUM_BLOCK_M):
102 for bm in nl.affine_range(TILES_IN_BLOCK_M):
103 i_res = nl.mgrid[0:TILE_K, 0:TILE_N]
104 i_res_packed = nl.mgrid[0:TILE_K, 0:BLOCK_N]
105 result_packed = nl.ndarray((TILE_K, BLOCK_N),
106 dtype=result_tiles.dtype,
107 buffer=nl.sbuf)
108
109 # coalesce result tiles for better DMA performance
110 for bn in nl.affine_range(TILES_IN_BLOCK_N):
111 result_packed[i_res.p,
112 bn * TILE_N + i_res.x] = nl.copy(result_tiles[m, bm, bn,
113 i_res.p,
114 i_res.x])
115 nl.store(result[(TILES_IN_BLOCK_M * m + bm) * TILE_K + i_res_packed.p,
116 BLOCK_N * n + i_res_packed.x],
117 value=result_packed[i_res_packed.p, i_res_packed.x])
118
119 return result
Testing Correctness and Benchmarking#
To test the correctness of the kernels, we compare the result with the
torch.matmul
with torch.allclose
.
1# Test the large workload with tiled kernels
2lhs = torch.rand((4096, 1024), dtype=torch.bfloat16, device=device)
3rhs = torch.rand((1024, 2048), dtype=torch.bfloat16, device=device)
4
5# Run torch reference
6output_torch = torch.matmul(lhs, rhs).to(device=cpu)
7
8def check_match(nki_func):
9 output = nki_func(lhs.T, rhs)
10 output_nki = output.to(device=cpu)
11 if torch.allclose(output_torch, output_nki, atol=1e-4, rtol=1e-2):
12 print("NKI and Torch match")
13 else:
14 print("NKI and Torch differ")
15
16print("Checking correctness of nki_matmul_tiled")
17check_match(nki_matmul_tiled_)
18
19print("Checking correctness of nki_matmul_hoist_load")
20check_match(nki_matmul_hoist_load_)
21
22print("Checking correctness of nki_matmul_block_free_dimension")
23check_match(nki_matmul_block_free_dimension_)
24
25print("Checking correctness of nki_matmul_fully_optimized")
26check_match(nki_matmul_fully_optimized_)
Output from the test:
Checking correctness of nki_matmul_tiled
NKI and Torch match
Checking correctness of nki_matmul_hoist_load
NKI and Torch match
Checking correctness of nki_matmul_block_free_dimension
NKI and Torch match
Checking correctness of nki_matmul_fully_optimized
NKI and Torch match
To test for performance of each kernel here, we can use NKI’s benchmark
capability to measure the performance of the four different kernels on
[4096,8192] @ [8192,8192]
matrix multiplication.
1if __name__ == "__main__":
2 # Benchmarking with large matrices to show the differences more clearly
3 lhsT = nt.tensor[[8192, 4096], nl.bfloat16]
4 rhs = nt.tensor[[8192, 8192], nl.bfloat16]
5
6 def benchmark_nki(nki_func):
7 bench_func = nki.benchmark(warmup=5, iters=10)(nki_func)
8 bench_func(lhsT, rhs)
9 latency_res = bench_func.benchmark_result.nc_latency
10 p99 = latency_res.get_latency_percentile(99)
11 print("Latency: {:.2f} ms (P99)".format(p99 / 1000.0))
12
13 print("Benchmarking nki_matmul_tiled")
14 benchmark_nki(nki_matmul_tiled_)
15
16 print("Benchmarking nki_matmul_hoist_load")
17 benchmark_nki(nki_matmul_hoist_load_)
18
19 print("Benchmarking nki_matmul_block_free_dimension")
20 benchmark_nki(nki_matmul_block_free_dimension_)
21
22 print("Benchmarking nki_matmul_fully_optimized")
23 benchmark_nki(nki_matmul_fully_optimized_)
Kernels |
Latency (ms) |
Hardware FLOPs Utilization (HFU, %) |
---|---|---|
Original Tiled |
51.80 |
10.98 |
Optimization 1 |
42.96 |
13.24 |
Optimization 2 |
22.07 |
26.51 |
Optimization 3 |
6.97 |
85.24 |
As shown in the table above, with all the optimizations, the matrix
multiplication kernel is 7x faster comparing to the original tiled version. We
also profile the four different kernel implementations for the HFU (hardware
FLOPs utilization). With all the optimizations, the final version reaches a HFU
of 85.2%.
The performance numbers here are specific to input matrix sizes ([4096,8192] @
[8192,8192]
), data types (BF16), and server instance (Trn1.32xlarge).
Download All Source Code#
Click the links to download source code of the kernels and the testing code discussed in this tutorial.
All matrix multiplication NKI kernels:
matrix_multiplication_nki_kernels.py
PyTorch implementation:
matrix_multiplication_torch.py
You can also view the source code in the GitHub repository nki_samples
Example usage of the scripts:#
Run benchmarking of different NKI kernels:
python3 matrix_multiplication_nki_kernels.py
Run PyTorch implementation to validate the NKI results against the PyTorch implementation:
python3 matrix_multiplication_torch.py
This document is relevant for: Inf2
, Trn1
, Trn2