Skip to content

model.loss

Loss functions for multi-task learning.

Loss functions for multi-task learning.

Species prediction
  • Asymmetric Loss (ASL) for multi-label classification (default)
  • BCE with logits
  • Focal loss
  • Assume-Negative (AN) loss for presence-only data with negative sampling

Environmental prediction: mean squared error (auxiliary task).

ASL (Ridnik et al., 2021) uses separate focusing parameters for positive and negative terms. A hard-thresholding mechanism (probability margin m) shifts the negative probability down before computing the loss, effectively discarding very easy negatives. This is especially effective for species occurrence data where >99 %% of labels are 0.

The AN loss implements the "Full Location-Aware Assume Negative" (LAN-full) strategy from Cole et al. (SINR, 2023). It combines: - Community pseudo-negatives (SLDS): at each observed location, all species not in the observation list are treated as absent. - Spatial pseudo-negatives (SSDL): for each observed species, a random location from the batch is sampled where it is assumed absent.

Positive samples are up-weighted by λ to compensate for the overwhelming majority of pseudo-negative labels. For computational efficiency with large species vocabularies, only a random subset of M negative species is evaluated per sample.

Classes

AssumeNegativeLoss

Bases: Module

Assume-Negative loss for presence-only species occurrence data.

Implements the LAN-full strategy: for each sample the loss is computed on the observed positive species (up-weighted by λ) plus a random subset of M "assumed-negative" species. This avoids the O(n_species) per-sample cost when the vocabulary is large (10K+).

The loss is normalised per-species (dividing by n_species) so that its gradient magnitude is comparable to standard BCE regardless of how many species are in the vocabulary or how few positives each sample has. λ controls the relative importance of positives vs negatives; the absolute scale matches BCE.

When negative sampling is active (M > 0), the sampled negative contribution is scaled up by n_neg / n_sampled_neg per sample to approximate the full-vocabulary expectation.

Parameters:

Name Type Description Default
pos_lambda float

Up-weighting factor for positive samples.

4.0
neg_samples int

Number of negative species to sample per example (M). Use 0 to include all negatives (exact but slow for large vocabs).

1024
label_smoothing float

Smooth binary targets to prevent overconfident predictions. Positive targets become 1 - ε, negatives become ε. Set to 0 to disable.

0.0
Functions
forward(logits, targets)

Compute the assume-negative loss.

Parameters:

Name Type Description Default
logits Tensor

(batch, n_species) raw logits.

required
targets Tensor

(batch, n_species) binary labels (1 = observed, 0 = assumed absent).

required

Returns:

Type Description
Tensor

Scalar loss.

MultiTaskLoss

Bases: Module

Weighted multi-task loss: species (BCE, ASL, focal, or AN) + environmental (MSE).

Total = species_weight × species_loss + env_weight × env_loss [+ habitat_weight × habitat_species_loss]

When the habitat-species head is enabled, an auxiliary species loss is computed directly on the habitat head's logits (before gating). This gives the habitat head a full-strength learning signal independent of the gate value, which is critical because the gate initially suppresses the habitat contribution (σ(3) ≈ 0.05).

Functions
__init__(species_weight=1.0, env_weight=0.5, habitat_weight=0.0, pos_weight=None, species_loss='bce', focal_alpha=0.5, focal_gamma=2.0, pos_lambda=4.0, neg_samples=1024, label_smoothing=0.0, asl_gamma_pos=0.0, asl_gamma_neg=2.0, asl_clip=0.05, reduction='mean')

Parameters:

Name Type Description Default
species_weight float

Multiplier for species loss.

1.0
env_weight float

Multiplier for environmental loss.

0.5
habitat_weight float

Multiplier for auxiliary habitat-species loss (applied to habitat head logits before gating). Only used when the model returns 'habitat_logits'. Default 0 (disabled); 0.5 is a reasonable starting point when --habitat_head is enabled.

0.0
pos_weight Optional[Tensor]

Positive-class weights for BCE mode (ignored for focal/an/asl).

None
species_loss str

'bce' (default), 'asl' (asymmetric), 'focal', or 'an'.

'bce'
focal_alpha float

Alpha for focal loss (default 0.5 = neutral).

0.5
focal_gamma float

Gamma for focal loss.

2.0
pos_lambda float

λ for assume-negative loss (positive up-weighting, default 4).

4.0
neg_samples int

M for assume-negative loss (negative species to sample).

1024
label_smoothing float

Smooth binary targets (AN loss only, 0 = off).

0.0
asl_gamma_pos float

ASL focusing parameter for positive species (default 0).

0.0
asl_gamma_neg float

ASL focusing parameter for negative species (default 2).

2.0
asl_clip float

ASL probability margin for negatives (default 0.05).

0.05
forward(predictions, targets, compute_env_loss=True)

Compute weighted multi-task loss.

Parameters:

Name Type Description Default
predictions Dict[str, Tensor]

Dict with 'species_logits' and optionally 'env_pred'.

required
targets Dict[str, Tensor]

Dict with 'species' and 'env_features' tensors.

required
compute_env_loss bool

Whether to include the environmental MSE term.

True

Returns:

Type Description
Dict[str, Tensor]

Dict with 'species', 'env' (if computed), and 'total' losses.

Functions

focal_loss(logits, targets, alpha=0.25, gamma=2.0, reduction='mean')

Focal loss for multi-label classification.

Down-weights easy negatives and up-weights hard positives, which is critical for species occurrence data where >99% of labels are 0.

Reference: Lin et al., "Focal Loss for Dense Object Detection" (2017)

asymmetric_loss(logits, targets, gamma_pos=0.0, gamma_neg=2.0, clip=0.05, reduction='mean')

Asymmetric Loss for multi-label classification.

Applies separate focusing parameters for positive and negative terms, plus a probability-margin shift on negatives. This combination aggressively down-weights easy/confident negatives while leaving the positive gradient intact, making it ideal for extreme class imbalance.

Reference: Ridnik et al., "Asymmetric Loss For Multi-Label Classification" (ICCV 2021).

The per-element loss is::

L = -[ y · (1-p)^γ+ · log(p)
     + (1-y) · p_m^γ- · log(1-p_m) ]

where p_m = max(p - m, 0) is the margin-shifted probability.

Parameters:

Name Type Description Default
logits Tensor

(batch, n_species) raw logits.

required
targets Tensor

(batch, n_species) binary labels.

required
gamma_pos float

Focusing parameter for positive (present) species. 0 = no down-weighting of hard positives (recommended).

0.0
gamma_neg float

Focusing parameter for negative (absent) species. Higher values more aggressively suppress easy negatives.

2.0
clip float

Probability margin m. Shifts the negative probability down before loss computation, effectively ignoring very easy negatives. Set 0 to disable.

0.05
reduction str

'mean' | 'sum' | 'none'.

'mean'

masked_mse(pred, target)

Mean squared error that ignores NaN positions in target.

Environmental feature targets may contain NaN where data was missing. This function computes the MSE only over valid (non-NaN) elements so the model is not penalised for predicting placeholder values.

Predictions are clamped to [-1e4, 1e4] to prevent FP16 overflow from turning into inf² → NaN under AMP.

Returns zero if there are no valid elements in the batch.

compute_pos_weights(species_targets, smoothing=1.0)

Compute positive-class weights for BCE mode (neg/pos ratio with smoothing).