0
点赞
收藏
分享

微信扫一扫

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor)


 

解决方法:

def conv_2(batch, channel=256):   

mydevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu') ### aimin li, added
    for x in batch:  
    x.to(mydevice)     
        weights = torch.ones(channel, in_channel, 2, 2).to(mydevice)    
        b = torch.ones(channel).to(mydevice)
        out = F.conv2d(input=x, weight=weights, bias=b, stride=1, padding=1)

 



举报

相关推荐

0 条评论