PyTorch的函数变换#

下面介绍一下由torch.func提供的函数变换功能,这在使用PyTorch做科学计算时尤为有用。我们下面展示一些它的例子。

自动向量化并行#

下面我们定义一个函数,这个函数的作用是对

import torch
from torch.func import vmap

def f(x):
    return torch.outer(x,x)

x=torch.arange(1.,4.)
print(f(x))
bx=vmap(f)(f(x))
print(bx)
tensor([[1., 2., 3.],
        [2., 4., 6.],
        [3., 6., 9.]])
tensor([[[ 1.,  2.,  3.],
         [ 2.,  4.,  6.],
         [ 3.,  6.,  9.]],

        [[ 4.,  8., 12.],
         [ 8., 16., 24.],
         [12., 24., 36.]],

        [[ 9., 18., 27.],
         [18., 36., 54.],
         [27., 54., 81.]]])
tensor([[1., 2., 3.],
        [2., 4., 6.],
        [3., 6., 9.]])
tensor([[[ 1.,  2.,  3.],
         [ 2.,  4.,  6.],
         [ 3.,  6.,  9.]],

        [[ 4.,  8., 12.],
         [ 8., 16., 24.],
         [12., 24., 36.]],

        [[ 9., 18., 27.],
         [18., 36., 54.],
         [27., 54., 81.]]])

这个例子可能过于简单,显示不出来这个函数的威力。但是对于PyTorch写成的几乎所有函数,都可以这么做,包括复杂的求解器,例如我们写的RTE求解器。例如

import math
import numpy as np
import torch.nn.functional as F

def get_gauss_point(n):
    x, w = np.polynomial.legendre.leggauss(n)
    return torch.from_numpy(x).to(dtype=torch.get_default_dtype()), torch.from_numpy(w).to(dtype=torch.get_default_dtype())

a = 1
c = 1
rho = 1
cv = 1

# spatial grid
N = 400
L = 1
dx = L / N
x = torch.arange(0.5 * dx, L + 0.5 * dx, dx)

Tini =  vmap(lambda cas:(1 + 0.1 * torch.sin(2 * math.pi * x+torch.sin(cas)))*(1+torch.sin(cas)))(torch.tensor([1,2,3,4]))
sigma = vmap(lambda cas:(1 + 0.1 * torch.sin(2 * math.pi * x+torch.sin(cas+1)))*(1+torch.sin(cas)))(torch.tensor([1,2,3,4]))
    
def solve(Tini,sigma):
    CFL = 0.8
    dt = CFL * dx  # time step
    dtc = dt * c
    ddtc = 1 / dtc

    datarecord_sigma=sigma.clone()[None,:]
    
    # velocity grid & angle
    Nvx = 8
    mu, wmu = get_gauss_point(Nvx)

    # distribution function
    T = Tini
    I = 0.5 * a * c * Tini**4
    I = I.repeat(Nvx, 1)
    I = F.pad(I[None, ...], (1, 1), mode='circular')[0]
    I0 = wmu @ I  # energe
    sigma = sigma.repeat(Nvx // 2, 1)

    t=1.0
    Nt=int(t / dt)
    list_T=[]
    list_E=[]
    for loop in range(Nt):  #=1: 1/dt
        I_out = I.clone()
        T_out = T.clone()
        I0_out = I0.clone()

        index = slice(1, -1)
        index_add1 = slice(2, None)
        index_sub1 = slice(None, -2)

        # streaming, positive vx
        lv = slice(Nvx // 2, None)
        coe = mu[lv]
        I[lv, index] = I_out[lv, index] - dt / dx * coe[..., None] * (
            I_out[lv, index] - I_out[lv, index_sub1]) + dt * sigma * (
                (0.5 * a * c * T_out**4).repeat(Nvx // 2, 1) - I_out[lv, index])
        
        # streaming, negative vx
        lv = slice(0, Nvx // 2)
        coe = mu[lv]
        I[lv, index] = I_out[lv, index] - dt / dx * coe[..., None] * (
            I_out[lv, index_add1] - I_out[lv, index]) + dt * sigma * (
                (0.5 * a * c * T_out**4).repeat(Nvx // 2, 1) - I_out[lv, index])

        I = F.pad(I[None, ..., 1:-1], (1, 1), mode='circular')[0]
        T = T_out + dt / cv * sigma[0, :] * (I0_out[index] - a * c * T_out**4)
        I0 = wmu @ I

        if loop%10==0:
            list_E.append(I0[...,1:-1].clone())
            list_T.append(T.clone())
    datarecord_E=torch.stack(list_E,dim=0)
    datarecord_T=torch.stack(list_T,dim=0)
    return datarecord_E,datarecord_T,datarecord_sigma

E,T,sigma=vmap(solve)(Tini,sigma)
print(E.shape,T.shape,sigma.shape)
torch.Size([4, 50, 400]) torch.Size([4, 50, 400]) torch.Size([4, 1, 400])
torch.Size([4, 50, 400]) torch.Size([4, 50, 400]) torch.Size([4, 1, 400])

自动求导#

最简单的就是使用grad求函数导数,它假设函数返回的是一个值,然后求这个值对于输入的导数。

import torch
from torch.func import grad
x = torch.randn([]) #可以试试改形状会发生什么
fx = lambda x: torch.sin(x)
cos_x = grad(fx)(x)

torch.cos(x),cos_x
(tensor(0.9999), tensor(0.9999))
(tensor(0.9999), tensor(0.9999))

如果输出的不是一个值,那么grad方法就不再起作用了。这时,根据我们想要的导数不同,有多种不同的做法。例如对于sin这个例子,假设输入的是一个向量,我们其实只想求输出的每个值对应输入的每个值的导,就可以用vmap

x = torch.randn([10,]) #可以试试改形状会发生什么
fx=lambda x: torch.sin(x)
cos_x = vmap(grad(fx))(x) #grad函数被向量化到了一个vector上

torch.cos(x),cos_x
(tensor([ 0.3396,  0.8663,  0.8830,  0.7084,  0.9351, -0.1285, -0.0045,  0.8071,
         -0.1174,  0.9670]),
 tensor([ 0.3396,  0.8663,  0.8830,  0.7084,  0.9351, -0.1285, -0.0045,  0.8071,
         -0.1174,  0.9670]))
(tensor([ 0.3396,  0.8663,  0.8830,  0.7084,  0.9351, -0.1285, -0.0045,  0.8071,
         -0.1174,  0.9670]),
 tensor([ 0.3396,  0.8663,  0.8830,  0.7084,  0.9351, -0.1285, -0.0045,  0.8071,
         -0.1174,  0.9670]))

还有一种情况,就是我们想要求的就是输出的每个值对输入的每个值的导数,此时我们相求的是jacobian矩阵。

from torch.func import jacrev,jacfwd,hessian

x = torch.randn([10,]) #可以试试改形状会发生什么
fx=lambda x: torch.sin(x)
cos_x = jacrev(fx)(x) #grad函数被向量化到了一个vector上

print(cos_x.shape,x.shape,fx(x).shape)
torch.cos(x),cos_x
torch.Size([10, 10]) torch.Size([10]) torch.Size([10])
(tensor([ 0.9992,  0.7487,  0.9434,  0.9032,  0.9164,  0.6165,  0.5111,  1.0000,
          0.7079, -0.1347]),
 tensor([[ 0.9992,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000, -0.0000],
         [ 0.0000,  0.7487,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000, -0.0000],
         [ 0.0000,  0.0000,  0.9434,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000, -0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.9032,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000, -0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.9164,  0.0000,  0.0000,  0.0000,
           0.0000, -0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.6165,  0.0000,  0.0000,
           0.0000, -0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.5111,  0.0000,
           0.0000, -0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  1.0000,
           0.0000, -0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.7079, -0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000, -0.1347]]))
torch.Size([10, 10]) torch.Size([10]) torch.Size([10])
(tensor([ 0.9992,  0.7487,  0.9434,  0.9032,  0.9164,  0.6165,  0.5111,  1.0000,
          0.7079, -0.1347]),
 tensor([[ 0.9992,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000, -0.0000],
         [ 0.0000,  0.7487,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000, -0.0000],
         [ 0.0000,  0.0000,  0.9434,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000, -0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.9032,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000, -0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.9164,  0.0000,  0.0000,  0.0000,
           0.0000, -0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.6165,  0.0000,  0.0000,
           0.0000, -0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.5111,  0.0000,
           0.0000, -0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  1.0000,
           0.0000, -0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.7079, -0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000, -0.1347]]))
from torch.func import jacrev,jacfwd,hessian

x = torch.randn([10,2]) #可以试试改形状会发生什么
fx=lambda x: torch.sin(x)
hessian_x = hessian(fx)(x) #grad函数被向量化到了一个vector上

print(hessian_x.shape,x.shape)
torch.Size([10, 2, 10, 2, 10, 2]) torch.Size([10, 2])
torch.Size([10, 2, 10, 2, 10, 2]) torch.Size([10, 2])

更为常见的情况是以上两种的混合,例如,我们用一个神经网络去拟合二维的Euler方程的解,这个网络是从向量到向量的一个映射,我们希望计算得到\(\frac{\partial f}{\partial (x,t)}\)这样一个有6个元素的jacobian。同时我们网络的输入是batchsize=1000,我们就要同时计算这1000个sample的jacobian

x = torch.randn([1000,2])
Net=torch.nn.Linear(2,3)
y=Net(x)
y2=vmap(Net)(x)
print(f"{x.shape=},{y.shape=},{y2.shape=}")

dy_dx_1=vmap(jacrev(Net))(x)
dy_dx_2=vmap(jacfwd(Net))(x)
print(f"{dy_dx_1.shape=},{dy_dx_2.shape=}")

def myhessian(f):
    return jacfwd(jacrev(f))
def myhessian2(f):
    return jacrev(jacfwd(f))

ddy_dxx=vmap(hessian(Net))(x)
ddy_dxx2=vmap(myhessian(Net))(x)
ddy_dxx3=vmap(myhessian2(Net))(x)

print(f"{ddy_dxx.shape=},{ddy_dxx2.shape=},{ddy_dxx3.shape=}")
x.shape=torch.Size([1000, 2]),y.shape=torch.Size([1000, 3]),y2.shape=torch.Size([1000, 3])
dy_dx_1.shape=torch.Size([1000, 3, 2]),dy_dx_2.shape=torch.Size([1000, 3, 2])
ddy_dxx.shape=torch.Size([1000, 3, 2, 2]),ddy_dxx2.shape=torch.Size([1000, 3, 2, 2]),ddy_dxx3.shape=torch.Size([1000, 3, 2, 2])
x.shape=torch.Size([1000, 2]),y.shape=torch.Size([1000, 3]),y2.shape=torch.Size([1000, 3])
dy_dx_1.shape=torch.Size([1000, 3, 2]),dy_dx_2.shape=torch.Size([1000, 3, 2])
ddy_dxx.shape=torch.Size([1000, 3, 2, 2]),ddy_dxx2.shape=torch.Size([1000, 3, 2, 2]),ddy_dxx3.shape=torch.Size([1000, 3, 2, 2])

如何选择jacrev与jacfwd#

jacrev与jacfwd

These two functions compute the same values (up to machine numerics), but differ in their implementation: jacfwd uses forward-mode automatic differentiation, which is more efficient for “tall” Jacobian matrices, while jacrev uses reverse-mode, which is more efficient for “wide” Jacobian matrices. For matrices that are near-square, jacfwd probably has an edge over jacrev.

以及对于hessian

To implement hessian, we could have used jacfwd(jacrev(f)) or jacrev(jacfwd(f)) or any other composition of the two. But forward-over-reverse is typically the most efficient. That’s because in the inner Jacobian computation we’re often differentiating a function wide Jacobian (maybe like a loss function 𝑓:ℝⁿ→ℝ), while in the outer Jacobian computation we’re differentiating a function with a square Jacobian (since ∇𝑓:ℝⁿ→ℝⁿ), which is where forward-mode wins out.

More#

https://pytorch.org/tutorials/beginner/basics/autogradqs_tutorial.html

https://pytorch.org/tutorials/intermediate/per_sample_grads.html

https://pytorch.org/tutorials/intermediate/jacobians_hessians.html