0
点赞
收藏
分享

微信扫一扫

『NLP学习笔记』Cross Entropy Loss 的硬截断、软化到 Focal Loss

幺幺零 2022-04-13 阅读 87
Cross Entropy Loss 的硬截断、软化到 Focal Loss

文章目录

一. 二分类模型

二. 修正的交叉少损失(硬截断)

import torch
import numpy as np
import torch.nn as nn

SEED = 0
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)

theta = lambda t: (torch.sign(t) + 1.) / 2.
sigmoid = lambda t: (torch.sigmoid(1e9 * t))


class Loss(nn.Module):
    def __init__(self, theta, num_classer=2, reduction='mean', margin=0.5):
        super().__init__()
        self.theta = theta
        self.num_classes = num_classer
        self.reduction = reduction
        self.m = margin

    def forward(self, pred, y):
        '''
        pred: 2-D [batch, num_classes]. Softmaxed, no log
        y: 1-D [batch]. Index, but one-hot
        '''
        y_onehot = torch.tensor(np.eye(self.num_classes)[y])  # 2-D one-hot
        lambda_y_pred = 1 - self.theta(y_onehot - self.m) * self.theta(pred - self.m) - self.theta(
            1 - self.m - y_onehot) * self.theta(1 - self.m - pred)

        weight = torch.sign(torch.sum(lambda_y_pred, dim=1)).unsqueeze(0)
        cel = y_onehot * torch.log(pred)  # + (1 - y_onehot) * torch.log(1 - pred)
        if self.reduction == 'sum':
            return -torch.sum(torch.mm(weight, cel).squeeze(0))
            # return -torch.sum(lambda_y_pred * cel)
        else:
            return -torch.mean(torch.mm(weight, cel).squeeze(0))
            # return -torch.mean(torch.sum(lambda_y_pred * cel, dim=0))


y_pred = torch.randn(3, 4)
y_pred_softmax = nn.Softmax(dim=1)(y_pred)  # dim=1是每一行和为1
# pytorch中,一般来说如果对tensor的一个函数后加上了下划线,则表明这是一个in-place类型。in-place类型是指,
# 当在一个tensor上操作了之后,是直接修改了这个tensor,而不是返回一个新的tensor并不修改旧的tensor。
y_pred_softmax.clamp_(1e-8, 0.999999)  # 功能将输入input张量每个元素的值压缩到区间 [min,max],并返回结果到一个新张量。
label = torch.tensor([0, 2, 2])
loss_fn = Loss(theta, 4, reduction='mean', margin=0.6)

print(loss_fn(y_pred_softmax, label).item())  # .item()得到只有一个元素张量里面的元素值

三. 软化Loss

四. Focal Loss

举报

相关推荐

0 条评论