This document is relevant for: Inf2
, Trn1
, Trn2
nki.isa.affine_select#
- nki.isa.affine_select(pred, on_true_tile, on_false_value, *, mask=None, dtype=None, **kwargs)[source]#
Select elements between an input tile
on_true_tile
and a scalar valueon_false_value
according to a boolean predicate tile using GpSimd Engine. The predicate tile is calculated on-the-fly in the engine by evaluating an affine expression element-by-element as indicated inpred
.pred
must meet the following requirements:It must not depend on any runtime variables that can’t be resolved at compile-time.
It can’t be multiple masks combined using logical operators such as
&
and|
.
For a complex predicate that doesn’t meet the above requirements, consider using nl.where.
The input tile
on_true_tile
, the calculated boolean predicate tile expressed bypred
, and the returned output tile of this instruction must have the same shape. If the predicate value of a given position isTrue
, the corresponding output element will take the element fromon_true_tile
in the same position. If the predicate value of a given position isFalse
, the corresponding output element will take the value ofon_false_value
.A common use case for
affine_select
is to apply a causal mask on the attention scores for transformer decoder models.This instruction allows any float or 8-bit/16-bit integer data types for both the input data tile and output tile (see Supported Data Types for more information). The output tile data type is specified using the
dtype
field. Ifdtype
is not specified, the output data type will be the same as the input data type ofdata
. However, the data type ofon_false_value
must be float32, regardless of the input/output tile data types.Estimated instruction cost:
GPSIMD_START + N
GpSimd Engine cycles, whereN
is the number of elements per partition inon_true_tile
andGPSIMD_START
is the instruction startup overhead on GpSimdE, roughly 150 engine cycles.- Parameters:
pred – an affine expression that defines the boolean predicate
on_true_tile – an input tile for selection with a
True
predicate valueon_false_value – a scalar value for selection with a
False
predicate valuemask – (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see NKI API Masking for details)
dtype – (optional) data type to cast the output type to (see Supported Data Types for more information); if not specified, it will default to be the same as the data type of the input tiles, or whichever input type has the highest precision (see NKI Type Promotion for more information);
- Returns:
an output tile with values selected from either
on_true_tile
oron_false_value
according to the following equation: output[x] = (pred[x] > 0) ? on_true_tile[x] : on_false_value
Example:
import neuronxcc.nki.isa as nisa import neuronxcc.nki.language as nl ################################################################## # Example 1: Take tile a of shape [128, 128] and replace its # upper triangle with -9984.0; ################################################################## ix, iy = nl.mgrid[0:128, 0:128] a = nl.load(a_tensor[ix, iy]) b = nisa.affine_select(pred=(iy <ix), on_true_tile=a[ix, iy], on_false_value=-9984.0) nl.store(b_tensor[ix, iy], b)
This document is relevant for: Inf2
, Trn1
, Trn2