Skip to content

tuner

birdnet_stm32.training.tuner

Optuna hyperparameter tuning for DS-CNN training.

run_tuning(args)

Run Optuna hyperparameter search.

Creates an Optuna study that maximizes val_roc_auc. After all trials, prints the best hyperparameters and saves them as a JSON file alongside the checkpoint directory.

Parameters:

Name Type Description Default
args Namespace

CLI arguments including n_trials.

required
Source code in birdnet_stm32/training/tuner.py
def run_tuning(args: argparse.Namespace) -> None:
    """Run Optuna hyperparameter search.

    Creates an Optuna study that maximizes val_roc_auc. After all trials,
    prints the best hyperparameters and saves them as a JSON file alongside
    the checkpoint directory.

    Args:
        args: CLI arguments including n_trials.
    """
    study = optuna.create_study(
        direction="maximize",
        study_name="birdnet_stm32_tune",
        pruner=optuna.pruners.MedianPruner(n_startup_trials=3, n_warmup_steps=5),
    )

    study.optimize(lambda trial: _objective(trial, args), n_trials=args.n_trials)

    # Report results
    print("\n" + "=" * 60)
    print("Optuna tuning complete")
    print(f"  Best val_roc_auc: {study.best_value:.4f}")
    print(f"  Best trial: #{study.best_trial.number}")
    print("  Best hyperparameters:")
    for key, value in study.best_trial.params.items():
        print(f"    {key}: {value}")
    print("=" * 60)

    # Save best params as JSON
    import json

    results_path = os.path.join(os.path.dirname(args.checkpoint_path), "optuna_best_params.json")
    best_data = {
        "best_value": study.best_value,
        "best_trial": study.best_trial.number,
        "best_params": study.best_trial.params,
        "n_trials": len(study.trials),
    }
    with open(results_path, "w") as f:
        json.dump(best_data, f, indent=2)
    print(f"Best parameters saved to '{results_path}'")

    # Copy best trial checkpoint to main checkpoint path
    best_trial_ckpt = os.path.join(
        os.path.dirname(args.checkpoint_path),
        "optuna_trials",
        f"trial_{study.best_trial.number}.keras",
    )
    if os.path.isfile(best_trial_ckpt):
        import shutil

        shutil.copy2(best_trial_ckpt, args.checkpoint_path)
        print(f"Best model copied to '{args.checkpoint_path}'")