distillation
birdnet_stm32.training.distillation
¶
Knowledge distillation loss for training with soft teacher labels.
Implements a combined loss that blends the standard hard-label loss with a KL-divergence distillation loss from a teacher model's soft predictions.
DistillationLoss
¶
Bases: Loss
Combined hard-label + soft-label distillation loss.
Loss = (1 - alpha) * student_loss(y_true, y_pred) + alpha * T^2 * KL(softmax(teacher_logits/T) || softmax(student_logits/T))
For simplicity, this implementation accepts pre-computed soft labels (teacher probabilities) rather than teacher logits, and uses categorical crossentropy as the distillation term.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
alpha
|
float
|
Weight for the distillation loss (0 = pure hard labels, 1 = pure distillation). |
0.5
|
temperature
|
float
|
Softmax temperature for smoothing teacher predictions. |
3.0
|
student_loss
|
Loss | None
|
Base loss function for hard labels. |
None
|
Source code in birdnet_stm32/training/distillation.py
call(y_true, y_pred)
¶
Compute combined distillation loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_true
|
Ground truth labels. Should be a concatenation of [hard_labels, soft_labels] along the last axis, where hard_labels has shape [B, C] and soft_labels has shape [B, C]. Total shape: [B, 2*C]. |
required | |
y_pred
|
Student model predictions [B, C]. |
required |
Returns:
| Type | Description |
|---|---|
|
Scalar loss value. |