基于PyTorch的基本操作#

这个例子和前一个例子几乎一样,唯一的不同在于我们将NumPy换成了PyTorch。因为PyTorch已经支持了NumPy的大部分功能,而且在CPU上进行计算时,速度并不慢于numpy,而且还支持可微分、可函数变换、可GPU加速等功能,所以大家写数值代码时完全可以直接基于Pytorch而非Numpy。

import torch
import scipy
import matplotlib.pyplot as plt

操作张量#

x=torch.linspace(0,1,100)
a=2
b=3
noise=0.1*torch.randn(100)
y=a*x+b+noise

使用Scipy进行线性回归#

slope, intercept, r, p, se =scipy.stats.linregress(x,y)

使用Matplotlib画图#

plt.plot(x,y,label="data")
plt.plot(x,intercept+slope*x,label="fit")
plt.legend()
plt.savefig("linear_fit.pdf")
_images/pytorch_example1_numpy_7_0.png