import torch import torch.nn as nn import torch.nn.functional as F class FocalLoss(nn.Module): def __init__(self, gamma=2, alpha=None, reduction='mean', task_type='binary', num_classes=None): """ Unified Focal Loss class for binary, multi-class, and multi-label classification tasks. :param gamma: Focusing parameter, controls the strength of the modulating factor (1 - p_t)^gamma :param alpha: Balancing factor, can be a scalar or a tensor for class-wise weights. If None, no class balancing is used. :param reduction: Specifies the reduction method: 'none' | 'mean' | 'sum' :param task_type: Specifies the type of task: 'binary', 'multi-class', or 'multi-label' :param num_classes: Number of classes (only required for multi-class classification) """ super(FocalLoss, self).__init__() self.gamma = gamma self.alpha = alpha self.reduction = reduction self.task_type = task_type self.num_classes = num_classes # Handle alpha for class balancing in multi-class tasks if task_type == 'multi-class' and alpha is not None and isinstance(alpha, (list, torch.Tensor)): assert num_classes is not None, "num_classes must be specified for multi-class classification" if isinstance(alpha, list): self.alpha = torch.Tensor(alpha) else: self.alpha = alpha def forward(self, inputs, targets): """ Forward pass to compute the Focal Loss based on the specified task type. :param inputs: Predictions (logits) from the model. Shape: - binary/multi-label: (batch_size, num_classes) - multi-class: (batch_size, num_classes) :param targets: Ground truth labels. Shape: - binary: (batch_size,) - multi-label: (batch_size, num_classes) - multi-class: (batch_size,) """ if self.task_type == 'binary': return self.binary_focal_loss(inputs, targets) elif self.task_type == 'multi-class': return self.multi_class_focal_loss(inputs, targets) elif self.task_type == 'multi-label': return self.multi_label_focal_loss(inputs, targets) else: raise ValueError( f"Unsupported task_type '{self.task_type}'. Use 'binary', 'multi-class', or 'multi-label'.") def binary_focal_loss(self, inputs, targets): """ Focal loss for binary classification. """ probs = torch.sigmoid(inputs) targets = targets.float() # Compute binary cross entropy bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none') # Compute focal weight p_t = probs * targets + (1 - probs) * (1 - targets) focal_weight = (1 - p_t) ** self.gamma # Apply alpha if provided if self.alpha is not None: alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets) bce_loss = alpha_t * bce_loss # Apply focal loss weighting loss = focal_weight * bce_loss if self.reduction == 'mean': return loss.mean() elif self.reduction == 'sum': return loss.sum() return loss def multi_class_focal_loss(self, inputs, targets): """ Focal loss for multi-class classification. """ if self.alpha is not None: alpha = self.alpha.to(inputs.device) # Convert logits to probabilities with softmax probs = F.softmax(inputs, dim=1) # One-hot encode the targets targets_one_hot = F.one_hot(targets, num_classes=self.num_classes).float() # Compute cross-entropy for each class ce_loss = -targets_one_hot * torch.log(probs) # Compute focal weight p_t = torch.sum(probs * targets_one_hot, dim=1) # p_t for each sample focal_weight = (1 - p_t) ** self.gamma # Apply alpha if provided (per-class weighting) if self.alpha is not None: alpha_t = alpha.gather(0, targets) ce_loss = alpha_t.unsqueeze(1) * ce_loss # Apply focal loss weight loss = focal_weight.unsqueeze(1) * ce_loss if self.reduction == 'mean': return loss.mean() elif self.reduction == 'sum': return loss.sum() return loss def multi_label_focal_loss(self, inputs, targets): """ Focal loss for multi-label classification. """ probs = torch.sigmoid(inputs) # Compute binary cross entropy bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none') # Compute focal weight p_t = probs * targets + (1 - probs) * (1 - targets) focal_weight = (1 - p_t) ** self.gamma # Apply alpha if provided if self.alpha is not None: alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets) bce_loss = alpha_t * bce_loss # Apply focal loss weight loss = focal_weight * bce_loss if self.reduction == 'mean': return loss.mean() elif self.reduction == 'sum': return loss.sum() return loss