Add model/focal_loss.py

This commit is contained in:
2026-02-27 11:39:04 +08:00
parent 14ce733d36
commit 5a094286e0

135
model/focal_loss.py Normal file
View File

@@ -0,0 +1,135 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
def __init__(self, gamma=2, alpha=None, reduction='mean', task_type='binary', num_classes=None):
"""
Unified Focal Loss class for binary, multi-class, and multi-label classification tasks.
:param gamma: Focusing parameter, controls the strength of the modulating factor (1 - p_t)^gamma
:param alpha: Balancing factor, can be a scalar or a tensor for class-wise weights. If None, no class balancing is used.
:param reduction: Specifies the reduction method: 'none' | 'mean' | 'sum'
:param task_type: Specifies the type of task: 'binary', 'multi-class', or 'multi-label'
:param num_classes: Number of classes (only required for multi-class classification)
"""
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
self.reduction = reduction
self.task_type = task_type
self.num_classes = num_classes
# Handle alpha for class balancing in multi-class tasks
if task_type == 'multi-class' and alpha is not None and isinstance(alpha, (list, torch.Tensor)):
assert num_classes is not None, "num_classes must be specified for multi-class classification"
if isinstance(alpha, list):
self.alpha = torch.Tensor(alpha)
else:
self.alpha = alpha
def forward(self, inputs, targets):
"""
Forward pass to compute the Focal Loss based on the specified task type.
:param inputs: Predictions (logits) from the model.
Shape:
- binary/multi-label: (batch_size, num_classes)
- multi-class: (batch_size, num_classes)
:param targets: Ground truth labels.
Shape:
- binary: (batch_size,)
- multi-label: (batch_size, num_classes)
- multi-class: (batch_size,)
"""
if self.task_type == 'binary':
return self.binary_focal_loss(inputs, targets)
elif self.task_type == 'multi-class':
return self.multi_class_focal_loss(inputs, targets)
elif self.task_type == 'multi-label':
return self.multi_label_focal_loss(inputs, targets)
else:
raise ValueError(
f"Unsupported task_type '{self.task_type}'. Use 'binary', 'multi-class', or 'multi-label'.")
def binary_focal_loss(self, inputs, targets):
""" Focal loss for binary classification. """
probs = torch.sigmoid(inputs)
targets = targets.float()
# Compute binary cross entropy
bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
# Compute focal weight
p_t = probs * targets + (1 - probs) * (1 - targets)
focal_weight = (1 - p_t) ** self.gamma
# Apply alpha if provided
if self.alpha is not None:
alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
bce_loss = alpha_t * bce_loss
# Apply focal loss weighting
loss = focal_weight * bce_loss
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
return loss
def multi_class_focal_loss(self, inputs, targets):
""" Focal loss for multi-class classification. """
if self.alpha is not None:
alpha = self.alpha.to(inputs.device)
# Convert logits to probabilities with softmax
probs = F.softmax(inputs, dim=1)
# One-hot encode the targets
targets_one_hot = F.one_hot(targets, num_classes=self.num_classes).float()
# Compute cross-entropy for each class
ce_loss = -targets_one_hot * torch.log(probs)
# Compute focal weight
p_t = torch.sum(probs * targets_one_hot, dim=1) # p_t for each sample
focal_weight = (1 - p_t) ** self.gamma
# Apply alpha if provided (per-class weighting)
if self.alpha is not None:
alpha_t = alpha.gather(0, targets)
ce_loss = alpha_t.unsqueeze(1) * ce_loss
# Apply focal loss weight
loss = focal_weight.unsqueeze(1) * ce_loss
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
return loss
def multi_label_focal_loss(self, inputs, targets):
""" Focal loss for multi-label classification. """
probs = torch.sigmoid(inputs)
# Compute binary cross entropy
bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
# Compute focal weight
p_t = probs * targets + (1 - probs) * (1 - targets)
focal_weight = (1 - p_t) ** self.gamma
# Apply alpha if provided
if self.alpha is not None:
alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
bce_loss = alpha_t * bce_loss
# Apply focal loss weight
loss = focal_weight * bce_loss
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
return loss