def get_args() -> argparse.Namespace:
"""Parse command-line arguments for training.
Sensible defaults are chosen so that most users only need to specify
``--data_path_train``. Architecture features that improve accuracy
(SE attention, inverted residuals, SpecAugment, balanced class weights,
deterministic seeding, gradient clipping, label smoothing) are **on by
default** and can be disabled with ``--no_*`` flags when experimenting.
"""
parser = argparse.ArgumentParser(description="Train STM32N6 audio classifier")
# -- Data -----------------------------------------------------------------
parser.add_argument("--data_path_train", type=str, required=True, help="Path to train dataset")
parser.add_argument("--max_classes", type=int, default=None, help="Use top N classes by sample count")
parser.add_argument("--max_samples", type=int, default=None, help="Max samples per class")
parser.add_argument("--upsample_ratio", type=float, default=0.5, help="Upsample ratio for minority classes")
# -- Audio ----------------------------------------------------------------
parser.add_argument("--sample_rate", type=int, default=24000, help="Audio sample rate (Hz)")
parser.add_argument("--num_mels", type=int, default=64, help="Number of mel bins")
parser.add_argument("--spec_width", type=int, default=256, help="Spectrogram width (frames)")
parser.add_argument("--fft_length", type=int, default=512, help="FFT length")
parser.add_argument("--chunk_duration", type=float, default=3, help="Audio chunk duration (seconds)")
parser.add_argument(
"--max_duration",
type=int,
default=60,
help=(
"Maximum seconds to read per file. The loader still reads only the bytes it needs "
"for the candidate chunks (smart-crop bounded by --max_chunks_per_file)."
),
)
parser.add_argument(
"--audio_frontend",
type=str,
default="hybrid",
choices=["hybrid", "raw", "librosa", "mfcc", "log_mel"],
help="Audio frontend mode",
)
parser.add_argument("--mag_scale", type=str, default="pwl", choices=["pcen", "pwl", "db", "none"])
parser.add_argument("--n_mfcc", type=int, default=20, help="Number of MFCC coefficients (mfcc frontend only)")
# -- Model architecture ---------------------------------------------------
parser.add_argument("--embeddings_size", type=int, default=256, help="Embeddings layer size")
parser.add_argument("--alpha", type=float, default=1.0, help="Width multiplier")
parser.add_argument("--depth_multiplier", type=int, default=1, help="Depth multiplier")
parser.add_argument("--frontend_trainable", action="store_true", default=False)
parser.add_argument("--no_se", action="store_true", default=False, help="Disable SE channel attention")
parser.add_argument("--se_reduction", type=int, default=8, help="SE channel reduction factor")
parser.add_argument("--no_inverted_residual", action="store_true", default=False, help="Use plain DS blocks")
parser.add_argument("--expansion_factor", type=int, default=2, help="Expansion factor for inverted residuals")
parser.add_argument(
"--use_attention_pooling", action="store_true", default=False, help="Use attention pooling instead of GAP"
)
# -- Augmentation ---------------------------------------------------------
parser.add_argument("--no_spec_augment", action="store_true", default=False, help="Disable SpecAugment")
parser.add_argument("--freq_mask_max", type=int, default=8, help="Max frequency mask width (bins)")
parser.add_argument("--time_mask_max", type=int, default=25, help="Max time mask width (frames)")
parser.add_argument("--mixup_alpha", type=float, default=0.2, help="Mixup alpha")
parser.add_argument("--mixup_probability", type=float, default=0.25, help="Mixup batch fraction")
# -- Training -------------------------------------------------------------
parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
parser.add_argument("--num_workers", type=int, default=8, help="Parallel data loading workers (0 = sequential)")
parser.add_argument(
"--max_chunks_per_file",
type=int,
default=3,
help="Max salient chunks to extract per file open (reduces redundant I/O for long recordings)",
)
parser.add_argument(
"--prefetch_batches",
type=int,
default=2,
help="Loader prefetch queue depth in batches (higher = faster, but more RAM)",
)
parser.add_argument("--epochs", type=int, default=50, help="Number of epochs")
parser.add_argument("--learning_rate", type=float, default=0.001, help="Initial learning rate")
parser.add_argument("--dropout", type=float, default=0.5, help="Dropout rate before classifier head")
parser.add_argument("--optimizer", type=str, default="adam", choices=["adam", "sgd", "adamw"], help="Optimizer")
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay (adamw only)")
parser.add_argument(
"--loss",
type=str,
default="auto",
choices=["auto", "focal"],
help="Loss function. 'auto' selects based on mixup; 'focal' uses focal loss.",
)
parser.add_argument("--focal_gamma", type=float, default=2.0, help="Focal loss gamma (focusing parameter)")
parser.add_argument("--val_split", type=float, default=0.2, help="Validation split ratio")
parser.add_argument(
"--checkpoint_path", type=str, default="checkpoints/best_model.keras", help="Output checkpoint path (.keras)"
)
parser.add_argument("--label_smoothing", type=float, default=0.1, help="Label smoothing factor (0 = off)")
parser.add_argument("--grad_clip", type=float, default=1.0, help="Max gradient norm for clipping (0 = disabled)")
parser.add_argument(
"--no_class_weights",
action="store_true",
default=False,
help="Disable balanced inverse-frequency class weighting",
)
parser.add_argument(
"--mixed_precision", action="store_true", default=False, help="Enable FP16 mixed precision training"
)
parser.add_argument("--resume", action="store_true", default=False, help="Resume training from checkpoint")
parser.add_argument("--seed", type=int, default=42, help="Random seed for deterministic training")
# -- Tuning & QAT --------------------------------------------------------
parser.add_argument(
"--tune", action="store_true", default=False, help="Run Optuna hyperparameter search instead of single training"
)
parser.add_argument("--n_trials", type=int, default=20, help="Number of Optuna trials (used with --tune)")
parser.add_argument(
"--qat",
action="store_true",
default=False,
help="Quantization-aware fine-tuning (requires pretrained --checkpoint_path)",
)
# -- Linear probing -------------------------------------------------------
parser.add_argument(
"--linear_probe",
action="store_true",
default=False,
help="Freeze backbone and train only the classifier head (requires pretrained --checkpoint_path)",
)
args = parser.parse_args()
# Derive positive flags from --no_* flags
args.use_se = not args.no_se
args.use_inverted_residual = not args.no_inverted_residual
args.spec_augment = not args.no_spec_augment
args.class_weights = "none" if args.no_class_weights else "balanced"
args.deterministic = True # always deterministic
return args