Migrating from Transformers NeuronX to NeuronX Distributed(NxD) Inference#
For customers who are currently using Transformers NeuronX, this migration guide explains the steps involved in migrating from Transformers NeuronX to NxD Inference library.
How is writing modeling code different in NxD Inference?#
In Transformers NeuronX, you write modeling code in HLO format using a Python HLO interface. In NeuronX Distributed Inference, you write modeling code in native PyTorch and Python, and the library converts it to HLO for you. This change makes it easier to develop models to run on Neuron, because you can start from existing Pytorch or Python modeling code.
How can I migrate from Transformers NeuronX to use NxD Inference with vLLM?#
Transformers NeuronX library currently supports Llama and Mistral model architectures with vLLM integration. If you are using one of these models, like Llama 3.1, Llama 3, Llama 2, or Mistral-7b-V2, you can migrate to use NxD Inference library with vLLM using the following steps:
Update Environment Variable to force vLLM to use NxD Inference#
As vLLM currently supports both Transformers NeuronX and NeuronX Distributed Inference libraries for the Llama and Mistral models, you need to update the following environment variable in the inference scripts to force vLLM to use NxD Inference.
# Force vLLM framework to use neuronx-distributed-inference
os.environ['VLLM_NEURON_FRAMEWORK'] = "neuronx-distributed-inference"
Compiling and loading the model#
Transformers NeuronX uses Neuron Persistent Cache to load a pre-compiled model so that there is no additional delay in compilation when loading the model on vLLM. NxD Inference currently does not support Neuron Persistent Cache but provides the following way to load a pre-compiled model in NeuronX Distributed Inference.
For production use cases where customer wants to avoid compiling the model in NxD Inference for the first time, users can set the environment variable NEURON_COMPILED_ARTIFACTS
which points to pre-compiled artifacts directory to avoid the compilation time. If the artifacts are not present within the specified directory, then compilation of the model would be triggered as a fallback mechanism and will store the artifacts by default in neuron-compiled-artifacts/{unique_hash}/
Features currently not supported in NxD Inference through vLLM#
NxD Inference doesn’t yet support the following features that TNx supports in vLLM integration.
Multi-Node Inference
Persistent Cache
concurrency > 1 support for speculation
Users can use exactly the same set of parameters to test out vLLM with NxD Inference library as they specify with Transformers NeuronX with the exception of override_neuron_config
. Both Transformers NeuronX and NxD Inference allows overriding available NeuronConfig, but not all NeuronConfig parameters that are available with Transformers NeuronX are still valid/applicable in NxD Inference. Refer to the Neuron Config Migration to migrate your override_neuron_config
params from Transformers NeuronX to NxD Inference.
Serialization support#
In both libraries, you serialize the compiled model, so you can use the model in subsequent runs without compiling it each time.
In Transformers NeuronX, the save function does not serialize sharded weights by default, and you can enable this functionality with the sharded_weights
flag. In NeuronX Distributed Inference, the compile
function serializes sharded weights by default, and you can disable this functionality with the save_sharded_checkpoint
flag in NeuronConfig
.
Tranformers NeuronX#
# Create and compile the Neuron model
neuron_config = NeuronConfig()
model_neuron = LlamaForSampling.from_pretrained(
'openlm-research/open_llama_3b',
batch_size=1,
tp_degree=8,
n_positions=128,
neuron_config=neuron_config
)
# Compile the model.
model_neuron.to_neuron()
# Save the presharded weights and compiled artifacts to a directory.
model_neuron.save('llama-artifacts', sharded_weights=True)
NeuronX Distributed Inference#
model_path = "/home/ubuntu/models/open_llama_3b"
compiled_model_path = "/home/ubuntu/compiled_models/open_llama_3b"
neuron_config = NeuronConfig(
batch_size=1,
tp_degree=8,
seq_len=128
)
config = LlamaInferenceConfig(
neuron_config,
load_config=load_pretrained_config(model_path)
)
model = NeuronLlamaForCausalLM(model_path, config)
# Compile the model, shard the weights, and save to the given path.
model.compile(compiled_model_path)
Models supported in Transformers NeuronX and NxD Inference model hubs#
The following table depicts the list of models currently supported by TNx and their status in the NxD Inference library. For a more detailed list of models currently supported in NeuronX Distributed Inference, please refer to NxD Inference model hub guide
Model |
Transformers NeuronX (TNx) |
NxD Inference (NxDI) |
||
---|---|---|---|---|
supported in TNx |
vLLM Support (TNx) |
supported in NxDI |
vLLM Support (NxD Inference) |
|
BLOOM |
Yes |
No |
No |
No |
GPT2 |
Yes |
No |
No |
No |
GPT-J |
Yes |
No |
No |
No |
GPT-Neox |
Yes |
No |
No |
No |
Llama 2 |
Yes |
Yes |
Yes |
Yes |
Llama 3 |
Yes |
Yes |
Yes |
Yes |
Llama 3.1 |
Yes |
Yes |
Yes |
Yes |
Llama 3.2 (1B and 3B) |
Yes |
Yes |
Yes |
Yes |
Llama 3.2 90B |
No |
No |
Yes |
Yes |
Mistral-V2 |
Yes |
Yes |
Yes |
Yes |
Mixtral |
Yes |
No |
Yes |
Yes |
DBRX |
No |
No |
Yes |
Yes |
Onboarding custom or private models with NxD Inference#
If you need model support for one of the models not currently supported in NxD Inference or if you have a private model that you currently implemented support in Transformers Neuronx, you need to implement the model using NxD Inference library. You can use the Onboarding models to run on NxD Inference guide.
Neuron Config Migration#
There are differences in Neuron Config parameters in Transformers NeuronX and NxD Inference libraries.
If you use TNx directly without vLLM, or if you use the override_neuron_config
param in vLLM with TNx, then you must update config parameters according to the following table.
Transformers NeuronX parameter |
NxD Inference parameter |
Notes |
---|---|---|
sparse_attn |
N/A |
|
quant.quant_dtype |
quantization_dtype |
To use quantization, set |
quant.dequant_dtype |
torch_dtype |
NxD Inference uses the inference dtype as the dequant dtype. |
quant.quantize_method |
quantization_type |
|
quant.quantize_attn |
N/A |
|
quant.no_quantize_list |
N/A |
|
kv_cache_quant.quant_dtype |
N/A |
NxD Inference uses FP8 (torch.float8_e4m3fn) for KV cache quantization. To use KV cache quantization, set |
kv_cache_quant.dequant_dtype |
torch_dtype |
NxD Inference uses the inference dtype as the dequant dtype. |
kv_cache_quant.quantize_method |
N/A |
NxD Inference uses direct cast. |
continuous_batching.max_num_seqs |
max_batch_size |
To use continuous batching, set |
continuous_batching.max_model_len |
seq_len |
|
continuous_batching.optimized_paged_attention |
N/A |
|
continuous_batching.block_size |
N/A |
|
continuous_batching.num_blocks |
N/A |
|
attention_layout |
N/A |
NxD Inference uses BHSD layout. |
collectives_layout |
N/A |
NxD Inference uses BHSD layout. |
cache_layout |
N/A |
NxD Inference uses BHSD layout. |
padding_side |
padding_side |
NxD Inference defaults to padding on the right side. |
group_query_attention |
N/A |
|
sequence_parallel_norm |
sequence_parallel_enabled |
|
sequence_parallel_norm_threshold |
N/A |
|
bf16_rms_norm |
N/A |
NxD Inference upcasts RMS norm inputs to fp32. |
on_device_embedding |
N/A |
|
on_device_generation |
on_device_sampling_config |
|
on_device_generation.max_length |
seq_len |
NxD Inference uses the model’s sequence length. |
on_device_generation.do_sample |
on_device_sampling_config.do_sample |
|
on_device_generation.top_k |
on_device_sampling_config.top_k |
NxD Inference supports top_k through dynamic sampling. Pass the top_k values to the model inputs. |
on_device_generation.top_p |
N/A |
NxD Inference supports top_p through dynamic sampling. Pass the top_p values to the model inputs. |
on_device_generation.temperature |
N/A |
NxD Inference supports temperature through dynamic sampling. Pass the temperature values to the model inputs. |
on_device_generation.top_p_min_tokens |
N/A |
NxD Inference defaults to a minimum of 1 token. |
on_device_generation.global_top_k |
on_device_sampling_config.global_topk |
|
on_device_generation.eos_token_id |
N/A |
NxD Inference sampling treats EOS like any other token. |
on_device_generation.dynamic |
on_device_sampling_config.dynamic |
|
on_device_generation.deterministic |
on_device_sampling_config.deterministic |
|
on_device_generation.per_batch_line |
N/A |
|
all_reduce_dtype |
rpl_reduce_dtype |
NxD Inference applies this dtype to only the all_reduce in attention’s |
cast_logits_dtype |
N/A |
|
fuse_qkv |
fused_qkv |
|
qkv_tiling |
N/A |
|
weight_tiling |
N/A |
|
mlp_in_weight_tiling_permute_order |
N/A |
|
mlp_out_weight_tiling_permute_order |
N/A |
|
mlp_out_weight_transpose |
N/A |
|
log_softmax_scores |
N/A |
|
shard_over_sequence |
flash_decoding_enabled |
|
duplicate_q_weight_sos |
N/A |
|
output_all_logits |
N/A |
|
fused_rmsnorm_qkv |
qkv_kernel_enabled |
|
fused_rmsnorm_mlp |
mlp_kernel_enabled |
|
attn_output_transposed |
N/A |
|
compilation_worker_count |
N/A |