Checkpoint Conversion#
The NxD Training library provides a versatile checkpoint conversion functionality, allowing seamless transition between different model styles. This tutorial aims to provide a comprehensive guide through the various use cases and demonstrate how to perform the checkpoint conversions.
Supported Model Architectures#
The checkpoint conversion functionality supports conversion of the following model styles to/from NxDT checkpoints:
HuggingFace (HF) style models
Megatron style models
Extends support for both GQA (Llama-3) and non-GQA models (Llama-2).
Conversion Scenarios and Usage#
The tool supports the following conversion scenarios. It internally
uses NeuronxDistributed (NxD)
to convert to/from checkpoints.
Run the following commands from the /examples/checkpoint_conversion_scripts/
directory:
Note
Important: You must set the
--hw_backend
argument correctly for your hardware. The sample commands below usetrn1
.Set
--hw_backend trn1
for Trainium (Trn1) hardwareSet
--hw_backend trn2
for Trainium 2 (Trn2) hardware
All example commands in this tutorial use trn1
. If you’re using Trn2,
remember to replace trn1
with trn2
in every command.
Ensure that the model configuration config.json file is present, as it is required for checkpoint conversions. It is suggested to use specific json files like examples . If not present, you will need to create it.
If your HF/custom checkpoint has multiple
.bin
or.pt
or.pth
files then merge and convert to a single file before conversion.
For conversion of non-GQA based models (e.g. Llama2), just set the --qkv_linear
argument to False
.
HF style model:
HF to NxDT checkpoint:
Command:
python3 checkpoint_converter.py --model_style hf --hw_backend trn1 --input_dir /home/ubuntu/pretrained_llama_3_8B_hf/pytorch_model.bin --output_dir /home/ubuntu/converted_hf_style_hf_to_nxdt_tp8pp4/ --save_xser True --config /home/ubuntu/pretrained_llama_3_8B_hf/config.json --tp_size 8 --pp_size 4 --n_layers 32 --kv_size_multiplier 1 --qkv_linear True --convert_from_full_state
This converts an HF-style checkpoint to an NxDT checkpoint.
NxDT to HF checkpoint:
Command:
python3 checkpoint_converter.py --model_style hf --hw_backend trn1 --input_dir ~/examples/nemo_experiments/hf_llama3_8B_SFT/2024-07-19_23-07-40/checkpoints/hf_llama3_8B--step=5-consumed_samples=160.0.ckpt/model --output_dir ~/converted_hf_style_nxdt_to_hf_tp8pp4/ --load_xser True --config ~/config.json --tp_size 8 --pp_size 4 --kv_size_multiplier 1 --qkv_linear True --convert_to_full_state
This converts an NxDT checkpoint to an HF-style checkpoint.
Megatron style model (non-GQA models: e.g., Llama-2, and GQA models: e.g., Llama-3):
HF to NxDT Megatron checkpoint:
Command:
python3 checkpoint_converter.py --model_style megatron --hw_backend trn1 --input_dir ~/megatron-tp8pp4-nxdt-to-hf4/checkpoint.pt --output_dir ~/meg_nxdt_hf3_nxdt3 --config ~/llama_gqa/config.json --save_xser True --tp_size 8 --pp_size 4 --n_layers 32 --kv_size_multiplier 1 --qkv_linear True --convert_from_full_state
This converts an HF-style checkpoint to an NxDT Megatron-style checkpoint.
NxDT Megatron checkpoint to HF:
Command:
python3 checkpoint_converter.py --model_style megatron --hw_backend trn1 --input_dir ~/examples/nemo_experiments/megatron_llama/2024-07-23_21-07-30/checkpoints/megatron_llama--step=5-consumed_samples=5120.0.ckpt/model --output_dir ~/megatron-tp8pp4-nxdt-to-hf4 --load_xser True --config ~/llama_gqa/config.json --tp_size 8 --pp_size 4 --kv_size_multiplier 1 --qkv_linear True --convert_to_full_state
This converts an NxDT Megatron-style checkpoint to an HF-style checkpoint (GQA-based model, see:
--qkv_linear
set toTrue
).
Key Arguments#
The checkpoint_converter.py
script supports the following key arguments:
--model_style
: Specifies the model style, either hf (HuggingFace: default) or megatron--hw_backend
: (required) Specifies the hardware backend either trn1 or trn2--input_dir
: (required) directory containing the input checkpoint--hf_model_name
: (optional) HuggingFace model identifier for directly converting models hosted on HuggingFace--output_dir
: (required) directory to save the converted checkpoint directory--save_xser
: Saves the checkpoint with torch_xla serialization--load_xser
: Loads the checkpoint with torch_xla serialization--convert_from_full_state
: Converts full model checkpoint to sharded model checkpoint--convert_to_full_state
: Converts sharded model checkpoint to full model checkpoint--config
: path to the model configuration file (create json file if not present)--tp_size
: tensor parallelism degree--pp_size
: pipeline parallelism degree--n_layers
: number of layers in the model--kv_size_multiplier
: key-value size multiplier--qkv_linear
: boolean to specify GQA/non-GQA models--fuse_qkv
: boolean to specify fused QKV in GQA models
We recommend enabling xser for significantly faster save and load times. Note that if the checkpoint is saved with xser, it can only be loaded with xser, and vice versa.
Conversion Example#
Assuming you have a pre-trained HF-style Llama3-8B model checkpoint looking similar to:
input_dir: /hf/checkpoint/pytorch_model.bin
$ ls /hf/checkpoint
-rw-r--r-- 1 user group 123 Aug 27 2024 pytorch_model.bin
Convert the HF-style checkpoint to an NxDT checkpoint on a single instance:
python3 checkpoint_converter.py --model_style hf --hw_backend trn1 --input_dir /hf/checkpoint/pytorch_model.bin --output_dir /nxdt/checkpoint --save_xser True --convert_from_full_state --config /path/to/config.json --tp_size 8 --pp_size 4 --n_layers 32 --kv_size_multiplier 1 --qkv_linear True --convert_from_full_state
This command will create an NxDT checkpoint in output_dir: /nxdt/checkpoint
and it will be sharded with (tp=8, pp=4) like:
$ ls /nxdt/checkpoint/model
-rw-r--r-- 1 user group 123 Aug 27 2024 dp_rank_00_tp_rank_00_pp_rank_00.pt
-rw-r--r-- 1 user group 456 Aug 27 2024 dp_rank_00_tp_rank_01_pp_rank_00.pt
...........................................................................
-rw-r--r-- 1 user group 789 Aug 27 2024 dp_rank_00_tp_rank_07_pp_rank_02.pt
-rw-r--r-- 1 user group 122 Aug 27 2024 dp_rank_00_tp_rank_07_pp_rank_03.pt
Direct HuggingFace Model Conversion#
Using the --hf_model_name
argument allows users to directly convert checkpoint files hosted on HuggingFace
without the need for manual downloading or merging of checkpoint files.
To use this feature, you can specify the HuggingFace model identifier using the --hf_model_name
argument.
The script will then download the model and convert it directly to the NxDT format.
Note
When using
--hf_model_name
, do not specify--input_dir
. These arguments are mutually exclusive.If both
--hf_model_name
and--input_dir
are specified, the script will prioritize--input_dir
and ignore--hf_model_name
You will be prompted to enter your HuggingFace API token. If you don’t have one, you can create it at https://huggingface.co/settings/tokens.
Ensure you have sufficient disk space to download and process the model.
Example usage:
python3 checkpoint_converter.py --model_style hf --hw_backend trn1 --hf_model_name "meta-llama/Llama-2-7b-hf" --output_dir /path/to/output --save_xser True --config /path/to/config.json --tp_size 8 --pp_size 4 --n_layers 32 --kv_size_multiplier 1 --qkv_linear False --convert_from_full_state
This command will download the Llama-2-7b model from HuggingFace. Convert it to NxDT format, and save it in the specified output directory.
Troubleshooting#
If you encounter an error related to HuggingFace authentication, ensure you’re using a valid API token.
If the download fails, check your internet connection and verify that the model identifier is correct.