这里第5问用到的函数是上一节微分里面写好的,原书把它封装到了一个包里面,封装步骤见本书正文前面还有微分那一节。我就直接cv下来了。
以下是代码部分:
#第二问
import torch
x=torch.arange(4.0,requires_grad=True)
print(x)
y=torch.dot(x,x)
print(y)
y.backward(retain_graph=True)
print("第一次运行反向传播:",x.grad)
y.backward()
print("第二次运行反向传播:",x.grad)
#第三问
x.grad.zero_()
def g(a):
b = a * 2
while b.norm() < 1000:
b = b * 2
if b.sum() > 0:
c = b
else:
c = 100 * b
return c
a = torch.rand(size=(3,3),requires_grad=True)
d = g(a)
d.sum().backward()
a.grad
x.grad.zero_()
b = torch.rand(4)
b.requires_grad=True
e = g(b)
e.backward()
b.grad
#第四问
def practice_control_flow(a):
b = a * a
if b < 0:
return b
else:
if b>=1:
while b < 100:
b = b**2
else:
b = b + 1
return b
x.grad.zero_()
a.grad.zero_()
a = torch.tensor([5.],requires_grad=True)
result=practice_control_flow(a)
result.backward()
print(result,a.grad)
b = torch.tensor([0.1],requires_grad=True)
result2=practice_control_flow(b)
result2.backward()
print(result2,b.grad)
c = torch.tensor([-12.],requires_grad=True)
result3=practice_control_flow(c)
result3.backward()
print(result3,c.grad)
#第五问
%matplotlib inline
import numpy as np
from IPython import display
import torch
from matplotlib import pyplot as plt
#定义f(x)
def f(x):
return 3 * x ** 2 - 4 * x
#一下画图函数后面可能还用的到,用于把几个函数设置到同一个图像中
def use_svg_display():
#使用svg格式(可缩放矢量图,一种图形标识规范,很高清)在Jupyter里面作图
display.set_matplotlib_formats('svg')
#设置matplotlib的图表大小
def set_figsize(figsize=(3.5,2.5)):
use_svg_display()
plt.rcParams['figure.figsize'] = figsize
#设置由matplotlib⽣成图表的轴的属性,legend用于显示图像标签
def set_axes(axes,xlabel,ylabel,xlim,ylim,xscale,yscale,legend):
#设置matplotlib的轴
axes.set_xlabel(xlabel)
axes.set_ylabel(ylabel)
axes.set_xscale(xscale)
axes.set_yscale(yscale)
axes.set_xlim(xlim)
axes.set_ylim(ylim)
if legend:
axes.legend(legend)
#在图像中显示网格
axes.grid()
#xscale='linear'可以使得x轴和y轴刻度值呈线性关系 ='log'则y轴刻度与x轴的满足y=lgx. 默认是'linear'
#定义了plot函数来简洁地绘制多条曲线,因为我们需要在整个书中可视化许多曲线。
def plot(X, Y=None, xlabel=None, ylabel=None, legend=None, xlim=None, ylim=None,
xscale='linear',yscale='linear',fmts=('-','m--','g-.','r:'),figsize=(3.5,2.5),axes=None):
#绘制数据点
if legend is None:
legend=[]
set_figsize(figsize)
axes = axes if axes else plt.gca()
#如果X有一个轴则输出为true
def has_one_axis(X):
return (hasattr(X,"ndim") and X.ndim==1 or isinstance(X,list) and not hasattr(X[0],'__len__'))
if has_one_axis(X):
X = [X]
if Y is None:
X,Y=[[]]*len(X),X
elif has_one_axis(Y):
Y=[Y]
if len(X)!=len(Y):
X=X*len(Y)
axes.cla()
for x,y,fmt in zip(X,Y,fmts):
if(len(X)):
axes.plot(x,y,fmt)
else:
axes.plot(y,fmt)
set_axes(axes,xlabel,ylabel,xlim,ylim,xscale,yscale,legend)
x=torch.arange(-10.,10.,0.1)
x.requires_grad=True
y=torch.sin(x)
y.sum().backward()
plot(x.detach(),[y.detach(),x.grad,],'x','f(x)',legend=['f(x)','Tangent line=f\'(x)'])