Distillation#

NeMo 2.0 offers an easy-to-enable Knowledge Distillation (KD) training setup. The following section explains how to use it.

KD involves using information from an existing trained model to train a second (usually smaller, faster) model, thereby “distilling” knowledge from one to the other.

Distillation has two primary benefits: faster convergence and higher final accuracy than traditional training.

In NeMo, distillation is enabled by the NVIDIA TensorRT Model Optimizer (ModelOpt) library – a library to optimize deep-learning models for inference on GPUs.

Logits-Distillation Process#

The logits-distillation process involves these steps:

  1. Loads Checkpoints: Loads both the student and teacher model checkpoints. They must both support the same parallelism strategy.

  2. Replaces Loss Function: Replaces the standard loss function with the KL-Divergence between the output logits.

  3. Trains Models: Runs forward passes on both models, but executes the backward pass only on the student model.

  4. Saves Checkpoints: Saves only the student model checkpoints, allowing it to be used later in the same manner as before.

Limitations#

  • Only GPT-based NeMo 2.0 checkpoints are supported.

  • Only logit-pair distillation is enabled for now.

Use NeMo-Run Recipes#

Note

Prerequisite: Before proceeding, please follow the example in Quickstart with NeMo-Run to familiarize yourself with NeMo-Run first.

import nemo_run as run
from nemo.collections import llm
from nemo.collections.llm.modelopt.recipes import distillation_recipe

recipe = distillation_recipe(
    student_model_path="path/to/student/nemo2-checkpoint/",
    teacher_model_path="path/to/teacher/nemo2-checkpoint/",
    dir="./distill_logs",  # Path to store logs and checkpoints
    name="distill_testrun",
    num_nodes=1,
    num_gpus_per_node=8,
)

# Override the configuration with desired components:
recipe.data = run.Config(llm.PreTrainingDataModule, ...)
recipe.trainer.strategy.tensor_model_parallel_size = 8
...

run.run(recipe)

Use with torchrun or Slurm#

Alternatively, you can run a traditional script with a finer degree of customization.

STUDENT_CKPT="path/to/student/nemo2-checkpoint/"
TEACHER_CKPT="path/to/teacher/nemo2-checkpoint/"

DATA_PATHS="1.0 path/to/tokenized/data"
SEQUENCE_LEN=8192
MICRO_BATCHSIZE=1
GLOBAL_BATCHSIZE=4
STEPS=100

TP=8
CP=1
PP=1
DP=1
NUM_NODES=1
DEVICES_PER_NODE=8

NAME="distill_testrun"
LOG_DIR="./distill_logs/"


launch_cmd="torchrun --nproc_per_node=$(($TP * $CP * $PP * $DP))"

${launch_cmd} scripts/llm/gpt_train.py \
    --name ${NAME} \
    --model_path ${STUDENT_CKPT} \
    --teacher_path ${TEACHER_CKPT} \
    --tp_size ${TP} \
    --cp_size ${CP} \
    --pp_size ${PP} \
    --devices ${DEVICES_PER_NODE} \
    --num_nodes ${NUM_NODES} \
    --log_dir ${LOG_DIR} \
    --max_steps ${STEPS} \
    --gbs ${GLOBAL_BATCHSIZE} \
    --mbs ${MICRO_BATCHSIZE} \
    --data_paths ${DATA_PATHS} \
    --seq_length ${SEQUENCE_LEN}

SFT Distillation#

To perform SFT Knowledge Distillation on a chat dataset, follow the script above and add the –tokenizer and –use-chat-data arguments as well. See scripts/llm/gpt_train.py for full argument descriptions.