在PyTorch、JAX、MLX乃至其他现代深度学习框架中,VJP(向量-雅可比乘积)和JVP(雅可比-向量乘积)是实现自动微分的核心机制。最近在学习一些关于敏感度分析,了解到了相关的知识。这里记录一下这两个概念,以及它们是如何驱动“自动求导”过程的。
基本定义
不失去一般性,对于一个函数
\[y=f(x),\ x \in \mathbb{R}^n, y \in \mathbb{R}^m\]它的雅可比矩阵是:
\[J_f = \frac{\partial f}{\partial x} \in \mathbb{R}^{m \times n}\]JVP:Jacobian-Vector Product(正向微分)的定义是,给定一个向量\(v \in \mathbb{R}^n\), JVP计算
\[J_f(x)\cdot v\]表示在当前点\(x\), 往方向\(v\)移动一小段距离,输出如何变化,也就是正向微分,乘法的时间复杂度是\(O(mn^2)\).
实际上正向微分就是导数的微分形式。举一个例子,比如\(y=sin(x) \cdot x\),我们写出导数的微分形式:
\[dy = cos(x)\cdot x \cdot dx + sin(x) \cdot dx\]对应于代码,我们相当于在每个操作中维护一个pair:
(x, dx) # 当前值及其方向导数
JVP输入另外一对pair:
y = np.sin(x) * x
dy = np.cos(x) * dx * x + np.sin(x) * dx
然后在计算图里前向传播即可,完整的代码如下:
def f(x): # y = sin(x) * x
return np.sin(x) * x
def jvp(x, dx): # 输入当前值及其方向导数
y = np.sin(x) * x
dy = np.cos(x) * dx * x + np.sin(x) * dx
return y, dy
VJP:Vector-Jacobian Product(反向微分)则相反,其定义是,给定一个向量\(u \in \mathbb{R}^m\), JVP计算:
\[u^\mathrm T \cdot J_f(x)\]也就是雅可比左乘一个转置向量。实际上反向微分就是导数的梯度形式, 乘法的时间复杂度是\(O(m^2n)\).
比如在神经网络中,记最后的loss为\(L\) (通常是标量),将其梯度初始化为1。
我们可以反向传播每个操作的梯度。比如在前向传播过程中,某个操作还是\(y=sin(x) \cdot x\),我们现在已经知道了从上层链路传来的梯度\(\frac{\partial L}{\partial y}\), 我们根据链式法则计算\(\frac{\partial L}{\partial x}\):
\[\frac{\partial L}{\partial x} = \frac{\partial L}{\partial y}\cdot \frac{\partial y}{\partial x}=\frac{\partial L}{\partial y}\cdot (cos(x)\cdot x + sin(x))\]用代码表示就是:
def vjp(x, y_grad): # y_grad 是从上层链路传来的
x_grad = y_grad * (np.sin(x) + x * np.cos(x))
return x_grad
实际上,正向微分就是敏感度分析,输入变量往某个方向移动一小段距离,输出如何变化,可以快速获取任意方向上的导数值。其整体的时间复杂度和前向传播类似,适合用于小输入的实时求导。
PyTorch中自定义算子的前向微分和后向微分
实际上写过PyTorch自定义算子的同学就知道,自定义算子里面backward
函数就是就是后向微分。但是前向微分比较少用。实际上,Pytorch以vjp功能(也就是反向传播功能)为主,这在神经网络训练场景下(最后的往往是一个标量的损失函数,维度较小)非常实用。JAX则提供了完备的高阶微分系统(包括前向微分和后向微分),因此在科学计算领域非常受欢迎。
后面PyTorch向JAX看齐,也提供了前向微分功能,也就是jvp
接口,我们可通过自定义一个平方函数,来学习怎么使用(下面的代码在PyTorch 2.7.1中可运行),注意同时要定义setup_context
接口。
#%%
import torch
print(torch.__version__)
#%%
class MySquare(torch.autograd.Function):
@staticmethod
def forward(x):
return x ** 2
@staticmethod
def setup_context(ctx, inputs, output):
x, = inputs # Get the input tensor
ctx.x = x
@staticmethod
def backward(ctx, grad_output):
x = ctx.x
# VJP: ∂(x^2)/∂x = 2x
return grad_output * 2 * x
@staticmethod
def jvp(ctx, grad_x):
"""JVP method for forward-mode AD"""
x = ctx.x # Access input saved for jvp
# JVP: d(x^2)/dx * dx = 2x * dx
return 2 * x * grad_x
#%%
x = torch.tensor(3.0, requires_grad=True)
def square_fn(x):
return MySquare.apply(x)
# Example 1: Forward and Backward (VJP)
print("=== Forward and Backward (VJP) ===")
# Forward
y = square_fn(x)
print("Forward:", y) # 9.0
# Backward (VJP)
y.backward()
print("Gradient (VJP):", x.grad) # 6.0
#%%
# Example 2: Using torch.func.jvp for forward-mode AD
print("\n=== Forward-mode AD using torch.func.jvp ===")
from torch.func import jvp
x = torch.tensor(3.0)
x_tangent = torch.tensor(2.0) # direction vector dx
# Compute JVP using torch.func.jvp
result = jvp(square_fn, (x,), (x_tangent,))
primal_out = result[0]
tangent_out = result[1]
print("Primal output:", primal_out) # 9.0
print("Tangent output:", tangent_out) # 12.0
# %%
PyTorch中现有算子的前向微分
当然对于Pytorch的现有现有算子,我们可以用torch.autograd.functional.jvp
实现自动前向微分。
from torch.autograd.functional import jvp
# 定义一个函数 f(x): R^2 → R^2
def f(x):
# x 是形状为 (2,) 的张量
return torch.stack([
x[0] * x[1], # x₁ * x₂
torch.sin(x[1]) # sin(x₂)
])
# 选择一点 x,以及一个方向向量 v
x = torch.tensor([2.0, 3.0], requires_grad=True)
v = torch.tensor([1.0, 0.0]) # 只沿第一个方向扰动
# 调用 jvp(Jacobian-vector product)
y, jvp_result = jvp(f, (x,), (v,))
# 打印结果
print("f(x) =", y)
print("JVP: df/dx · v =", jvp_result)
# 结果
# f(x) = tensor([6.0000, 0.1411])
# JVP: df/dx · v = tensor([3., 0.])
解释:
\[\begin{align} f_1(x) &= x_0 \cdot x_1 = 2 \cdot 3 = 6 \\ f_2(x) &= \sin(x_1) = \sin(3) ≈ 0.1411 \end{align}\]雅可比矩阵(2x2)为:
\[J_f(x) = \begin{bmatrix} \frac{\partial f_1}{\partial x_0} & \frac{\partial f_1}{\partial x_1} \\ 0 & \cos(x_1) \end{bmatrix} \begin{bmatrix} 3 & 2 \\ 0 & \cos(3) \end{bmatrix}\]方向向量 \(v = [1, 0]\),所以:
\[J_f \cdot v = \begin{bmatrix} 3 \\ 0 \end{bmatrix}\]利用前向微分实现Hessian 向量积
Hessian 向量积(HVP,Hessian-Vector Product)是机器学习和优化中非常重要的一个工具,通常来说,构造Hessian矩阵是一个非常代价非常大的操作,但是很多时候我们并不需要显式构造Hessian矩阵,只需要Hessian矩阵和一个向量的积,我们可以用前向微分实现。
首先,Hessian 向量积衡量的是,函数在某方向的二阶变化率,也就是该方向的“曲率”。
设函数:
\[f : \mathbb{R}^n \rightarrow \mathbb{R}\]我们有梯度\(\nabla f(x) \in \mathbb{R}^n\)和Hessian 矩阵\(\nabla^2 f(x) \in \mathbb{R}^{n \times n}\)。给定方向向量\(v \in \mathbb{R}^n\),Hessian 向量积定义为:
\[\text{HVP}(x, v) = \nabla^2 f(x) \cdot v\]它是 Hessian 矩阵作用于一个向量的结果,反映在方向 v 上的二阶导数信息。
相比于梯度,梯度告诉你函数往哪个方向最陡;Hessian 向量积告诉你函数在某个方向是否“凹”或“凸”,也就是二阶导数的投影,这是用于估计优化路径曲率的关键工具。
那么如何快速计算HVP呢?
我们只需要对梯度再做一次前向微分(JVP,方向导数)即可,
\[\nabla^2 f(x) \cdot v = \frac{d}{d\epsilon} \nabla f(x + \epsilon v)\big|_{\epsilon = 0}\]换句话说:\(HVP = JVP(\nabla f(x), v)\). 可以理解为先做了一次反向微分得到梯度,然后对梯度再做一次前向微分。
我们可以在PyTorch中实现:
import torch
from torch.autograd.functional import jvp
# 定义函数 f: R^n → R(标量)
def f(x):
return (x[0]**2 + 3 * x[1]**2 + x[0]*x[1]) # 可换成任意标量函数
x = torch.tensor([1.0, 2.0], requires_grad=True)
v = torch.tensor([1.0, -1.0]) # 要计算的方向向量
# 第一步:定义 grad f(x)
def grad_f(x):
return torch.autograd.grad(f(x), x, create_graph=True)[0]
# 第二步:对 grad f(x) 做 JVP,得到 Hessian-vector product
_, hvp = jvp(grad_f, (x,), (v,))
print("Hessian-vector product (HVP):", hvp)
# Hessian-vector product (HVP): tensor([ 1., -5.])
HVP在很多地方都有用:
- 在二阶优化中,用于实现牛顿法、共轭梯度法;
- 在神经网络训练中,使用 HVP 实现曲率感知更新;
- 在灵敏度分析中,实现梯度变化趋势分析;
- 强化学习,实现Trust Region Policy Optimization (TRPO)。
前向微分与后向微分在敏感度分析中的区别
涉及敏感度分析(sensitivity analysis),有两种常见的方法:
方法 | 表达式 | 本质 | 精度 | 实现方式 |
---|---|---|---|---|
自动微分法(JVP) | \(\dot{y} = J(x) \cdot \dot{x}\) | 真正的一阶导数 | ✅ 精确 | 使用 torch.autograd.functional.jvp |
梯度乘扰动法 | \(\delta y \approx \nabla y(x)^\top \cdot \delta x\) | 一阶泰勒展开近似 | ⚠️ 有截断误差 | 手动乘梯度和扰动 |
import torch
from torch.autograd.functional import jvp
# 非线性函数
def f(x): return torch.sin(x)
#%%
x = torch.tensor([1.0], requires_grad=True)
dx = torch.tensor([0.1]) # 扰动方向
# 方法 1:JVP
_, dy_jvp = jvp(f, (x,), (dx,))
# 方法 2:梯度乘扰动
y = f(x)
y.backward()
grad = x.grad
dy_approx = grad * dx
# 方法 3:真实变化值
y2 = f(x + dx)
delta_y_true = y2 - y
print("Approximate result:", dy_approx.item())
print("JVP result:", dy_jvp.item()) # ≈ cos(1.0) * 0.1
print("True delta_y:", delta_y_true.item())
# Approximate result: 0.05403023585677147
# JVP result: 0.05403023585677147
# True delta_y: 0.04973644018173218
#%%
dx = torch.tensor([0.5])
_, dy_jvp = jvp(f, (x,), (dx,))
y = f(x)
y.backward()
grad = x.grad
dy_approx = grad * dx
y2 = f(x + dx)
delta_y_true = y2 - y
print("Approximate result:", dy_approx.item())
print("JVP result:", dy_jvp.item()) # ≈ cos(1.0) * 0.1
print("True delta_y:", delta_y_true.item())
y2 = f(x + dx)
delta_y_true = y2 - y
# Approximate result: 0.5403023362159729
# JVP result: 0.27015116810798645
# True delta_y: 0.1560240387916565
总结
这篇文章介绍了自动微分的两个核心概念:JVP(正向微分)和VJP(反向微分)。JVP计算雅可比矩阵与向量的乘积,适合输入维度小的场景;VJP计算向量与雅可比矩阵的乘积,适合输出维度小的场景如神经网络训练。PyTorch主要基于VJP实现反向传播,而JAX同时提供了完备的JVP和VJP支持。理解这两个概念有助于更好地掌握现代深度学习框架的自动微分原理。 最后,该文章涉及到的代码都放在这供大家参考。