model.loss¶
Loss functions for multi-task learning.
Loss functions for multi-task learning.
Species prediction
- Asymmetric Loss (ASL) for multi-label classification (default)
- BCE with logits
- Focal loss
- Assume-Negative (AN) loss for presence-only data with negative sampling
Environmental prediction: mean squared error (auxiliary task).
ASL (Ridnik et al., 2021) uses separate focusing parameters for positive and negative terms. A hard-thresholding mechanism (probability margin m) shifts the negative probability down before computing the loss, effectively discarding very easy negatives. This is especially effective for species occurrence data where >99 %% of labels are 0.
The AN loss implements the "Full Location-Aware Assume Negative" (LAN-full) strategy from Cole et al. (SINR, 2023). It combines: - Community pseudo-negatives (SLDS): at each observed location, all species not in the observation list are treated as absent. - Spatial pseudo-negatives (SSDL): for each observed species, a random location from the batch is sampled where it is assumed absent.
Positive samples are up-weighted by λ to compensate for the overwhelming majority of pseudo-negative labels. For computational efficiency with large species vocabularies, only a random subset of M negative species is evaluated per sample.
Classes¶
AssumeNegativeLoss
¶
Bases: Module
Assume-Negative loss for presence-only species occurrence data.
Implements the LAN-full strategy: for each sample the loss is computed on the observed positive species (up-weighted by λ) plus a random subset of M "assumed-negative" species. This avoids the O(n_species) per-sample cost when the vocabulary is large (10K+).
The loss is normalised per-species (dividing by n_species) so that
its gradient magnitude is comparable to standard BCE regardless of how
many species are in the vocabulary or how few positives each sample has.
λ controls the relative importance of positives vs negatives; the
absolute scale matches BCE.
When negative sampling is active (M > 0), the sampled negative
contribution is scaled up by n_neg / n_sampled_neg per sample
to approximate the full-vocabulary expectation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pos_lambda
|
float
|
Up-weighting factor for positive samples. |
4.0
|
neg_samples
|
int
|
Number of negative species to sample per example (M). Use 0 to include all negatives (exact but slow for large vocabs). |
1024
|
label_smoothing
|
float
|
Smooth binary targets to prevent overconfident
predictions. Positive targets become |
0.0
|
Functions¶
forward(logits, targets)
¶
Compute the assume-negative loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
logits
|
Tensor
|
(batch, n_species) raw logits. |
required |
targets
|
Tensor
|
(batch, n_species) binary labels (1 = observed, 0 = assumed absent). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Scalar loss. |
MultiTaskLoss
¶
Bases: Module
Weighted multi-task loss: species (BCE, ASL, focal, or AN) + environmental (MSE).
Total = species_weight × species_loss + env_weight × env_loss [+ habitat_weight × habitat_species_loss]
When the habitat-species head is enabled, an auxiliary species loss is computed directly on the habitat head's logits (before gating). This gives the habitat head a full-strength learning signal independent of the gate value, which is critical because the gate initially suppresses the habitat contribution (σ(3) ≈ 0.05).
Functions¶
__init__(species_weight=1.0, env_weight=0.5, habitat_weight=0.0, pos_weight=None, species_loss='bce', focal_alpha=0.5, focal_gamma=2.0, pos_lambda=4.0, neg_samples=1024, label_smoothing=0.0, asl_gamma_pos=0.0, asl_gamma_neg=2.0, asl_clip=0.05, reduction='mean')
¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
species_weight
|
float
|
Multiplier for species loss. |
1.0
|
env_weight
|
float
|
Multiplier for environmental loss. |
0.5
|
habitat_weight
|
float
|
Multiplier for auxiliary habitat-species loss
(applied to habitat head logits before gating). Only used
when the model returns |
0.0
|
pos_weight
|
Optional[Tensor]
|
Positive-class weights for BCE mode (ignored for focal/an/asl). |
None
|
species_loss
|
str
|
'bce' (default), 'asl' (asymmetric), 'focal', or 'an'. |
'bce'
|
focal_alpha
|
float
|
Alpha for focal loss (default 0.5 = neutral). |
0.5
|
focal_gamma
|
float
|
Gamma for focal loss. |
2.0
|
pos_lambda
|
float
|
λ for assume-negative loss (positive up-weighting, default 4). |
4.0
|
neg_samples
|
int
|
M for assume-negative loss (negative species to sample). |
1024
|
label_smoothing
|
float
|
Smooth binary targets (AN loss only, 0 = off). |
0.0
|
asl_gamma_pos
|
float
|
ASL focusing parameter for positive species (default 0). |
0.0
|
asl_gamma_neg
|
float
|
ASL focusing parameter for negative species (default 2). |
2.0
|
asl_clip
|
float
|
ASL probability margin for negatives (default 0.05). |
0.05
|
forward(predictions, targets, compute_env_loss=True)
¶
Compute weighted multi-task loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
predictions
|
Dict[str, Tensor]
|
Dict with |
required |
targets
|
Dict[str, Tensor]
|
Dict with |
required |
compute_env_loss
|
bool
|
Whether to include the environmental MSE term. |
True
|
Returns:
| Type | Description |
|---|---|
Dict[str, Tensor]
|
Dict with |
Functions¶
focal_loss(logits, targets, alpha=0.25, gamma=2.0, reduction='mean')
¶
Focal loss for multi-label classification.
Down-weights easy negatives and up-weights hard positives, which is critical for species occurrence data where >99% of labels are 0.
Reference: Lin et al., "Focal Loss for Dense Object Detection" (2017)
asymmetric_loss(logits, targets, gamma_pos=0.0, gamma_neg=2.0, clip=0.05, reduction='mean')
¶
Asymmetric Loss for multi-label classification.
Applies separate focusing parameters for positive and negative terms, plus a probability-margin shift on negatives. This combination aggressively down-weights easy/confident negatives while leaving the positive gradient intact, making it ideal for extreme class imbalance.
Reference: Ridnik et al., "Asymmetric Loss For Multi-Label Classification" (ICCV 2021).
The per-element loss is::
L = -[ y · (1-p)^γ+ · log(p)
+ (1-y) · p_m^γ- · log(1-p_m) ]
where p_m = max(p - m, 0) is the margin-shifted probability.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
logits
|
Tensor
|
|
required |
targets
|
Tensor
|
|
required |
gamma_pos
|
float
|
Focusing parameter for positive (present) species. 0 = no down-weighting of hard positives (recommended). |
0.0
|
gamma_neg
|
float
|
Focusing parameter for negative (absent) species. Higher values more aggressively suppress easy negatives. |
2.0
|
clip
|
float
|
Probability margin m. Shifts the negative probability down before loss computation, effectively ignoring very easy negatives. Set 0 to disable. |
0.05
|
reduction
|
str
|
|
'mean'
|
masked_mse(pred, target)
¶
Mean squared error that ignores NaN positions in target.
Environmental feature targets may contain NaN where data was missing. This function computes the MSE only over valid (non-NaN) elements so the model is not penalised for predicting placeholder values.
Predictions are clamped to [-1e4, 1e4] to prevent FP16 overflow from turning into inf² → NaN under AMP.
Returns zero if there are no valid elements in the batch.
compute_pos_weights(species_targets, smoothing=1.0)
¶
Compute positive-class weights for BCE mode (neg/pos ratio with smoothing).