This document is relevant for: Inf2, Trn1, Trn2

nki.isa.affine_select#

nki.isa.affine_select(pred, on_true_tile, on_false_value, *, mask=None, dtype=None, **kwargs)[source]#

Select elements between an input tile on_true_tile and a scalar value on_false_value according to a boolean predicate tile using GpSimd Engine. The predicate tile is calculated on-the-fly in the engine by evaluating an affine expression element-by-element as indicated in pred.

pred must meet the following requirements:

  • It must not depend on any runtime variables that can’t be resolved at compile-time.

  • It can’t be multiple masks combined using logical operators such as & and |.

For a complex predicate that doesn’t meet the above requirements, consider using nl.where.

The input tile on_true_tile, the calculated boolean predicate tile expressed by pred, and the returned output tile of this instruction must have the same shape. If the predicate value of a given position is True, the corresponding output element will take the element from on_true_tile in the same position. If the predicate value of a given position is False, the corresponding output element will take the value of on_false_value.

A common use case for affine_select is to apply a causal mask on the attention scores for transformer decoder models.

This instruction allows any float or 8-bit/16-bit integer data types for both the input data tile and output tile (see Supported Data Types for more information). The output tile data type is specified using the dtype field. If dtype is not specified, the output data type will be the same as the input data type of data. However, the data type of on_false_value must be float32, regardless of the input/output tile data types.

Estimated instruction cost:

GPSIMD_START + N GpSimd Engine cycles, where N is the number of elements per partition in on_true_tile and GPSIMD_START is the instruction startup overhead on GpSimdE, roughly 150 engine cycles.

Parameters:
  • pred – an affine expression that defines the boolean predicate

  • on_true_tile – an input tile for selection with a True predicate value

  • on_false_value – a scalar value for selection with a False predicate value

  • mask – (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see NKI API Masking for details)

  • dtype – (optional) data type to cast the output type to (see Supported Data Types for more information); if not specified, it will default to be the same as the data type of the input tiles, or whichever input type has the highest precision (see NKI Type Promotion for more information);

Returns:

an output tile with values selected from either on_true_tile or on_false_value according to the following equation: output[x] = (pred[x] > 0) ? on_true_tile[x] : on_false_value

Example:

import neuronxcc.nki.isa as nisa
import neuronxcc.nki.language as nl

##################################################################
# Example 1: Take tile a of shape [128, 128] and replace its
# upper triangle with -9984.0;
##################################################################
ix, iy = nl.mgrid[0:128, 0:128]
a = nl.load(a_tensor[ix, iy])

b = nisa.affine_select(pred=(iy <ix), on_true_tile=a[ix, iy], on_false_value=-9984.0)

nl.store(b_tensor[ix, iy], b)

This document is relevant for: Inf2, Trn1, Trn2