Training¶
Basic usage¶
python -m birdnet_stm32 train \
--data_path_train data/train \
--audio_frontend hybrid \
--mag_scale pwl \
--checkpoint_path checkpoints/my_model.keras
The script saves these files alongside the checkpoint:
my_model.keras— trained Keras modelmy_model_model_config.json— conversion metadata (frontend, shapes, etc.)my_model_labels.txt— ordered class namesmy_model_history.csv— per-epoch training metrics (loss, ROC-AUC)my_model_curves.png— loss and ROC-AUC training curves plotmy_model_train_state.json— epoch counter for--resume
Audio frontends¶
| Frontend | Input to model | Description |
|---|---|---|
hybrid (default) |
Linear magnitude STFT | Model applies a learned mel mixer and magnitude scaling. Best for deployment. |
librosa |
Mel spectrogram | Spectrogram computed offline with librosa. Simplest, but frontend is not in the graph. |
raw |
Raw waveform | Model learns the filterbank from scratch via Conv2D. Most flexible, highest memory. |
Deprecated aliases
precomputed (now librosa) and tf (now raw) still work but emit
deprecation warnings and will be removed in a future release.
Raw frontend memory limit
At 22 kHz × 3 s the raw input exceeds 65,536 samples (the 16-bit activation
size limit on the N6 NPU). Use hybrid or librosa for deployment, or
reduce sample rate / chunk duration.
Magnitude scaling¶
| Mode | Description | Quantization friendliness |
|---|---|---|
pwl (default) |
Piecewise-linear learned compression | Excellent — recommended for deployment |
pcen |
Per-channel energy normalization | Good — uses pooling + convolution |
db |
Log-scale (decibels) | Poor — avoid for quantized models |
none |
No compression | Baseline only |
Model architecture¶
The DS-CNN is scaled with two knobs:
--alpha(width multiplier): scales channel counts across all stages. Default 1.0. Values like 0.5 or 0.75 produce smaller models.--depth_multiplier: repeats each depthwise-separable block. Default 1. Increase to 2 for deeper models.
Channel alignment
Keep channel counts as multiples of 8 for optimal NPU vectorization. The
model builder enforces this automatically via _make_divisible.
Training options¶
Data augmentation¶
- Mixup: controlled by
--mixup_alpha(default 0.2, 0 disables) and--mixup_probability(default 0.25). Uses Dirichlet multi-source mixing (2–3 sources per sample) to realistically emulate overlapping bird vocalizations. Labels are combined via element-wise max. - SpecAugment: enabled by default. Applies random frequency and time
masking to spectrograms during training. Disable with
--no_spec_augment. Control mask widths with--freq_mask_max(default 8 bins) and--time_mask_max(default 25 frames). - Smart crop: long recordings (> 2 chunks) are automatically cropped to salient regions using short-time energy (STE) analysis, reducing label noise from silent or irrelevant segments.
- Multi-chunk I/O reuse: long files (e.g. 60 s recordings) yield up to
--max_chunks_per_file(default 3) salient chunks per file open, stored in a shuffled in-memory reservoir. This avoids redundant FLAC decode + resample for the same file across epochs.
Loss function¶
- Binary crossentropy (default): standard multi-label loss.
- Focal loss:
--loss focaldown-weights well-classified examples, focusing on hard negatives. Tune with--focal_gamma(default 2.0). Useful for imbalanced class distributions.
Optimizer¶
Select with --optimizer (default adam):
| Optimizer | Description |
|---|---|
adam |
Adaptive moment estimation (default) |
sgd |
SGD with momentum 0.9 |
adamw |
AdamW with decoupled weight decay |
Set weight decay with --weight_decay (default 0, only used by adamw).
Deterministic mode¶
Training is always deterministic — all random seeds (Python, NumPy,
TensorFlow) are set and TF_DETERMINISTIC_OPS is enabled automatically.
Use --seed (default 42) to change the RNG seed.
Gradient clipping¶
Gradient clipping by global norm is enabled by default (--grad_clip 1.0).
Set to 0 to disable. Prevents exploding gradients, especially useful with
large models or unstable training.
Class weighting¶
Balanced inverse-frequency class weights are enabled by default. Use
--no_class_weights to disable. Useful for imbalanced datasets where some
species have fewer training files.
Mixed precision¶
Use --mixed_precision to enable FP16 compute with FP32 accumulation.
Reduces memory usage and speeds up training on GPUs with Tensor Cores.
Resumable training¶
Use --resume to continue training from a previously saved checkpoint.
The optimizer state is recompiled and training resumes from the last saved
epoch. Example:
# Initial training (interrupted or completed at epoch 30)
python -m birdnet_stm32 train --epochs 30 --checkpoint_path ckpt/model.keras ...
# Resume and extend to 50 epochs
python -m birdnet_stm32 train --epochs 50 --resume --checkpoint_path ckpt/model.keras ...
Quantization-Aware Training (QAT)¶
Use --qat to fine-tune a pretrained model with simulated INT8 quantization
noise. This closes the accuracy gap between the float Keras model and the
quantized TFLite model by teaching the weights to survive quantization.
QAT requires a pretrained model
Always train normally first, then fine-tune with --qat. Do not use
--qat from scratch — the quantization noise destabilizes randomly
initialized weights and the model will not converge. The dataset must
have the same classes as the pretrained model; use --linear_probe to
adapt to a different class set first.
QAT works by injecting fake-quantization noise into kernel weights during training while maintaining full-precision shadow copies. BatchNorm layers are frozen to prevent running statistics drift. No FakeQuant ops remain in the saved model, so the N6 NPU runs it without issues.
# Step 1: Normal training
python -m birdnet_stm32 train --data_path_train data/train \
--epochs 50 --checkpoint_path checkpoints/model.keras
# Step 2: QAT fine-tuning (lower LR, fewer epochs)
python -m birdnet_stm32 train --data_path_train data/train --qat \
--checkpoint_path checkpoints/model.keras \
--epochs 10 --learning_rate 0.0001
# Step 3: Convert the QAT model
python -m birdnet_stm32 convert \
--checkpoint_path checkpoints/model_qat.keras \
--model_config checkpoints/model_model_config.json \
--data_path_train data/train
The QAT model is saved as {name}_qat.keras alongside the original.
Linear probing¶
Use --linear_probe to freeze a pretrained backbone and train only a new
classification head on your custom species dataset. This is useful when you
have a pretrained model (e.g. a large BirdNET checkpoint) and want to adapt
it to a different set of species with limited data.
python -m birdnet_stm32 train --data_path_train data/my_species \
--linear_probe --checkpoint_path checkpoints/pretrained.keras \
--epochs 20 --learning_rate 0.001
The probe model is saved as {name}_probe.keras with a new labels file.
Learning rate¶
Cosine decay schedule from --learning_rate (default 0.001) to near-zero
over --epochs (default 50). Early stopping on validation loss with patience
of 10 epochs.
Hyperparameter tuning with Optuna¶
Use --tune to run an automated hyperparameter search using Optuna (requires
pip install -e ".[tune]"). The tuner explores alpha, depth_multiplier,
embeddings_size, learning_rate, dropout, batch_size, mixup_alpha,
label_smoothing, optimizer, weight_decay, grad_clip, use_se,
use_inverted_residual, use_attention_pooling, se_reduction, and
expansion_factor. It maximizes val_roc_auc with MedianPruner.
Set --n_trials to control how many configurations to try (default 20).
Full argument reference¶
| Argument | Default | Description |
|---|---|---|
--data_path_train |
(required) | Path to training data |
--max_samples |
None | Max files per class |
--upsample_ratio |
0.5 | Minority class upsample ratio |
--sample_rate |
24000 | Audio sample rate (Hz) |
--num_mels |
64 | Number of mel frequency bins |
--spec_width |
256 | Spectrogram width (frames) |
--fft_length |
512 | FFT window length |
--chunk_duration |
3 | Chunk duration (seconds) |
--max_duration |
60 | Max seconds to load per file |
--audio_frontend |
hybrid | librosa, hybrid, raw, mfcc, or log_mel |
--mag_scale |
pwl | pwl, pcen, db, or none |
--embeddings_size |
256 | Embedding channels before head |
--alpha |
1.0 | Model width scaling |
--depth_multiplier |
1 | Block repeats per stage |
--frontend_trainable |
False | Make frontend weights trainable |
--mixup_alpha |
0.2 | Mixup alpha (0 disables) |
--mixup_probability |
0.25 | Fraction of batch to mix |
--no_spec_augment |
False | Disable SpecAugment masking (on by default) |
--freq_mask_max |
8 | Max frequency mask width (bins) |
--time_mask_max |
25 | Max time mask width (frames) |
--dropout |
0.5 | Dropout rate before classifier head |
--optimizer |
adam | adam, sgd, or adamw |
--weight_decay |
0.0 | Weight decay (adamw only) |
--loss |
auto | auto (BCE) or focal |
--focal_gamma |
2.0 | Focal loss focusing parameter |
--label_smoothing |
0.1 | Label smoothing factor (0 = off) |
--no_se |
False | Disable SE channel attention (on by default) |
--se_reduction |
8 | SE channel reduction factor |
--no_inverted_residual |
False | Use plain DS blocks (inverted residuals on by default) |
--expansion_factor |
2 | Expansion factor for inverted residuals |
--use_attention_pooling |
False | Use attention pooling instead of GAP |
--n_mfcc |
20 | Number of MFCC coefficients (mfcc frontend only) |
--grad_clip |
1.0 | Max gradient norm for clipping (0 = disabled) |
--no_class_weights |
False | Disable balanced class weighting (on by default) |
--mixed_precision |
False | Enable FP16 mixed precision training |
--resume |
False | Resume training from checkpoint |
--seed |
42 | Random seed |
--batch_size |
32 | Batch size |
--num_workers |
8 | Parallel data loading workers (0 = sequential) |
--max_chunks_per_file |
3 | Max salient chunks per file open (reduces redundant I/O) |
--epochs |
50 | Number of epochs |
--learning_rate |
0.001 | Initial learning rate |
--val_split |
0.2 | Validation split fraction |
--checkpoint_path |
checkpoints/best_model.keras | Output path (.keras) |
--tune |
False | Run Optuna hyperparameter search |
--n_trials |
20 | Number of Optuna trials |
--qat |
False | Quantization-aware fine-tuning |
--linear_probe |
False | Freeze backbone and train only classifier head |
Data pipeline¶
The training pipeline uses a multiprocessing pool for parallel data loading, bypassing the GIL so FLAC decode, resampling, smart-crop, and spectrogram computation run across separate CPU cores.
When --max_chunks_per_file is greater than 1 (default 3), each file open
extracts multiple salient chunks which are buffered in a shuffled in-memory
reservoir (~135 MB for 512 samples). This dramatically reduces I/O for
long recordings: a 60 s file decoded once yields 3 usable chunks instead of
re-opening the same file 3 times across epochs.
The reservoir maintains batch diversity by shuffling samples from many different files before yielding them. With a reservoir of 512 samples from ~200 different files, the probability of two chunks from the same file landing in one batch of 32 is negligible.
Tune with:
--num_workers N— number of worker processes (default 8, 0 = sequential)--max_chunks_per_file N— chunks per file open (default 3, 1 = original behavior)
Noise classes¶
Place audio in folders named noise, silence, background, or other
under data/train/. These receive all-zero label vectors and help the model
learn to reject non-bird sounds.