def get_args() -> argparse.Namespace:
"""Parse command-line arguments for training."""
parser = argparse.ArgumentParser(description="Train STM32N6 audio classifier")
parser.add_argument("--data_path_train", type=str, required=True, help="Path to train dataset")
parser.add_argument("--max_samples", type=int, default=None, help="Max samples per class")
parser.add_argument("--upsample_ratio", type=float, default=0.5, help="Upsample ratio for minority classes")
parser.add_argument("--sample_rate", type=int, default=24000, help="Audio sample rate (Hz)")
parser.add_argument("--num_mels", type=int, default=64, help="Number of mel bins")
parser.add_argument("--spec_width", type=int, default=256, help="Spectrogram width (frames)")
parser.add_argument("--fft_length", type=int, default=512, help="FFT length")
parser.add_argument("--chunk_duration", type=float, default=3, help="Audio chunk duration (seconds)")
parser.add_argument("--max_duration", type=int, default=30, help="Max audio duration (seconds)")
parser.add_argument(
"--audio_frontend",
type=str,
default="hybrid",
choices=["precomputed", "hybrid", "raw", "librosa", "tf", "mfcc", "log_mel"],
)
parser.add_argument("--mag_scale", type=str, default="pwl", choices=["pcen", "pwl", "db", "none"])
parser.add_argument("--n_mfcc", type=int, default=20, help="Number of MFCC coefficients (mfcc frontend only)")
parser.add_argument("--embeddings_size", type=int, default=256, help="Embeddings layer size")
parser.add_argument("--alpha", type=float, default=1.0, help="Width multiplier")
parser.add_argument("--depth_multiplier", type=int, default=1, help="Depth multiplier")
parser.add_argument("--frontend_trainable", action="store_true", default=False)
parser.add_argument("--mixup_alpha", type=float, default=0.2, help="Mixup alpha")
parser.add_argument("--mixup_probability", type=float, default=0.25, help="Mixup batch fraction")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
parser.add_argument("--epochs", type=int, default=50, help="Number of epochs")
parser.add_argument("--learning_rate", type=float, default=0.001, help="Initial learning rate")
parser.add_argument("--dropout", type=float, default=0.5, help="Dropout rate before classifier head")
parser.add_argument("--optimizer", type=str, default="adam", choices=["adam", "sgd", "adamw"], help="Optimizer")
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay (adamw only)")
parser.add_argument(
"--loss",
type=str,
default="auto",
choices=["auto", "focal"],
help="Loss function. 'auto' selects based on mixup; 'focal' uses focal loss.",
)
parser.add_argument("--focal_gamma", type=float, default=2.0, help="Focal loss gamma (focusing parameter)")
parser.add_argument("--val_split", type=float, default=0.2, help="Validation split ratio")
parser.add_argument("--checkpoint_path", type=str, default="checkpoints/best_model.keras")
parser.add_argument("--label_smoothing", type=float, default=0.0, help="Label smoothing factor (0 = off)")
parser.add_argument("--use_se", action="store_true", default=False, help="Add SE channel attention to each block")
parser.add_argument("--se_reduction", type=int, default=4, help="SE channel reduction factor")
parser.add_argument(
"--use_inverted_residual", action="store_true", default=False, help="Use inverted residual blocks"
)
parser.add_argument("--expansion_factor", type=int, default=6, help="Expansion factor for inverted residuals")
parser.add_argument(
"--use_attention_pooling", action="store_true", default=False, help="Use attention pooling instead of GAP"
)
parser.add_argument("--spec_augment", action="store_true", default=False, help="Enable SpecAugment")
parser.add_argument("--freq_mask_max", type=int, default=8, help="Max frequency mask width (bins)")
parser.add_argument("--time_mask_max", type=int, default=25, help="Max time mask width (frames)")
parser.add_argument("--grad_clip", type=float, default=0.0, help="Max gradient norm for clipping (0 = disabled)")
parser.add_argument(
"--class_weights",
type=str,
default="none",
choices=["none", "balanced"],
help="Class weighting strategy ('none' or 'balanced' inverse-frequency)",
)
parser.add_argument(
"--mixed_precision", action="store_true", default=False, help="Enable FP16 mixed precision training"
)
parser.add_argument("--resume", action="store_true", default=False, help="Resume training from checkpoint")
parser.add_argument("--deterministic", action="store_true", default=False, help="Enable deterministic mode")
parser.add_argument("--seed", type=int, default=42, help="Random seed (used with --deterministic)")
parser.add_argument(
"--tune", action="store_true", default=False, help="Run Optuna hyperparameter search instead of single training"
)
parser.add_argument("--n_trials", type=int, default=20, help="Number of Optuna trials (used with --tune)")
parser.add_argument(
"--qat",
action="store_true",
default=False,
help="Quantization-aware fine-tuning (requires pretrained --checkpoint_path)",
)
return parser.parse_args()