0
点赞
收藏
分享

微信扫一扫

线性回归简洁实现代码(李沐动手学)

善解人意的娇娇 2022-01-26 阅读 78

讲解上b站听,此处贴pycharm代码
从零实现代码链接:https://blog.csdn.net/tongjingqi_/article/details/122690408?spm=1001.2014.3001.5502

import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2l
true_w=torch.tensor([2,-3.4])
true_b=4.2
features,labels=d2l.synthetic_data(true_w,true_b,1000)

def load_array(data_arrays,batch_size,is_train=True):
    """构造一个Pytorch数据可迭代对象"""
    dataset=data.TensorDataset(*data_arrays)
    return data.DataLoader(dataset,batch_size,shuffle=is_train)

batch_size=10
data_iter=load_array((features,labels),batch_size)
next(iter(data_iter))

from torch import nn
net=nn.Sequential(nn.Linear(2,1))

net[0].weight.data.normal_(0,0.01)
net[0].bias.data.fill_(0)

loss=nn.MSELoss()
trainer=torch.optim.SGD(net.parameters(),lr=0.03)

num_epochs=3
for epoch in range(num_epochs):
    for X,y in data_iter:
        l=loss(net(X),y)
        trainer.zero_grad()
        l.backward()
        trainer.step()
    l=loss(net(features),labels)
    print(f'epoch{epoch+1},loss{l:f}')

举报

相关推荐

0 条评论