Checkpoint Conversion#

The NxD Training library provides a versatile checkpoint conversion functionality, allowing seamless transition between different model styles. This tutorial aims to provide a comprehensive guide through the various use cases and demonstrate how to perform the checkpoint conversions.

Supported Model Architectures#

The checkpoint conversion functionality supports conversion of the following model styles to/from NxDT checkpoints:

  1. HuggingFace (HF) style models

  2. Megatron style models

Extends support for both GQA (Llama-3) and non-GQA models (Llama-2).

Conversion Scenarios and Usage#

The tool supports the following conversion scenarios. It internally uses NeuronxDistributed (NxD) to convert to/from checkpoints. Run the following commands from the /examples/checkpoint_conversion_scripts/ directory:


  1. Important: You must set the --hw_backend argument correctly for your hardware. The sample commands below use trn1.

    • Set --hw_backend trn1 for Trainium (Trn1) hardware

    • Set --hw_backend trn2 for Trainium 2 (Trn2) hardware

All example commands in this tutorial use trn1. If you’re using Trn2, remember to replace trn1 with trn2 in every command.

  1. Ensure that the model configuration config.json file is present, as it is required for checkpoint conversions. It is suggested to use specific json files like examples . If not present, you will need to create it.

  2. If your HF/custom checkpoint has multiple .bin or .pt or .pth files then merge and convert to a single file before conversion.

For conversion of non-GQA based models (e.g. Llama2), just set the --qkv_linear argument to False.

  1. HF style model:

    1. HF to NxDT checkpoint:


      python3 --model_style hf --hw_backend trn1 --input_dir /home/ubuntu/pretrained_llama_3_8B_hf/pytorch_model.bin --output_dir /home/ubuntu/converted_hf_style_hf_to_nxdt_tp8pp4/ --save_xser True --config /home/ubuntu/pretrained_llama_3_8B_hf/config.json --tp_size 8 --pp_size 4 --n_layers 32 --kv_size_multiplier 1 --qkv_linear True --convert_from_full_state

    This converts an HF-style checkpoint to an NxDT checkpoint.

    1. NxDT to HF checkpoint:


    python3 --model_style hf --hw_backend trn1 --input_dir ~/examples/nemo_experiments/hf_llama3_8B_SFT/2024-07-19_23-07-40/checkpoints/hf_llama3_8B--step=5-consumed_samples=160.0.ckpt/model --output_dir ~/converted_hf_style_nxdt_to_hf_tp8pp4/ --load_xser True --config ~/config.json --tp_size 8 --pp_size 4 --kv_size_multiplier 1 --qkv_linear True --convert_to_full_state

    This converts an NxDT checkpoint to an HF-style checkpoint.

  2. Megatron style model (non-GQA models: e.g., Llama-2, and GQA models: e.g., Llama-3):

    1. HF to NxDT Megatron checkpoint:


    python3 --model_style megatron --hw_backend trn1 --input_dir ~/megatron-tp8pp4-nxdt-to-hf4/ --output_dir ~/meg_nxdt_hf3_nxdt3 --config ~/llama_gqa/config.json --save_xser True --tp_size 8 --pp_size 4 --n_layers 32 --kv_size_multiplier 1 --qkv_linear True --convert_from_full_state

    This converts an HF-style checkpoint to an NxDT Megatron-style checkpoint.

    1. NxDT Megatron checkpoint to HF:


    python3  --model_style megatron --hw_backend trn1 --input_dir ~/examples/nemo_experiments/megatron_llama/2024-07-23_21-07-30/checkpoints/megatron_llama--step=5-consumed_samples=5120.0.ckpt/model --output_dir ~/megatron-tp8pp4-nxdt-to-hf4 --load_xser True --config ~/llama_gqa/config.json --tp_size 8 --pp_size 4 --kv_size_multiplier 1 --qkv_linear True --convert_to_full_state

    This converts an NxDT Megatron-style checkpoint to an HF-style checkpoint (GQA-based model, see: --qkv_linear set to True).

Key Arguments#

The script supports the following key arguments:

  • --model_style: Specifies the model style, either hf (HuggingFace: default) or megatron

  • --hw_backend: (required) Specifies the hardware backend either trn1 or trn2

  • --input_dir: (required) directory containing the input checkpoint

  • --hf_model_name: (optional) HuggingFace model identifier for directly converting models hosted on HuggingFace

  • --output_dir: (required) directory to save the converted checkpoint directory

  • --save_xser: Saves the checkpoint with torch_xla serialization

  • --load_xser: Loads the checkpoint with torch_xla serialization

  • --convert_from_full_state: Converts full model checkpoint to sharded model checkpoint

  • --convert_to_full_state: Converts sharded model checkpoint to full model checkpoint

  • --config: path to the model configuration file (create json file if not present)

  • --tp_size: tensor parallelism degree

  • --pp_size: pipeline parallelism degree

  • --n_layers: number of layers in the model

  • --kv_size_multiplier: key-value size multiplier

  • --qkv_linear: boolean to specify GQA/non-GQA models

  • --fuse_qkv: boolean to specify fused QKV in GQA models

We recommend enabling xser for significantly faster save and load times. Note that if the checkpoint is saved with xser, it can only be loaded with xser, and vice versa.

Conversion Example#

Assuming you have a pre-trained HF-style Llama3-8B model checkpoint looking similar to:

input_dir: /hf/checkpoint/pytorch_model.bin

$ ls /hf/checkpoint

-rw-r--r-- 1 user group 123 Aug 27 2024 pytorch_model.bin

Convert the HF-style checkpoint to an NxDT checkpoint on a single instance:

python3 --model_style hf --hw_backend trn1 --input_dir /hf/checkpoint/pytorch_model.bin --output_dir /nxdt/checkpoint --save_xser True --convert_from_full_state --config /path/to/config.json --tp_size 8 --pp_size 4 --n_layers 32 --kv_size_multiplier 1 --qkv_linear True --convert_from_full_state

This command will create an NxDT checkpoint in output_dir: /nxdt/checkpoint and it will be sharded with (tp=8, pp=4) like:

$ ls /nxdt/checkpoint/model

-rw-r--r-- 1 user group 123 Aug 27 2024
-rw-r--r-- 1 user group 456 Aug 27 2024
-rw-r--r-- 1 user group 789 Aug 27 2024
-rw-r--r-- 1 user group 122 Aug 27 2024

Direct HuggingFace Model Conversion#

Using the --hf_model_name argument allows users to directly convert checkpoint files hosted on HuggingFace without the need for manual downloading or merging of checkpoint files.

To use this feature, you can specify the HuggingFace model identifier using the --hf_model_name argument. The script will then download the model and convert it directly to the NxDT format.


  1. When using --hf_model_name, do not specify --input_dir. These arguments are mutually exclusive.

  2. If both --hf_model_name and --input_dir are specified, the script will prioritize --input_dir and ignore --hf_model_name

  3. You will be prompted to enter your HuggingFace API token. If you don’t have one, you can create it at

  4. Ensure you have sufficient disk space to download and process the model.

Example usage:

python3 --model_style hf --hw_backend trn1 --hf_model_name "meta-llama/Llama-2-7b-hf" --output_dir /path/to/output --save_xser True --config /path/to/config.json --tp_size 8 --pp_size 4 --n_layers 32 --kv_size_multiplier 1 --qkv_linear False --convert_from_full_state

This command will download the Llama-2-7b model from HuggingFace. Convert it to NxDT format, and save it in the specified output directory.


  • If you encounter an error related to HuggingFace authentication, ensure you’re using a valid API token.

  • If the download fails, check your internet connection and verify that the model identifier is correct.