2.6 Knowledge Distillation

Knowledge Distillation #

Knowledge distillation (KD) refers to a principled way to transfer the capabilities of a large teacher model into a smaller, faster student model – often achieving substantial reductions in latency, memory footprint, and cost. In inference optimization pipelines, KD can often be used in conjunction with other model compression techniques like quantization. The following schematic captures the general idea of KD:

Knowledge distillation

Knowledge Distillation

In the above figure, for each example $x$ in the dataset, the teacher generates a predictive distribution $p_t(y \vert x)$ over outputs $y$, and the student generates its own distribution $p_s(y \vert x)$. The distillation loss $\ell(p_t, p_s)$ trains the student to mimic the teacher.

A smaller model distilled from a large teacher consistently outperforms the same model trained from scratch with the same data budget. This happens for several fundamental reasons – rooted in optimization, representational capacity, data distributions, and the nature of supervision signals, which are intuitively described below:

  • The teacher provides a richer learning signal than one-hot labels. Since the soft distribution over labels is a much smoother optimization landscape, it allows the small student model to learn subtle patterns it could never discover from one-hot labels, when trained from scratch. Consequently, the effective sample complexity is reduced as the teacher provides supervisory information for every token, not just the correct one.
  • Big models can learn higher-quality representations that small models can imitate but cannot discover. Rather than needing to infer long-range semantics or complex reasoning from raw text, the student only needs to approximate the teacher’s function, which is a simpler mapping than learning the function from scratch. It is often useful to consider the following analogy – a college student can learn calculus from a professor in months, but would not reinvent calculus on their own even in years.

Types of knowledge distillation #

While KD can encompasses diverse strategies, there a three major categories used for LLMs, which are described below:

Logit/Soft-label distillation #

This is the classic form of KD where the student learns directly from the teacher’s probability distribution over tokens (aka logits). Given an input $x$, let $z_t$ and $z_s$ be the logits of the teacher and student, respectively. Logits here refer to the raw, unnormalized scores before softmax. Then, $p_t = \sigma(z_t/T)$ and $p_s = \sigma(z_s/T)$ are the corresponding probability distributions over the the token, where $\sigma$ is the Softmax function, and $T$ is the temperature for sampling. For some $0 < \alpha \le 1$, the distillation loss used by the student is: $$L_{\rm KD} = (1 - \alpha)L_{\rm CE}(z_s, y) + \alpha T^2L_{\rm KL}(p_t \parallel p_s).$$ Here, $L_{\rm CE}$ is the cross-entropy loss between the student’s logits and the ground truth labels $y$, and $L_{\rm KL}$ is the KL divergence between the teacher’s and the student’s logits. $\alpha$ blends the two types of loss. Note: $\alpha = 1$ refers to pure logit-level distillation. This type of KD ignores the teacher’s intermediate computation, and only the logits from the teacher (which may be available through an API) suffice. This is the most common type of KD, and is primarily what Minitron uses.

Sequence-level distillation #

In sequence-level KD, instead of (or in addition to) matching logits, the student is trained on text sequences generated by the teacher using a standard cross-entropy training loss. This is sometimes also referred to as supervised-fine-tuning with synthetic data. Given an input $x$, the teacher generates a compute output sequence: $$y_t = (y_{t,1}, y_{t,2}, \ldots, y_{t,L})$$ of length $L$. the student is trained to predict this tacher generated sequence as if it were the ground-truth using the training objective: $$L_{\rm seq} = -\sum_{i=1}^{L}\log p_s(y_{t,i} \vert x, y_{t, <i})$$ Note that this uses teacher-generate labels instead of human-provided labels.

Sequence-level KD is effective because human datasets, in general, often contain noise and stylistic variance. On the other hand, teacher outputs from a single model produces a coherent, consistent behavior pattern, giving the student clean suppervision. Moreover, teacher-generated outputs can be produced for large datasets cheaply (without human labeling), providing abundant training data.

While logit-level KD transfers token-level distributions from the teacher to student, sequence-level KD transfers semantic structures such as reasoning traces and formatting preferences. This enables compressing models beyond logit fidelity since even though a small student cannot match the teacher’s softmax landscapes for reasoning tasks, it can still imitate the teacher’s reasoning patterns via generates sequences.

Quite importantly, sequence-level KD is especially beneficial for closed models (e.g., GPT-5, Claude, Gemini, etc.) because unlike open models, they often do not allow access to logits, hidden states, or internal probabilities.

Feature-level/Intermediate representation distillation #

In feature-level KD, the student is trained to match the teacher’s hidden states, attention maps, or intermediate embeddings. This can lead to dramatically stronger students, but at a significantly higher computational cost. For example, to match the attention-map, a Frobenius norm error loss like $$L_{\rm attn}= \sum_{(l,m)} \lVert A_s^{(m)} - A_t^{(l)} \rVert_F^2$$ may be used. Here, the attention scores of the $m^{\rm th}$ layer of the student is mapped to the $l^{\rm th}$ layer of the teacher.

The intuition behind why feature-level KD is effective lies in the fact that LLMs build hierarchical neural representations. For instance, lower-layers learn lexical and morphological features, mid-layers learn syntax and symantics, wheareas higher-layers learn abstraction and world knowledge. While it is difficult for a small student to learn these tructures from scratch, it can learn to imitate them from the teacher.

While very effective, it should be noted that feature-level KD is rare for distilling knowledge from large-scale LLMs because the computation and memory cost explode. To distill hidden states for a sequence length $S$, hidden dimension $d$, and across $L$ layers, all hidden states must be stored, which requires $O(LSd)$ memory, which blows up. Hence, logit-level KD and sequence-level KD are more popular for large scale LLMs. Feature-level KD is more common in vision models and smaller text models of relatively small size ($\leq 3{\rm B}$ params).

References #

  1. Sreenivas et. al., LLM Pruning and Distillation in Practice: The Minitron Approach https://arxiv.org/abs/2408.11796
  2. Hinton, Vinyals, Dean, Distilling the Knowledge in a Neural Network, NeurIPS 2014 https://arxiv.org/abs/1503.02531
  3. Romero et. al., FitNets: Hints for Thin Deep Nets, ICLR 2015 https://arxiv.org/abs/1412.6550
  4. Huang and Wang, Like What You Like: Knowledge Distill via Neuron Selectivity Transfer, 2017 https://arxiv.org/abs/1707.01219
  5. Zagoruyko and Komadakis, Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Transfer, ICLR 2017 https://arxiv.org/abs/1612.03928