Distillation#
NeMo 2.0 offers an easy-to-enable Knowledge Distillation (KD) training setup. The following section explains how to use it.
KD involves using information from an existing trained model to train a second (usually smaller, faster) model, thereby “distilling” knowledge from one to the other.
Distillation has two primary benefits: faster convergence and higher final accuracy than traditional training.
In NeMo, distillation is enabled by the NVIDIA TensorRT Model Optimizer (ModelOpt) library – a library to optimize deep-learning models for inference on GPUs.
Logits-Distillation Process#
The logits-distillation process involves these steps:
Loads Checkpoints: Loads both the student and teacher model checkpoints. They must both support the same parallelism strategy.
Replaces Loss Function: Replaces the standard loss function with the KL-Divergence between the output logits.
Trains Models: Runs forward passes on both models, but executes the backward pass only on the student model.
Saves Checkpoints: Saves only the student model checkpoints, allowing it to be used later in the same manner as before.
Limitations#
Only GPT-based NeMo 2.0 checkpoints are supported.
Only logit-pair distillation is enabled for now.
Use NeMo-Run Recipes#
Note
Prerequisite: Before proceeding, please follow the example in Quickstart with NeMo-Run to familiarize yourself with NeMo-Run first.
import nemo_run as run
from nemo.collections import llm
from nemo.collections.llm.modelopt.recipes import distillation_recipe
recipe = distillation_recipe(
student_model_path="path/to/student/nemo2-checkpoint/",
teacher_model_path="path/to/teacher/nemo2-checkpoint/",
dir="./distill_logs", # Path to store logs and checkpoints
name="distill_testrun",
num_nodes=1,
num_gpus_per_node=8,
)
# Override the configuration with desired components:
recipe.data = run.Config(llm.PreTrainingDataModule, ...)
recipe.trainer.strategy.tensor_model_parallel_size = 8
...
run.run(recipe)
Use with torchrun
or Slurm#
Alternatively, you can run a traditional script with a finer degree of customization.
STUDENT_CKPT="path/to/student/nemo2-checkpoint/"
TEACHER_CKPT="path/to/teacher/nemo2-checkpoint/"
DATA_PATHS="1.0 path/to/tokenized/data"
SEQUENCE_LEN=8192
MICRO_BATCHSIZE=1
GLOBAL_BATCHSIZE=4
STEPS=100
TP=8
CP=1
PP=1
DP=1
NUM_NODES=1
DEVICES_PER_NODE=8
NAME="distill_testrun"
LOG_DIR="./distill_logs/"
launch_cmd="torchrun --nproc_per_node=$(($TP * $CP * $PP * $DP))"
${launch_cmd} scripts/llm/gpt_train.py \
--name ${NAME} \
--model_path ${STUDENT_CKPT} \
--teacher_path ${TEACHER_CKPT} \
--tp_size ${TP} \
--cp_size ${CP} \
--pp_size ${PP} \
--devices ${DEVICES_PER_NODE} \
--num_nodes ${NUM_NODES} \
--log_dir ${LOG_DIR} \
--max_steps ${STEPS} \
--gbs ${GLOBAL_BATCHSIZE} \
--mbs ${MICRO_BATCHSIZE} \
--data_paths ${DATA_PATHS} \
--seq_length ${SEQUENCE_LEN}
SFT Distillation#
To perform SFT Knowledge Distillation on a chat dataset, follow the script above and add the –tokenizer and –use-chat-data arguments as well. See scripts/llm/gpt_train.py for full argument descriptions.