diff --git a/model/focal_loss.py b/model/focal_loss.py new file mode 100644 index 0000000..d38d61f --- /dev/null +++ b/model/focal_loss.py @@ -0,0 +1,135 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FocalLoss(nn.Module): + def __init__(self, gamma=2, alpha=None, reduction='mean', task_type='binary', num_classes=None): + """ + Unified Focal Loss class for binary, multi-class, and multi-label classification tasks. + :param gamma: Focusing parameter, controls the strength of the modulating factor (1 - p_t)^gamma + :param alpha: Balancing factor, can be a scalar or a tensor for class-wise weights. If None, no class balancing is used. + :param reduction: Specifies the reduction method: 'none' | 'mean' | 'sum' + :param task_type: Specifies the type of task: 'binary', 'multi-class', or 'multi-label' + :param num_classes: Number of classes (only required for multi-class classification) + """ + super(FocalLoss, self).__init__() + self.gamma = gamma + self.alpha = alpha + self.reduction = reduction + self.task_type = task_type + self.num_classes = num_classes + + # Handle alpha for class balancing in multi-class tasks + if task_type == 'multi-class' and alpha is not None and isinstance(alpha, (list, torch.Tensor)): + assert num_classes is not None, "num_classes must be specified for multi-class classification" + if isinstance(alpha, list): + self.alpha = torch.Tensor(alpha) + else: + self.alpha = alpha + + def forward(self, inputs, targets): + """ + Forward pass to compute the Focal Loss based on the specified task type. + :param inputs: Predictions (logits) from the model. + Shape: + - binary/multi-label: (batch_size, num_classes) + - multi-class: (batch_size, num_classes) + :param targets: Ground truth labels. + Shape: + - binary: (batch_size,) + - multi-label: (batch_size, num_classes) + - multi-class: (batch_size,) + """ + if self.task_type == 'binary': + return self.binary_focal_loss(inputs, targets) + elif self.task_type == 'multi-class': + return self.multi_class_focal_loss(inputs, targets) + elif self.task_type == 'multi-label': + return self.multi_label_focal_loss(inputs, targets) + else: + raise ValueError( + f"Unsupported task_type '{self.task_type}'. Use 'binary', 'multi-class', or 'multi-label'.") + + def binary_focal_loss(self, inputs, targets): + """ Focal loss for binary classification. """ + probs = torch.sigmoid(inputs) + targets = targets.float() + + # Compute binary cross entropy + bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none') + + # Compute focal weight + p_t = probs * targets + (1 - probs) * (1 - targets) + focal_weight = (1 - p_t) ** self.gamma + + # Apply alpha if provided + if self.alpha is not None: + alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets) + bce_loss = alpha_t * bce_loss + + # Apply focal loss weighting + loss = focal_weight * bce_loss + + if self.reduction == 'mean': + return loss.mean() + elif self.reduction == 'sum': + return loss.sum() + return loss + + def multi_class_focal_loss(self, inputs, targets): + """ Focal loss for multi-class classification. """ + if self.alpha is not None: + alpha = self.alpha.to(inputs.device) + + # Convert logits to probabilities with softmax + probs = F.softmax(inputs, dim=1) + + # One-hot encode the targets + targets_one_hot = F.one_hot(targets, num_classes=self.num_classes).float() + + # Compute cross-entropy for each class + ce_loss = -targets_one_hot * torch.log(probs) + + # Compute focal weight + p_t = torch.sum(probs * targets_one_hot, dim=1) # p_t for each sample + focal_weight = (1 - p_t) ** self.gamma + + # Apply alpha if provided (per-class weighting) + if self.alpha is not None: + alpha_t = alpha.gather(0, targets) + ce_loss = alpha_t.unsqueeze(1) * ce_loss + + # Apply focal loss weight + loss = focal_weight.unsqueeze(1) * ce_loss + + if self.reduction == 'mean': + return loss.mean() + elif self.reduction == 'sum': + return loss.sum() + return loss + + def multi_label_focal_loss(self, inputs, targets): + """ Focal loss for multi-label classification. """ + probs = torch.sigmoid(inputs) + + # Compute binary cross entropy + bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none') + + # Compute focal weight + p_t = probs * targets + (1 - probs) * (1 - targets) + focal_weight = (1 - p_t) ** self.gamma + + # Apply alpha if provided + if self.alpha is not None: + alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets) + bce_loss = alpha_t * bce_loss + + # Apply focal loss weight + loss = focal_weight * bce_loss + + if self.reduction == 'mean': + return loss.mean() + elif self.reduction == 'sum': + return loss.sum() + return loss