| Cross Entropy Loss 的硬截断、软化到 Focal Loss |
文章目录
一. 二分类模型
二. 修正的交叉少损失(硬截断)
import torch
import numpy as np
import torch.nn as nn
SEED = 0
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)
theta = lambda t: (torch.sign(t) + 1.) / 2.
sigmoid = lambda t: (torch.sigmoid(1e9 * t))
class Loss(nn.Module):
def __init__(self, theta, num_classer=2, reduction='mean', margin=0.5):
super().__init__()
self.theta = theta
self.num_classes = num_classer
self.reduction = reduction
self.m = margin
def forward(self, pred, y):
'''
pred: 2-D [batch, num_classes]. Softmaxed, no log
y: 1-D [batch]. Index, but one-hot
'''
y_onehot = torch.tensor(np.eye(self.num_classes)[y])
lambda_y_pred = 1 - self.theta(y_onehot - self.m) * self.theta(pred - self.m) - self.theta(
1 - self.m - y_onehot) * self.theta(1 - self.m - pred)
weight = torch.sign(torch.sum(lambda_y_pred, dim=1)).unsqueeze(0)
cel = y_onehot * torch.log(pred)
if self.reduction == 'sum':
return -torch.sum(torch.mm(weight, cel).squeeze(0))
else:
return -torch.mean(torch.mm(weight, cel).squeeze(0))
y_pred = torch.randn(3, 4)
y_pred_softmax = nn.Softmax(dim=1)(y_pred)
y_pred_softmax.clamp_(1e-8, 0.999999)
label = torch.tensor([0, 2, 2])
loss_fn = Loss(theta, 4, reduction='mean', margin=0.6)
print(loss_fn(y_pred_softmax, label).item())
三. 软化Loss
四. Focal Loss