Skip to content

Index

birdnet_stm32.models

Model architectures, audio frontend layer, magnitude scaling, and inference runners.

Use :func:build_model to create a model by name::

from birdnet_stm32.models import build_model
model = build_model("dscnn", num_mels=64, spec_width=256, ...)

register_model(name)

Decorator to register a model builder function.

The decorated function must accept keyword arguments and return an uncompiled tf.keras.Model.

Parameters:

Name Type Description Default
name str

Canonical model name (e.g. "dscnn").

required
Source code in birdnet_stm32/models/__init__.py
def register_model(name: str):
    """Decorator to register a model builder function.

    The decorated function must accept keyword arguments and return an
    uncompiled ``tf.keras.Model``.

    Args:
        name: Canonical model name (e.g. "dscnn").
    """

    def decorator(fn: Callable[..., tf.keras.Model]) -> Callable[..., tf.keras.Model]:
        if name in _MODEL_REGISTRY:
            raise ValueError(f"Model '{name}' is already registered.")
        _MODEL_REGISTRY[name] = fn
        return fn

    return decorator

build_model(name, **kwargs)

Build a model by registered name.

Parameters:

Name Type Description Default
name str

Model architecture name (e.g. "dscnn").

required
**kwargs Any

Forwarded to the model builder.

{}

Returns:

Type Description
Model

Uncompiled Keras model.

Raises:

Type Description
KeyError

If no model with the given name is registered.

Source code in birdnet_stm32/models/__init__.py
def build_model(name: str, **kwargs: Any) -> tf.keras.Model:
    """Build a model by registered name.

    Args:
        name: Model architecture name (e.g. "dscnn").
        **kwargs: Forwarded to the model builder.

    Returns:
        Uncompiled Keras model.

    Raises:
        KeyError: If no model with the given name is registered.
    """
    if name not in _MODEL_REGISTRY:
        raise KeyError(f"Unknown model: '{name}'. Available: {list_models()}")
    return _MODEL_REGISTRY[name](**kwargs)

list_models()

Return all registered model names.

Source code in birdnet_stm32/models/__init__.py
def list_models() -> list[str]:
    """Return all registered model names."""
    return sorted(_MODEL_REGISTRY.keys())