def at_loss(embedder, encoder, clf, batch, perturb_norm_length=5.0):
embedded = embedder(batch) # [seq_len,batch,hidden_dim]
embedded.retain_grad()
ce = F.cross_entropy((clf(encoder(embedded, batch)[0])), batch.labels)
ce.backward()
d = embedded.grad.data.transpose(0, 1).contiguous() # [batch,seq_len,hidden]
d = get_normalized_vector(d)
d = d.transpose(0, 1).contiguous() # [seq_len,batch,hidden]
d = embedder(batch) + (perturb_norm_length * Variable(d))
loss = F.cross_entropy(clf(encoder(d, batch)[0]), batch.labels)
return loss
摘自https://github.com/DevSinghSachan/ssl_text_classification