PyTorch的函数变换
Contents
PyTorch的函数变换#
下面介绍一下由torch.func提供的函数变换功能,这在使用PyTorch做科学计算时尤为有用。我们下面展示一些它的例子。
自动向量化并行#
下面我们定义一个函数,这个函数的作用是对
import torch
from torch.func import vmap
def f(x):
return torch.outer(x,x)
x=torch.arange(1.,4.)
print(f(x))
bx=vmap(f)(f(x))
print(bx)
tensor([[1., 2., 3.],
[2., 4., 6.],
[3., 6., 9.]])
tensor([[[ 1., 2., 3.],
[ 2., 4., 6.],
[ 3., 6., 9.]],
[[ 4., 8., 12.],
[ 8., 16., 24.],
[12., 24., 36.]],
[[ 9., 18., 27.],
[18., 36., 54.],
[27., 54., 81.]]])
tensor([[1., 2., 3.],
[2., 4., 6.],
[3., 6., 9.]])
tensor([[[ 1., 2., 3.],
[ 2., 4., 6.],
[ 3., 6., 9.]],
[[ 4., 8., 12.],
[ 8., 16., 24.],
[12., 24., 36.]],
[[ 9., 18., 27.],
[18., 36., 54.],
[27., 54., 81.]]])
这个例子可能过于简单,显示不出来这个函数的威力。但是对于PyTorch写成的几乎所有函数,都可以这么做,包括复杂的求解器,例如我们写的RTE求解器。例如
import math
import numpy as np
import torch.nn.functional as F
def get_gauss_point(n):
x, w = np.polynomial.legendre.leggauss(n)
return torch.from_numpy(x).to(dtype=torch.get_default_dtype()), torch.from_numpy(w).to(dtype=torch.get_default_dtype())
a = 1
c = 1
rho = 1
cv = 1
# spatial grid
N = 400
L = 1
dx = L / N
x = torch.arange(0.5 * dx, L + 0.5 * dx, dx)
Tini = vmap(lambda cas:(1 + 0.1 * torch.sin(2 * math.pi * x+torch.sin(cas)))*(1+torch.sin(cas)))(torch.tensor([1,2,3,4]))
sigma = vmap(lambda cas:(1 + 0.1 * torch.sin(2 * math.pi * x+torch.sin(cas+1)))*(1+torch.sin(cas)))(torch.tensor([1,2,3,4]))
def solve(Tini,sigma):
CFL = 0.8
dt = CFL * dx # time step
dtc = dt * c
ddtc = 1 / dtc
datarecord_sigma=sigma.clone()[None,:]
# velocity grid & angle
Nvx = 8
mu, wmu = get_gauss_point(Nvx)
# distribution function
T = Tini
I = 0.5 * a * c * Tini**4
I = I.repeat(Nvx, 1)
I = F.pad(I[None, ...], (1, 1), mode='circular')[0]
I0 = wmu @ I # energe
sigma = sigma.repeat(Nvx // 2, 1)
t=1.0
Nt=int(t / dt)
list_T=[]
list_E=[]
for loop in range(Nt): #=1: 1/dt
I_out = I.clone()
T_out = T.clone()
I0_out = I0.clone()
index = slice(1, -1)
index_add1 = slice(2, None)
index_sub1 = slice(None, -2)
# streaming, positive vx
lv = slice(Nvx // 2, None)
coe = mu[lv]
I[lv, index] = I_out[lv, index] - dt / dx * coe[..., None] * (
I_out[lv, index] - I_out[lv, index_sub1]) + dt * sigma * (
(0.5 * a * c * T_out**4).repeat(Nvx // 2, 1) - I_out[lv, index])
# streaming, negative vx
lv = slice(0, Nvx // 2)
coe = mu[lv]
I[lv, index] = I_out[lv, index] - dt / dx * coe[..., None] * (
I_out[lv, index_add1] - I_out[lv, index]) + dt * sigma * (
(0.5 * a * c * T_out**4).repeat(Nvx // 2, 1) - I_out[lv, index])
I = F.pad(I[None, ..., 1:-1], (1, 1), mode='circular')[0]
T = T_out + dt / cv * sigma[0, :] * (I0_out[index] - a * c * T_out**4)
I0 = wmu @ I
if loop%10==0:
list_E.append(I0[...,1:-1].clone())
list_T.append(T.clone())
datarecord_E=torch.stack(list_E,dim=0)
datarecord_T=torch.stack(list_T,dim=0)
return datarecord_E,datarecord_T,datarecord_sigma
E,T,sigma=vmap(solve)(Tini,sigma)
print(E.shape,T.shape,sigma.shape)
torch.Size([4, 50, 400]) torch.Size([4, 50, 400]) torch.Size([4, 1, 400])
torch.Size([4, 50, 400]) torch.Size([4, 50, 400]) torch.Size([4, 1, 400])
自动求导#
最简单的就是使用grad求函数导数,它假设函数返回的是一个值,然后求这个值对于输入的导数。
import torch
from torch.func import grad
x = torch.randn([]) #可以试试改形状会发生什么
fx = lambda x: torch.sin(x)
cos_x = grad(fx)(x)
torch.cos(x),cos_x
(tensor(0.9999), tensor(0.9999))
(tensor(0.9999), tensor(0.9999))
如果输出的不是一个值,那么grad方法就不再起作用了。这时,根据我们想要的导数不同,有多种不同的做法。例如对于sin这个例子,假设输入的是一个向量,我们其实只想求输出的每个值对应输入的每个值的导,就可以用vmap
x = torch.randn([10,]) #可以试试改形状会发生什么
fx=lambda x: torch.sin(x)
cos_x = vmap(grad(fx))(x) #grad函数被向量化到了一个vector上
torch.cos(x),cos_x
(tensor([ 0.3396, 0.8663, 0.8830, 0.7084, 0.9351, -0.1285, -0.0045, 0.8071,
-0.1174, 0.9670]),
tensor([ 0.3396, 0.8663, 0.8830, 0.7084, 0.9351, -0.1285, -0.0045, 0.8071,
-0.1174, 0.9670]))
(tensor([ 0.3396, 0.8663, 0.8830, 0.7084, 0.9351, -0.1285, -0.0045, 0.8071,
-0.1174, 0.9670]),
tensor([ 0.3396, 0.8663, 0.8830, 0.7084, 0.9351, -0.1285, -0.0045, 0.8071,
-0.1174, 0.9670]))
还有一种情况,就是我们想要求的就是输出的每个值对输入的每个值的导数,此时我们相求的是jacobian矩阵。
from torch.func import jacrev,jacfwd,hessian
x = torch.randn([10,]) #可以试试改形状会发生什么
fx=lambda x: torch.sin(x)
cos_x = jacrev(fx)(x) #grad函数被向量化到了一个vector上
print(cos_x.shape,x.shape,fx(x).shape)
torch.cos(x),cos_x
torch.Size([10, 10]) torch.Size([10]) torch.Size([10])
(tensor([ 0.9992, 0.7487, 0.9434, 0.9032, 0.9164, 0.6165, 0.5111, 1.0000,
0.7079, -0.1347]),
tensor([[ 0.9992, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, -0.0000],
[ 0.0000, 0.7487, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, -0.0000],
[ 0.0000, 0.0000, 0.9434, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, -0.0000],
[ 0.0000, 0.0000, 0.0000, 0.9032, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, -0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.9164, 0.0000, 0.0000, 0.0000,
0.0000, -0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6165, 0.0000, 0.0000,
0.0000, -0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5111, 0.0000,
0.0000, -0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000,
0.0000, -0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.7079, -0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, -0.1347]]))
torch.Size([10, 10]) torch.Size([10]) torch.Size([10])
(tensor([ 0.9992, 0.7487, 0.9434, 0.9032, 0.9164, 0.6165, 0.5111, 1.0000,
0.7079, -0.1347]),
tensor([[ 0.9992, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, -0.0000],
[ 0.0000, 0.7487, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, -0.0000],
[ 0.0000, 0.0000, 0.9434, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, -0.0000],
[ 0.0000, 0.0000, 0.0000, 0.9032, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, -0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.9164, 0.0000, 0.0000, 0.0000,
0.0000, -0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6165, 0.0000, 0.0000,
0.0000, -0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5111, 0.0000,
0.0000, -0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000,
0.0000, -0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.7079, -0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, -0.1347]]))
from torch.func import jacrev,jacfwd,hessian
x = torch.randn([10,2]) #可以试试改形状会发生什么
fx=lambda x: torch.sin(x)
hessian_x = hessian(fx)(x) #grad函数被向量化到了一个vector上
print(hessian_x.shape,x.shape)
torch.Size([10, 2, 10, 2, 10, 2]) torch.Size([10, 2])
torch.Size([10, 2, 10, 2, 10, 2]) torch.Size([10, 2])
更为常见的情况是以上两种的混合,例如,我们用一个神经网络去拟合二维的Euler方程的解,这个网络是从向量到向量的一个映射,我们希望计算得到\(\frac{\partial f}{\partial (x,t)}\)这样一个有6个元素的jacobian。同时我们网络的输入是batchsize=1000,我们就要同时计算这1000个sample的jacobian
x = torch.randn([1000,2])
Net=torch.nn.Linear(2,3)
y=Net(x)
y2=vmap(Net)(x)
print(f"{x.shape=},{y.shape=},{y2.shape=}")
dy_dx_1=vmap(jacrev(Net))(x)
dy_dx_2=vmap(jacfwd(Net))(x)
print(f"{dy_dx_1.shape=},{dy_dx_2.shape=}")
def myhessian(f):
return jacfwd(jacrev(f))
def myhessian2(f):
return jacrev(jacfwd(f))
ddy_dxx=vmap(hessian(Net))(x)
ddy_dxx2=vmap(myhessian(Net))(x)
ddy_dxx3=vmap(myhessian2(Net))(x)
print(f"{ddy_dxx.shape=},{ddy_dxx2.shape=},{ddy_dxx3.shape=}")
x.shape=torch.Size([1000, 2]),y.shape=torch.Size([1000, 3]),y2.shape=torch.Size([1000, 3])
dy_dx_1.shape=torch.Size([1000, 3, 2]),dy_dx_2.shape=torch.Size([1000, 3, 2])
ddy_dxx.shape=torch.Size([1000, 3, 2, 2]),ddy_dxx2.shape=torch.Size([1000, 3, 2, 2]),ddy_dxx3.shape=torch.Size([1000, 3, 2, 2])
x.shape=torch.Size([1000, 2]),y.shape=torch.Size([1000, 3]),y2.shape=torch.Size([1000, 3])
dy_dx_1.shape=torch.Size([1000, 3, 2]),dy_dx_2.shape=torch.Size([1000, 3, 2])
ddy_dxx.shape=torch.Size([1000, 3, 2, 2]),ddy_dxx2.shape=torch.Size([1000, 3, 2, 2]),ddy_dxx3.shape=torch.Size([1000, 3, 2, 2])
如何选择jacrev与jacfwd#
jacrev与jacfwd
These two functions compute the same values (up to machine numerics), but differ in their implementation: jacfwd uses forward-mode automatic differentiation, which is more efficient for “tall” Jacobian matrices, while jacrev uses reverse-mode, which is more efficient for “wide” Jacobian matrices. For matrices that are near-square, jacfwd probably has an edge over jacrev.
以及对于hessian
To implement hessian, we could have used jacfwd(jacrev(f)) or jacrev(jacfwd(f)) or any other composition of the two. But forward-over-reverse is typically the most efficient. That’s because in the inner Jacobian computation we’re often differentiating a function wide Jacobian (maybe like a loss function 𝑓:ℝⁿ→ℝ), while in the outer Jacobian computation we’re differentiating a function with a square Jacobian (since ∇𝑓:ℝⁿ→ℝⁿ), which is where forward-mode wins out.
More#
https://pytorch.org/tutorials/beginner/basics/autogradqs_tutorial.html
https://pytorch.org/tutorials/intermediate/per_sample_grads.html
https://pytorch.org/tutorials/intermediate/jacobians_hessians.html