RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor)

seuleyang

关注

阅读 4

10-14 12:00


 

解决方法:

def conv_2(batch, channel=256):   

mydevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu') ### aimin li, added
    for x in batch:  
    x.to(mydevice)     
        weights = torch.ones(channel, in_channel, 2, 2).to(mydevice)    
        b = torch.ones(channel).to(mydevice)
        out = F.conv2d(input=x, weight=weights, bias=b, stride=1, padding=1)

 



精彩评论(0)

0 0 举报