文章目录
- 本文内容
- 环境配置
- 全局变量
- 模型构建
- 损失函数
- 模型训练
- 构造Dataset
- 构造Dataloader
- 训练
- 模型评估
- 模型使用
- 参考文献
代码地址 :https://github.com/iioSnail/MDCSpell_pytorch
本文内容
本文为MDCSpell: A Multi-task Detector-Corrector Framework for Chinese Spelling Correction论文的Pytorch实现。
论文地址: https://aclanthology.org/2022.findings-acl.98/
论文年份:2022
论文笔记:javascript:void(0)
论文大致内容:作者基于Transformer和BERT设计了一个多任务的网络来进行CSC(Chinese Spell Checking)任务(中文拼写纠错)。多任务分别是找出哪个字是错的和对错字进行纠正。
由于作者并没有公开代码,所以我就尝试自己实现一个,最终我的实验结果如下表:
Dataset | Model | D_Precision | D_Recall | D_F1 | C_Prec | C_Rec | C_F1 |
SIGHAN 13 | MDCSpell | 89.1 | 78.3 | 83.4 | 87.5 | 76.8 | 81.8 |
SIGHAN 13 | MDCSpell(复现) | 80.2 | 79.9 | 80.0 | 77.2 | 76.9 | 77.1 |
SIGHAN 14 | MDCSpell | 70.2 | 68.8 | 69.5 | 69.0 | 67.7 | 68.3 |
SIGHAN 14 | MDCSpell(复现) | 82.8 | 66.6 | 73.8 | 79.9 | 64.3 | 71.2 |
SIGHAN 15 | MDCSpell | 80.8 | 80.6 | 80.7 | 78.4 | 78.2 | 78.3 |
SIGHAN 15 | MDCSpell(复现) | 86.7 | 76.1 | 81.1 | 72.5 | 82.7 | 77.3 |
这里是我训练了2个epoch的结果,与作者的结论相差不大。如果我增加训练次数的话,也许可以和作者的结果达到一致。
环境配置
try:
import transformers
except:
import os
import copy
import pickle
import torch
import transformers
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer
from tqdm import
torch.__version__
'1.12.1+cu113'
transformers.__version__
'4.21.3'
全局变量
# 句子的长度,作者并没有说明。我这里就按经验取一个
max_length = 128
# 作者使用的batch_size
batch_size = 32
# epoch数,作者并没有具体说明,按经验取一个
epochs = 10
# 每${log_after_step}步,打印一次日志
log_after_step = 20
# 模型存放的位置。
model_path = './drive/MyDrive/models/MDCSpell/'
os.makedirs(model_path, exist_ok=True)
model_path = model_path + 'MDCSpell-model.pt'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)
Device: cuda
模型构建
Correction Network 的数据流向如下:
1.将token序列 [CLS] 遇 到 逆 竟 [SEP]
送给Word Embedding模块进行embeddings,得到向量。
个人认为此时的embedding仅仅是Word Embeding,并不包含Position Embedding和Segment Embedding。
2.之后将向量送入BERT,增加Position Embedding和Segment Embedding,得到
。
3.在BERT内部,会经历多层的TransformerEncoder,最终的得到输出向量 .
4.将BERT的输出 和 隔壁Detection Network输出的
进行融合,得到
融合时并不对[CLS]
和[SEP]
进行融合
5.将送给全连接层(Dense Layer)做最后的预测。
Correction Network模型细节:
- BERT:作者使用的是具有12层Transformer Block的BERT-base版。
- Dense Layer:Dense Layer的输入通道为词向量维度,输出通道为词典大小。例如:词向量维度为768,词典大小为20000,则Dense Layer则为
nn.Linear(768, 20000)
- Dense Layer的初始化:Dense Layer的权重使用的是Word Embedding的参数。因为word Embedding是将词index转成词向量,所以其参数刚好是Dense Layer的转置,即Word Embedding是
nn.Linear(20000, 768)
,所以作者就是用Word Embedding的转置来初始化Dense Layer的参数。因为这样可以加速训练,且使模型变的稳定。
Detection Network的数据流向如下:
1.输入为使用BERT得到的word Embedding 。虽然图里并不包含
[CLS]
和[SEP]
的词向量,但个人认为不需要对其特殊处理,因为最后的预测也用不到这两个token.
2.将增加Position Embedding信息,得到
在论文中说Detection Network使用的是向量,其是word embedding+position embedding+segment embedding。这与图上是矛盾的,这里以图为准了。
3.将向量送入Transformer Block,得到输出向量
4.一方面,将输出向量送给隔壁的Correction Network进行融合;另一方面,将
送给后续的全连接层(Dense Layer)来判断哪个token是错误的.
Detection Network的细节:
- Transformer Block:Transformer Block是2层的TransformerEncoder。
- Transformer Block参数初始化:Transformer Block参数初始化使用的是BERT的权重。
- Dense Layer:Dense Layer的输入通道为词向量大小,输出通道为1。使用Sigmoid来判别该token为错字的概率。
class CorrectionNetwork(nn.Module):
def __init__(self):
super(CorrectionNetwork, self).__init__()
# BERT分词器,作者并没提到自己使用的是哪个中文版的bert,我这里就使用一个比较常用的
self.tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")
# BERT
self.bert = AutoModel.from_pretrained("hfl/chinese-roberta-wwm-ext")
# BERT的word embedding,本质就是个nn.Embedding
self.word_embedding_table = self.bert.get_input_embeddings()
# 预测层。hidden_size是词向量的大小,len(self.tokenizer)是词典大小
self.dense_layer = nn.Linear(self.bert.config.hidden_size, len(self.tokenizer))
def forward(self, inputs, word_embeddings, detect_hidden_states):
"""
Correction Network的前向传递
:param inputs: inputs为tokenizer对中文文本的分词结果,
里面包含了token对一个的index,attention_mask等
:param word_embeddings: 使用BERT的word_embedding对token进行embedding后的结果
:param detect_hidden_states: Detection Network输出hidden state
:return: Correction Network对个token的预测结果。
"""
# 1. 使用bert进行前向传递
bert_outputs = self.bert(token_type_ids=inputs['token_type_ids'],
attention_mask=inputs['attention_mask'],
inputs_embeds=word_embeddings)
# 2. 将bert的hidden_state和Detection Network的hidden state进行融合。
hidden_states = bert_outputs['last_hidden_state'] + detect_hidden_states
# 3. 最终使用全连接层进行token预测
return self.dense_layer(hidden_states)
def get_inputs_and_word_embeddings(self, sequences, max_length=128):
"""
对中文序列进行分词和word embeddings处理
:param sequences: 中文文本序列。例如: ["鸡你太美", "哎呦,你干嘛!"]
:param max_length: 文本的最大长度,不足则进行填充,超出进行裁剪。
:return: tokenizer的输出和word embeddings.
"""
inputs = self.tokenizer(sequences, padding='max_length', max_length=max_length, return_tensors='pt',
truncation=True).to(device)
# 使用BERT的work embeddings对token进行embedding,这里得到的embedding并不包含position embedding和segment embedding
word_embeddings = self.word_embedding_table(inputs['input_ids'])
return inputs,
class DetectionNetwork(nn.Module):
def __init__(self, position_embeddings, transformer_blocks, hidden_size):
"""
:param position_embeddings: bert的position_embeddings,本质是一个nn.Embedding
:param transformer: BERT的前两层transformer_block,其是一个ModuleList对象
"""
super(DetectionNetwork, self).__init__()
self.position_embeddings = position_embeddings
self.transformer_blocks = transformer_blocks
# 定义最后的预测层,预测哪个token是错误的
self.dense_layer = nn.Sequential(
nn.Linear(hidden_size, 1),
nn.Sigmoid()
)
def forward(self, word_embeddings):
# 获取token序列的长度,这里为128
sequence_length = word_embeddings.size(1)
# 生成position embedding
position_embeddings = self.position_embeddings(torch.LongTensor(range(sequence_length)).to(device))
# 融合work_embedding和position_embedding
x = word_embeddings + position_embeddings
# 将x一层一层的使用transformer encoder进行向后传递
for transformer_layer in self.transformer_blocks:
x = transformer_layer(x)[0]
# 最终返回Detection Network输出的hidden states和预测结果
hidden_states = x
return hidden_states, self.dense_layer(hidden_states)
class MDCSpellModel(nn.Module):
def __init__(self):
super(MDCSpellModel, self).__init__()
# 构造Correction Network
self.correction_network = CorrectionNetwork()
self._init_correction_dense_layer()
# 构造Detection Network
# position embedding使用BERT的
position_embeddings = self.correction_network.bert.embeddings.position_embeddings
# 作者在论文中提到的,Detection Network的Transformer使用BERT的权重
# 所以我这里直接克隆BERT的前两层Transformer来完成这个动作
transformer = copy.deepcopy(self.correction_network.bert.encoder.layer[:2])
# 提取BERT的词向量大小
hidden_size = self.correction_network.bert.config.hidden_size
# 构造Detection Network
self.detection_network = DetectionNetwork(position_embeddings, transformer, hidden_size)
def forward(self, sequences, max_length=128):
# 先获取word embedding,Correction Network和Detection Network都要用
inputs, word_embeddings = self.correction_network.get_inputs_and_word_embeddings(sequences, max_length)
# Detection Network进行前向传递,获取输出的Hidden State和预测结果
hidden_states, detection_outputs = self.detection_network(word_embeddings)
# Correction Network进行前向传递,获取其预测结果
correction_outputs = self.correction_network(inputs, word_embeddings, hidden_states)
# 返回Correction Network 和 Detection Network 的预测结果。
# 在计算损失时`[PAD]`token不需要参与计算,所以这里将`[PAD]`部分全都变为0
return correction_outputs, detection_outputs.squeeze(2) * inputs['attention_mask']
def _init_correction_dense_layer(self):
"""
原论文中提到,使用Word Embedding的weight来对Correction Network进行初始化
"""
self.correction_network.dense_layer.weight.data = self.correction_network.word_embedding_table.weight.data
定义好模型后,我们来简单的尝试一下:
model = MDCSpellModel().to(device)
correction_outputs, detection_outputs = model(["鸡你太美", "哎呦,你干嘛!"])
print("correction_outputs shape:", correction_outputs.size())
print("detection_outputs shape:", detection_outputs.size())
correction_outputs shape: torch.Size([2, 128, 21128])
detection_outputs shape: torch.Size([2, 128])
损失函数
Correction Network和Detection Network使用的都是Cross Entropy。之后进行相加即可:
其中 。作者通过实验得出
class MDCSpellLoss(nn.Module):
def __init__(self, coefficient=0.85):
super(MDCSpellLoss, self).__init__()
# 定义Correction Network的Loss函数
self.correction_criterion = nn.CrossEntropyLoss(ignore_index=0)
# 定义Detection Network的Loss函数,因为是二分类,所以用Binary Cross Entropy
self.detection_criterion = nn.BCELoss()
# 权重系数
self.coefficient = coefficient
def forward(self, correction_outputs, correction_targets, detection_outputs, detection_targets):
"""
:param correction_outputs: Correction Network的输出,Shape为(batch_size, sequence_length, hidden_size)
:param correction_targets: Correction Network的标签,Shape为(batch_size, sequence_length)
:param detection_outputs: Detection Network的输出,Shape为(batch_size, sequence_length)
:param detection_targets: Detection Network的标签,Shape为(batch_size, sequence_length)
:return:
"""
# 计算Correction Network的loss,因为Shape维度为3,所以要把batch_size和sequence_length进行合并才能计算
correction_loss = self.correction_criterion(correction_outputs.view(-1, correction_outputs.size(2)),
correction_targets.view(-1))
# 计算Detection Network的loss
detection_loss = self.detection_criterion(detection_outputs, detection_targets)
# 对两个loss进行加权平均
return self.coefficient * correction_loss + (1 - self.coefficient) *
模型训练
作者的训练方式:
- 第一步,首先使用Wang271K(自己造的假数据) 数据集进行训练。batch size为32, learning rate为2e-5
- 第二步,使用SIGHAN训练集进行fine-tune。 batch size为32,learning rate为1e-5
作者并没有提到使用的是什么Optimizer,但看这个学习率,应该是Adam。
在第一步,作者说的是使用了几乎3M个,但作者只提到过Wang271K这个数据集,我猜可能作者看错了,这个是0.3M条数据,而不是3M。
作者首先使用了Wang271K数据集进行对模型进行训练,然后又使用SIGHAN训练集对模型进行fine-tune。这里我就不进行fine-tune了,直接进行训练。我这里使用的是 ReaLiSe论文 处理好的数据集,其就是Wang271K和SIGHAN。
百度网盘链接 :https://pan.baidu.com/s/1x67LPiYAjLKhO1_2CI6aOA?pwd=skda
下载好直接解压即可。
构造Dataset
class CSCDataset(Dataset):
def __init__(self):
super(CSCDataset, self).__init__()
with open("data/trainall.times2.pkl", mode='br') as f:
train_data = pickle.load(f)
self.train_data = train_data
def __getitem__(self, index):
src = self.train_data[index]['src']
tgt = self.train_data[index]['tgt']
return src, tgt
def __len__(self):
return len(self.train_data)
train_data = CSCDataset()
train_data.__getitem__(0)
('纽约早盘作为基准的低硫轻油,五月份交割价攀升一点三四美元,来到每桶二十八点二五美元,而上周五曾下挫一豪元以上。',
'纽约早盘作为基准的低硫轻油,五月份交割价攀升一点三四美元,来到每桶二十八点二五美元,而上周五曾下挫一美元以上。')
构造Dataloader
tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")
def collate_fn(batch):
src, tgt = zip(*batch)
src, tgt = list(src), list(tgt)
src_tokens = tokenizer(src, padding='max_length', max_length=128, return_tensors='pt', truncation=True)['input_ids']
tgt_tokens = tokenizer(tgt, padding='max_length', max_length=128, return_tensors='pt', truncation=True)['input_ids']
correction_targets = tgt_tokens
detection_targets = (src_tokens != tgt_tokens).float()
return src, correction_targets, detection_targets, src_tokens # src_tokens在计算Correction的精准率时要用到
train_loader = DataLoader(train_data, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)
训练
criterion = MDCSpellLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
start_epoch = 0 # 从哪个epoch开始
total_step = 0 # 一共更新了多少次参数
# 恢复之前的训练
if os.path.exists(model_path):
if not torch.cuda.is_available():
checkpoint = torch.load(model_path, map_location='cpu')
else:
checkpoint = torch.load(model_path)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch']
total_step = checkpoint['total_step']
print("恢复训练,epoch:", start_epoch)
恢复训练,epoch: 2
model = model.to(device)
model = model.train()
训练这里代码量看起来很大,但实际大多都是计算recall和precision的代码。这里对于Detection的recall和precision的计算使用的是Detection Network的预测结果。
total_loss = 0. # 记录loss
d_recall_numerator = 0 # Detection的Recall的分子
d_recall_denominator = 0 # Detection的Recall的分母
d_precision_numerator = 0 # Detection的precision的分子
d_precision_denominator = 0 # Detection的precision的分母
c_recall_numerator = 0 # Correction的Recall的分子
c_recall_denominator = 0 # Correction的Recall的分母
c_precision_numerator = 0 # Correction的precision的分子
c_precision_denominator = 0 # Correction的precision的分母
for epoch in range(start_epoch, epochs):
step = 0
for sequences, correction_targets, detection_targets, correction_inputs in train_loader:
correction_targets, detection_targets = correction_targets.to(device), detection_targets.to(device)
correction_inputs = correction_inputs.to(device)
correction_outputs, detection_outputs = model(sequences)
loss = criterion(correction_outputs, correction_targets, detection_outputs, detection_targets)
loss.backward()
optimizer.step()
optimizer.zero_grad()
step += 1
total_step += 1
total_loss += loss.detach().item()
# 计算Detection的recall和precision指标
# 大于0.5,认为是错误token,反之为正确token
d_predicts = detection_outputs >= 0.5
# 计算错误token中被网络正确预测到的数量
d_recall_numerator += d_predicts[detection_targets == 1].sum().item()
# 计算错误token的数量
d_recall_denominator += (detection_targets == 1).sum().item()
# 计算网络预测的错误token的数量
d_precision_denominator += d_predicts.sum().item()
# 计算网络预测的错误token中,有多少是真错误的token
d_precision_numerator += (detection_targets[d_predicts == 1]).sum().item()
# 计算Correction的recall和precision
# 将输出映射成index,即将correction_outputs的Shape由(32, 128, 21128)变为(32,128)
correction_outputs = correction_outputs.argmax(2)
# 对于填充、[CLS]和[SEP]这三个token不校验
correction_outputs[(correction_targets == 0) | (correction_targets == 101) | (correction_targets == 102)] = 0
# correction_targets的[CLS]和[SEP]也要变为0
correction_targets[(correction_targets == 101) | (correction_targets == 102)] = 0
# Correction的预测结果,其中True表示预测正确,False表示预测错误或无需预测
c_predicts = correction_outputs == correction_targets
# 计算错误token中被网络正确纠正的token数量
c_recall_numerator += c_predicts[detection_targets == 1].sum().item()
# 计算错误token的数量
c_recall_denominator += (detection_targets == 1).sum().item()
# 计算网络纠正token的数量
correction_inputs[(correction_inputs == 101) | (correction_inputs == 102)] = 0
c_precision_denominator += (correction_outputs != correction_inputs).sum().item()
# 计算在网络纠正的这些token中,有多少是真正被纠正对的
c_precision_numerator += c_predicts[correction_outputs != correction_inputs].sum().item()
if total_step % log_after_step == 0:
loss = total_loss / log_after_step
d_recall = d_recall_numerator / (d_recall_denominator + 1e-9)
d_precision = d_precision_numerator / (d_precision_denominator + 1e-9)
c_recall = c_recall_numerator / (c_recall_denominator + 1e-9)
c_precision = c_precision_numerator / (c_precision_denominator + 1e-9)
print("Epoch {}, "
"Step {}/{}, "
"Total Step {}, "
"loss {:.5f}, "
"detection recall {:.4f}, "
"detection precision {:.4f}, "
"correction recall {:.4f}, "
"correction precision {:.4f}".format(epoch, step, len(train_loader), total_step,
loss,
d_recall,
d_precision,
c_recall,
c_precision))
total_loss = 0.
total_correct = 0
total_num = 0
d_recall_numerator = 0
d_recall_denominator = 0
d_precision_numerator = 0
d_precision_denominator = 0
c_recall_numerator = 0
c_recall_denominator = 0
c_precision_numerator = 0
c_precision_denominator = 0
torch.save({
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch + 1,
'total_step': total_step,
}, model_path)
。。。
Epoch 2, Step 15/8882, Total Step 8900, loss 0.02403, detection recall 0.4118, detection precision 0.8247, correction recall 0.8192, correction precision 0.9485
Epoch 2, Step 35/8882, Total Step 8920, loss 0.03479, detection recall 0.3658, detection precision 0.8055, correction recall 0.8029, correction precision 0.9125
。。。
模型评估
模型评估使用了SIGHAN 2013,2014,2015三个数据集对模型进行评估。对于Detection的Precision和Recall的评估,使用的是Correction Network的结果,这和训练阶段有所不同,这是因为Detection Network只是帮助Correction Network训练的,其结果在使用时不具备参考价值。
model = model.eval()
def eval(test_data):
d_recall_numerator = 0 # Detection的Recall的分子
d_recall_denominator = 0 # Detection的Recall的分母
d_precision_numerator = 0 # Detection的precision的分子
d_precision_denominator = 0 # Detection的precision的分母
c_recall_numerator = 0 # Correction的Recall的分子
c_recall_denominator = 0 # Correction的Recall的分母
c_precision_numerator = 0 # Correction的precision的分子
c_precision_denominator = 0 # Correction的precision的分母
prograss = tqdm(range(len(test_data)))
for i in prograss:
src, tgt = test_data[i]['src'], test_data[i]['tgt']
src_tokens = tokenizer(src, return_tensors='pt', max_length=128, truncation=True)['input_ids'][0][1:-1]
tgt_tokens = tokenizer(tgt, return_tensors='pt', max_length=128, truncation=True)['input_ids'][0][1:-1]
# 正常情况下,src和tgt的长度应该是一致的
if len(src_tokens) != len(tgt_tokens):
print("第%d条数据异常" % i)
continue
correction_outputs, _ = model(src)
predict_tokens = correction_outputs[0][1:len(src_tokens) + 1].argmax(1).detach().cpu()
# 计算错误token的数量
d_recall_denominator += (src_tokens != tgt_tokens).sum().item()
# 计算在这些错误token,有多少网络也认为它是错误的
d_recall_numerator += (predict_tokens != src_tokens)[src_tokens != tgt_tokens].sum().item()
# 计算网络找出的错误token的数量
d_precision_denominator += (predict_tokens != src_tokens).sum().item()
# 计算在网络找出的这些错误token中,有多少是真正错误的
d_precision_numerator += (src_tokens != tgt_tokens)[predict_tokens != src_tokens].sum().item()
# 计算Detection的recall、precision和f1-score
d_recall = d_recall_numerator / (d_recall_denominator + 1e-9)
d_precision = d_precision_numerator / (d_precision_denominator + 1e-9)
d_f1_score = 2 * (d_recall * d_precision) / (d_recall + d_precision + 1e-9)
# 计算错误token的数量
c_recall_denominator += (src_tokens != tgt_tokens).sum().item()
# 计算在这些错误token中,有多少网络预测对了
c_recall_numerator += (predict_tokens == tgt_tokens)[src_tokens != tgt_tokens].sum().item()
# 计算网络找出的错误token的数量
c_precision_denominator += (predict_tokens != src_tokens).sum().item()
# 计算网络找出的错误token中,有多少是正确修正的
c_precision_numerator += (predict_tokens == tgt_tokens)[predict_tokens != src_tokens].sum().item()
# 计算Correction的recall、precision和f1-score
c_recall = c_recall_numerator / (c_recall_denominator + 1e-9)
c_precision = c_precision_numerator / (c_precision_denominator + 1e-9)
c_f1_score = 2 * (c_recall * c_precision) / (c_recall + c_precision + 1e-9)
prograss.set_postfix({
'd_recall': d_recall,
'd_precision': d_precision,
'd_f1_score': d_f1_score,
'c_recall': c_recall,
'c_precision': c_precision,
'c_f1_score': c_f1_score,
})
with open("data/test.sighan13.pkl", mode='br') as f:
sighan13 = pickle.load(f)
eval(sighan13)
100%|██████████| 1000/1000 [00:11<00:00, 90.12it/s, d_recall=0.799, d_precision=0.802, d_f1_score=0.8, c_recall=0.769, c_precision=0.772, c_f1_score=0.771]
with open("data/test.sighan14.pkl", mode='br') as f:
sighan14 = pickle.load(f)
eval(sighan14)
100%|██████████| 1062/1062 [00:12<00:00, 85.48it/s, d_recall=0.666, d_precision=0.828, d_f1_score=0.738, c_recall=0.643, c_precision=0.799, c_f1_score=0.712]
with open("data/test.sighan15.pkl", mode='br') as f:
sighan15 = pickle.load(f)
eval(sighan15)
100%|██████████| 1100/1100 [00:11<00:00, 92.04it/s, d_recall=0.761, d_precision=0.867, d_f1_score=0.811, c_recall=0.725, c_precision=0.827, c_f1_score=0.773]
模型使用
最后,我们来真正的使用一下该模型,看下效果:
def predict(text):
sequences = [text]
correction_outputs, _ = model(sequences)
tokens = correction_outputs[0][1:len(text) + 1].argmax(1)
return ''.join(tokenizer.convert_ids_to_tokens(tokens))
predict("今天早上我吃了以个火聋果")
'今天早上我吃了一个火聋果'
predict("我是联系时长两年半的个人练习生蔡徐鲲,喜欢唱跳RAP蓝球")
'我是联系时长两年半的个人练习生蔡徐鲲,喜欢唱跳ra##p蓝球[SEP]'
虽然在数据上模型表现还不错,但在真正使用场景上,效果还是不够好。中文文本纠错果然是一个比较难的任务 T_T !
参考文献
MDCSpell论文: https://aclanthology.org/2022.findings-acl.98/