pytorch 重复采样 与 非重复采样

Just_Esme

关注

阅读 71

2022-07-27


import torch
import torch.nn.functional as F
from torch.autograd import *

a = Variable(torch.FloatTensor([[0,0,0,0,0,0,90,100]]))
b=F.softmax(a,-1)

print(b.multinomial()) # 7 或 6
print(b.multinomial(2)) # 6,7 或 7,6
print(b.multinomial(2,True)) # 7,7 或 7,6 或 6,7 或 6,6

也可以试试
WeightedRandomSampler
主要是replace也就是True False那个参数决定采样数据是否重复


精彩评论(0)

0 0 举报