Model Export¶
Convert a trained checkpoint to portable inference formats for deployment.
Quick Start¶
python convert.py # ONNX FP16 (default)
python convert.py --formats onnx tflite_fp16 # multiple formats
python convert.py --formats all # everything
Supported Formats¶
| Format | Flag | Size (scale 0.75, 12K species) | Description |
|---|---|---|---|
| ONNX FP32 | onnx |
~15 MB | Full-precision ONNX model |
| ONNX FP16 | onnx_fp16 |
~7 MB | Half-precision ONNX (default) |
| TFLite FP32 | tflite |
~14 MB | TensorFlow Lite, full precision |
| TFLite FP16 | tflite_fp16 |
~7 MB | TensorFlow Lite, half precision |
| TFLite INT8 | tflite_int8 |
~4 MB | TensorFlow Lite, dynamic-range quantization |
| TF SavedModel | tf |
~14 MB | TensorFlow SavedModel directory |
Use --formats all to export everything at once.
CLI Reference¶
| Flag | Default | Description |
|---|---|---|
--checkpoint |
checkpoints/checkpoint_best.pt |
Source checkpoint |
--outdir |
exports |
Output directory |
--formats |
onnx_fp16 |
Space-separated list of formats (or all) |
--tol |
1e-4 |
Base tolerance for numerical validation |
--device |
auto |
Device for PyTorch reference model |
--fp16_io |
off | Convert model I/O to FP16 as well (see below) |
FP16 I/O Behavior¶
By default, ONNX FP16 exports keep model inputs and outputs in FP32 while converting internal weights and activations to FP16. This preserves full coordinate precision (latitude, longitude, week) and keeps LayerNormalization in FP32 for numerical stability.
Pass --fp16_io to convert I/O tensors to FP16 as well. This is slightly
smaller but increases numerical divergence from the PyTorch reference.
| Mode | Max diff (typical) | Notes |
|---|---|---|
| FP16 weights, FP32 I/O (default) | ~0.013 | Best accuracy, recommended |
Full FP16 (--fp16_io) |
~0.013 | Slightly lossy input precision |
Export Wrapper¶
All exported models use a unified interface:
- Input:
(batch, 3)float tensor — columns are[latitude, longitude, week] - Output:
(batch, n_species)float tensor — sigmoid probabilities
This differs from the raw PyTorch model which takes three separate tensors. The wrapper applies sigmoid internally so outputs are directly interpretable as probabilities.
Validation¶
After each conversion, the script runs automatic numerical validation:
- A fixed set of 200 reference inputs (diverse lat/lon/week combinations) is generated
- The PyTorch model produces reference outputs
- The exported model is loaded back and run on the same inputs
- Maximum and mean absolute differences are computed
- If the max diff exceeds the tolerance, the conversion is marked as FAIL
Tolerances are format-aware:
| Format | Effective tolerance |
|---|---|
| ONNX FP32 | tol (1e-4) |
| ONNX FP16 (FP32 I/O, default) | 0.06 |
ONNX FP16 (--fp16_io) |
0.05 |
| TFLite FP32 | tol (1e-4) |
| TFLite FP16 | tol × 10 |
| TFLite INT8 | tol × 100 |
| TF SavedModel | tol (1e-4) |
Dependencies¶
ONNX formats require:
TFLite and TF SavedModel additionally require:
TFLite runtime requirement
This model uses GELU, which maps to tf.Erf during conversion. TFLite
therefore emits Select TF Ops (FlexErf) for FP32/FP16 exports. Run
these models with a TensorFlow Lite runtime that includes the Flex
delegate (SELECT_TF_OPS).
The script will print a clear error message if a required package is missing — it does not fail silently.
Output Structure¶
exports/
├── geomodel.onnx # ONNX FP32
├── geomodel_fp16.onnx # ONNX FP16
├── geomodel.tflite # TFLite FP32
├── geomodel_fp16.tflite # TFLite FP16
├── geomodel_int8.tflite # TFLite INT8
├── saved_model/ # TF SavedModel
├── labels.txt # Species vocabulary (copied from checkpoint dir)
└── MODEL_LICENSE.txt # Model weights license (CC BY-SA 4.0)
Running Exported Models¶
All exported formats share the same interface:
- Input:
(batch, 3)float32 tensor — columns are[latitude, longitude, week] - Output:
(batch, n_species)float32 tensor — sigmoid probabilities
Loading labels¶
def load_labels(path="exports/labels.txt"):
"""Load species labels from labels.txt."""
labels = []
with open(path) as f:
for line in f:
parts = line.strip().split("\t")
labels.append({"code": parts[0], "sci": parts[1], "common": parts[2]})
return labels
PyTorch (.pt)¶
import torch
import numpy as np
from model.model import BirdNETGeoModel
checkpoint = torch.load("checkpoints/checkpoint_best.pt", map_location="cpu",
weights_only=False)
cfg = checkpoint["model_config"]
model = BirdNETGeoModel(
n_species=cfg["n_species"],
n_env_features=cfg["n_env_features"],
model_scale=cfg["model_scale"],
coord_harmonics=cfg.get("coord_harmonics", 8),
week_harmonics=cfg.get("week_harmonics", 8),
)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
lat = torch.tensor([52.5])
lon = torch.tensor([13.4])
week = torch.tensor([22.0])
with torch.no_grad():
logits = model(lat, lon, week, return_env=False)["species_logits"]
probs = torch.sigmoid(logits) # (1, n_species)
labels = load_labels("checkpoints/labels.txt")
top = probs[0].topk(10)
for i, idx in enumerate(top.indices):
print(f"{i+1}. {labels[idx]['common']}: {top.values[i]:.3f}")
ONNX¶
import numpy as np
import onnxruntime as ort
session = ort.InferenceSession("exports/geomodel_fp16.onnx")
inputs = np.array([[52.5, 13.4, 22.0]], dtype=np.float32) # (batch, 3)
probs = session.run(None, {"input": inputs})[0] # (batch, n_species)
labels = load_labels()
top_k = 10
top_indices = np.argsort(probs[0])[::-1][:top_k]
for i, idx in enumerate(top_indices):
print(f"{i+1}. {labels[idx]['common']}: {probs[0][idx]:.3f}")
TFLite¶
import numpy as np
import tensorflow as tf
interpreter = tf.lite.Interpreter(model_path="exports/geomodel_fp16.tflite")
interpreter.allocate_tensors()
input_detail = interpreter.get_input_details()[0]
output_detail = interpreter.get_output_details()[0]
inputs = np.array([[52.5, 13.4, 22.0]], dtype=np.float32)
interpreter.set_tensor(input_detail["index"], inputs)
interpreter.invoke()
probs = interpreter.get_tensor(output_detail["index"]) # (1, n_species)
labels = load_labels()
top_k = 10
top_indices = np.argsort(probs[0])[::-1][:top_k]
for i, idx in enumerate(top_indices):
print(f"{i+1}. {labels[idx]['common']}: {probs[0][idx]:.3f}")
Batch inference
All formats support batched inputs. Stack multiple [lat, lon, week] rows into a single (N, 3) array to predict many locations at once.