0
点赞
收藏
分享

微信扫一扫

Adversarial Training的pytorch的实现

寒羽鹿 2022-07-27 阅读 159


def at_loss(embedder, encoder, clf, batch, perturb_norm_length=5.0):
embedded = embedder(batch) # [seq_len,batch,hidden_dim]
embedded.retain_grad()
ce = F.cross_entropy((clf(encoder(embedded, batch)[0])), batch.labels)
ce.backward()

d = embedded.grad.data.transpose(0, 1).contiguous() # [batch,seq_len,hidden]
d = get_normalized_vector(d)
d = d.transpose(0, 1).contiguous() # [seq_len,batch,hidden]

d = embedder(batch) + (perturb_norm_length * Variable(d))
loss = F.cross_entropy(clf(encoder(d, batch)[0]), batch.labels)
return loss

摘自https://github.com/DevSinghSachan/ssl_text_classification


举报

相关推荐

0 条评论