为什么近期大模型训练需要用到CUDA Graph
CUDA Graph是NVIDIA推出的一项技术,它可以将一系列GPU操作,包括kernel launches,memory copies预先在capture阶段录制成一个图,然后在replay阶段一次性提交执行。CUDA Graph最大的好处是可以极大减少CPU开销,减少kernel launch latency,这在使用Grace CPU的B系列NV芯片上带来的好处尤其突出。同时GPU可以更好的优化执行顺序,减少运行队列中间的bubble。
针对B系列芯片CPU开销较大的问题,这里以MoE层中fc1的grouped GEMM操作为例进行说明。以下的实验均在B系列芯片上进行。在Transformer Engine中,默认的grouped GEMM其实是通过循环触发多个GEMM kernel,并分配到不同的CUDA stream上执行(如左图所示),本质上不是“真正”的grouped GEMM。而如果在B系列芯片上直接调用cutblass,只需启动一次kernel即可完成真正的grouped GEMM(如右图所示),这种方式能显著减少kernel launch次数,从而大幅降低overhead。同时可以看到,左图在GEMM计算结束后,队列中出现一段bubble才进入激活函数环节,也从侧面反映出CPU带来的额外延迟。值得一提的是,TE默认采用多stream实现grouped GEMM的做法在H系列芯片上通常更优,但迁移到B系列时就未必适合了。
当然,使用 CUDA Graph 也引入了一些新的限制和约束:
- 输入的 Tensor 形状必须保持一致 —— 每次调用时输入 shape 不能变化。
- 禁止 CPU-GPU 同步操作 —— 在 graph 内部不能包含诸如
.item()、.cpu()等涉及同步的指令。 - 不能存在动态控制流 —— 执行路径必须是确定性的,不能依赖于运行时数据分支。
- 内存地址需固定 —— 需采用静态分配的输入缓冲区。
其中,“内存地址需固定”这一点对MoE模型尤为突出。MoE 的核心机制在于 Token Dispatcher 根据路由器(router)的分配结果,动态决定每个 token 分派给哪个 expert。即每个 batch 的路由都可能不同,进而使 AlltoAll 通信所需的数据大小和内存布局发生变化。而 CUDA Graph 要求 capture 阶段所有操作涉及的 Tensor 地址均已确定,这与 MoE 这种动态分布形成冲突。
为了解决这一问题,工程实践中,我们采用“部分捕获”(Partial Capture)的方案:对于可以静态展现的部分我们用 CUDA Graph 捕获,其余涉及动态路径/动态资源分配的部分则按照传统方式执行。这样既能获得 graph 带来的 kernel launch/调度优化,也能兼顾 MoE 动态收发的灵活性。具体配置如下所示:
cuda_graph_impl: transformer_engine
cuda_graph_scope:
- attn # 捕获 attention 层
- moe_router # 捕获 MoE 路由核心
- moe_preprocess # 捕获路由 token 的预处理环节
接下来将详细讲解实现原理与核心操作流程。
TE Capture/Replay 工作流程
整体流程
Megatron-LM 支持两种 CUDA Graph 实现方式:其一为Megatron-LM自己的实现,通过设置 cuda_graph_impl:local 启用。这种方式由 Megatron 原生实现了 graph 的捕获及回放,在 dense 模型场景下,可以一次性录制并复现包含前向与反向传播(不含优化器操作)的完整迭代过程,相关核心代码位于 megatron/core/transformer/cuda_graphs.py,主要涉及 CudaGraphManager 与 _CudaGraphRunner。然而在 MoE 模型情形下,由于需要采用定长 padding,导致无论某个 expert 实际接收到的 token 数量多少,都必须填充到统一的最大容量,这不仅增加了不必要的计算消耗,还带来 AlltoAll 通信带宽的浪费,并且显存也需预留足够的最大 buffer。因此,目前该实现架构并不适合 MoE 场景下的 CUDA Graph 加速需求。
另一种对 CUDA Graph 的支持方式,是通过启用 Transformer Engine(TE)实现的,只需在配置中设置 cuda_graph_impl:transformer_engine。我们在partial cuda graph中使用这套方案。在该方案下,核心流程围绕 TE 的 CUDA Graph 捕获(Capture)与回放(Replay)机制展开。Megatron-LM 集成 TE CUDA Graph 的主要逻辑位于 megatron/core/transformer/cuda_graphs.py 的 TECudaGraphHelper 类内部,其核心方法是调用 TE 提供的 make_graphed_callables() API。从最简单的情况开始开始考虑,假设只有一个microbatch,那么整体的工作流程如下:
# 1. 收集待捕获的层
flattened_callables = []
for layer in model.decoder.layers:
if _layer_is_graphable(layer, config):
flattened_callables.append(layer)
# 2. 构造示例输入
sample_args, kwargs = self._get_cuda_graph_input_data()
# 形如: ((hidden_states_layer0,), (hidden_states_layer1,), ...)
# 3. 捕获阶段
graphs = make_graphed_callables(
tuple(flattened_callables),
sample_args,
kwargs,
_order=order, # 管道并行调度顺序
_reuse_graph_input_output_buffers=True, # TE 2.7+ 支持的 buffer 复用特性
)
# 4. 可cuda replay的 Graph分配给各层对象
for layer in layers:
layer.cuda_graphs = [graphs[idx] for idx in ...]
首先是第一阶段,收集待捕获的层。由于我们是逐层进行捕获,因此首先会依据 config 配置筛选需要捕获 CUDA Graph 的层,依次存入 flattened_callables。注意这个时候虽然我们只捕获MoE layer的attn,moe_router以及moe_preprocess这三部分,但是这里我们还是将这个layer放进来,后面通过后面说的python的异常机制来实现部分捕获。
接着是第二阶段,构造示例输入。在这个阶段需为每一层准备一组示例输入 sample_args 与 kwargs,这个本质上是在调用layer的get_layer_static_inputs函数:
args, kwargs = layer.get_layer_static_inputs(self.seq_length, self.micro_batch_size)
这套静态输入会作为forward graph的static input buffer,同时在它上面模拟forward和backward跑出来的结果会作为backward graph的static input buffer。涉及PP的调度与 buffer 管理会在后文详细介绍。
第三阶段是捕获阶段。在这个阶段我们调用了TE的make_graphed_callables这个API。在 make_graphed_callables 的内部,核心包括以下环节:
- Warmup 阶段:多次执行 forward+backward,让类似
triton.compile这样的操作完成lazy 初始化,确保后续捕获的稳定性。 - Capture 阶段:依顺序对每个可调用对象依次执行 forward,再按逆序依次执行 backward。 捕获过程中将所有 CUDA kernel 操作录入 graph,同时固定输入和中间变量的 tensor 地址, 也就是说 forward 期间所有的中间 tensor buffer 其物理地址在 capture 当下即已确定。
# All captures here share a mempool. To avoid replays corrupting each other's memory, # the safest approach is to capture all passes in the same order they'll run: # fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1. make_graphed_callables最终返回一组包装好的 callable,这些callable本质上是是调用torch.autograd.Funtion的闭包,调用callable的返回值就是有我们自定义grad_fn的tensor。换句话说,我们的fwd_graph和bwd_graph其实是存在闭包里面的。def make_graphed_callables(...): # ... 捕获 graph 的代码 ... # 这些变量会被闭包捕获 static_input_surface = [...] static_outputs = (...) static_grad_outputs = (...) static_grad_inputs = (...) fwd_graph = torch.cuda.CUDAGraph() bwd_graph = torch.cuda.CUDAGraph() # 定义自定义 autograd function(被闭包捕获) class Graphed(torch.autograd.Function): @staticmethod def forward(ctx, skip_fp8_weight_update, *inputs): # 使用闭包捕获的变量 static_input_surface[i].copy_(inputs[i]) fwd_graph.replay() return tuple(o.detach() for o in static_outputs) @staticmethod def backward(ctx, *grads): static_grad_outputs[i].copy_(grads[i]) bwd_graph.replay() return static_grad_inputs # 闭包函数:捕获 Graphed 和其他变量 def functionalized(*user_args, **user_kwargs): # 调用 Graphed.apply() 执行 forward out = Graphed.apply(skip_fp8_weight_update, *func_args) return _tree_unflatten(out, output_unflatten_spec) return functionalized # 返回闭包接着是最后阶段,我们将这些闭包对象分配给各层对象作为一个成员变量,供后续调用。
CUDA Graph在microbatch之间能否复用?
我们接下来思考另一个问题,现在有有多个micro batches,CUDA Graph在microbatch之间能否复用?不妨假设number of microbatches=4。实际上,这个和micro batch的forward和backward调用有关系。
- 第一种情况是,假设每个microbatch完成forward后立刻进行backward(也就是1F1B调度)。
Microbatch: MB0 MB0 MB1 MB1 MB2 MB2 MB3 MB3
操作: FWD BWD FWD BWD FWD BWD FWD BWD
└──┘ └──┘ └──┘ └──┘ └──┘ └──┘ └──┘ └──┘
↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑
立刻 立刻 立刻 立刻 立刻 立刻 立刻 立刻
order = [1, -1, 1, -1, 1, -1, 1, -1]
↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑
fwd bwd fwd bwd fwd bwd fwd bwd
mb0 mb0 mb1 mb1 mb2 mb2 mb3 mb3
这种情况下显然CUDA Graph是可以复用的,同时input buffer也是可以复用的。那么这个时候构建的CUDA Graph的数量是2 × num_layers,其中2分别为是forward和backward。
- 另外一种情况是,所有microbatch都完成forward后,再依次进行backward。
Microbatch: MB0 MB1 MB2 MB3 MB0 MB1 MB2 MB3
操作: FWD FWD FWD FWD BWD BWD BWD BWD
←── 所有 Forward ──→ ←── 所有 Backward ──→
order = [1, 1, 1, 1, -1, -1, -1, -1]
注意到中间变量的地址是存在CUDA Graph内部的,如果microbatch 1复用microbatch 0的CUDA Graph,那么在forward的时候会改掉microbatch 0的中间变量,这样在microbatch 0在backward的时候就没有中间变量就被破坏掉了。这个时候CUDA Graph是不可以被复用的,此时CUDA Graph的数量是2 × num_layers × num_microbatches。 很显然,第二个情况需要临时保存的activation数量太多了,正常Megatron-LM在不开流水线并行的情况下是不可能使用的。时间上我们可以从Megatron-LM的实际代码可以看到不开PP的时候就是第一种情况的1F1B调度:
# megatron/core/pipeline_parallel/schedules.py: forward_backward_no_pipelining
with no_sync_func(): # 前 N-1 个 microbatch 不同步梯度
for i in range(num_microbatches - 1):
output_tensor = forward_step(...) # Forward microbatch i
if not forward_only:
backward_step(...) # Backward microbatch i(立刻!)
# 最后一个 microbatch 在 no_sync 外执行(触发梯度同步)
output_tensor = forward_step(...) # Forward 最后一个
backward_step(...) # Backward 最后一个
流水线并行下的TE CUDA Graph解析
目前,Megatron-LM采用的流水线并行方案是1F1B interleaving,如下图所示:

为便于分析,我们先不考虑VPP,即采用上半图的标准的1F1B方案。这里number_of_microbatches为8,PP为4。结合前述分析,我们可以得出:microbatch 1 和 microbatch 4 之间的 CUDA Graph 不可以复用,但当 microbatch 1 的 backward 完成后,microbatch 5 在执行前向时,完全可以复用microbatch 1 对应的 CUDA Graph。在典型1F1B调度下,每个GPU上同时存在的activation峰值仅与PP相关,因此在理论上,最优情况下CUDA Graph数量下限为 2 × num_layers × PP。
然而,当前实现中CUDA Graph的实际数量不是随PP增长的,而是随num_of_microbatches线性增长。换言之,不同microbatch之间的CUDA Graph没有被复用。因此,目前系统中CUDA Graph的数量实际为 2 × num_layers × num_of_microbatches。
我们知道和num_of_microbatches成正比是一个不太好的事情,这意味和global batch size成正比,容易出现显存爆炸的问题。为了解决这个问题,我们首先需要明确两个概念:
- CUDA Graph本身:即对CUDA操作的录制,此部分显存消耗较小,每个microbatch都有一份独立的graph是没有问题的。
- 中间变量显存(activation):在CUDA Graph捕获阶段产生。这部分显存可以在不同的CUDA Graph之间共享,经优化后实际占用量与PP成正比。这一优化空间及方法我们将在后文详细讨论。
我们回到考虑VPP的情况,假设num_of_microbathes=8, PP=4,VPP=2。我们首先学习Megatron-LM里面是如何表示1F1B interleaving的,可以跟着下面代码的思路:
# Get the PP and VPP scheduling order.
from megatron.core.pipeline_parallel.schedules import (
convert_schedule_table_to_order,
get_pp_rank_microbatches,
get_schedule_table,
)
_, _, num_warmup_microbatches, _ = get_pp_rank_microbatches(
num_microbatches,
self.num_model_chunks,
self.config.microbatch_group_size_per_vp_stage,
False,
)
schedule_table = get_schedule_table(
num_microbatches,
self.num_model_chunks,
self.config.microbatch_group_size_per_vp_stage,
)
order = convert_schedule_table_to_order(
num_warmup_microbatches, self.num_model_chunks, schedule_table
)
首先,假设整个模型有32层(Layer 0-31),那么我们就将整个模型切分成PP × VPP个model chunk,每个PP rank就有VPP个model chunk:
模型: Layer 0-31 (32层)
PP = 4, VPP = 2 (num_model_chunks = 2)
PP Rank 0: [Layer 0-3] ← Chunk 0
[Layer 16-19] ← Chunk 1
PP Rank 1: [Layer 4-7] ← Chunk 0
[Layer 20-23] ← Chunk 1
PP Rank 2: [Layer 8-11] ← Chunk 0
[Layer 24-27] ← Chunk 1
PP Rank 3: [Layer 12-15] ← Chunk 0
[Layer 28-31] ← Chunk 1
对于一个PP rank来说,其实它只能看到本地的VPP=2个model chunk,因此它其实依次这样的操作:选择本地的某一个model chunk进行fwd或者bwd。在Megatron-LM里面,我们用一个数字来表示这样的操作,”绝对值-1”表示本地的哪一个model chunk,正负号表示fwd还是bwd。 比如对于上面的PP Rank 0:
- 1表示[Layer 0-3] (chunk 0)做fwd;
- 2表示[Layer 16-19] (chunk 1) 做fwd;
- -1表示[Layer 0-3] (chunk 0) 做bwd;
- -2表示[Layer 16-19] (chunk 1) 做bwd;
每个PP rank都能拿到一个叫order的序列,只要编排好,每个PP rank按照本地编排的order依次执行,那么整体上就能形成一个完整的流水线:比如上图的Device 1的order就是
order = [ # Warmup (10 forwards) 1, 1, 1, 1, # chunk 0 fwd: mb 1,2,3,4 2, 2, 2, 2, # chunk 1 fwd: mb 1,2,3,4 1, 1, # chunk 0 fwd: mb 5,6 # 1F1B steady state (交替 fwd/bwd) 1, -1, # F_mb7_c0, B_mb1_c1 1, -1, # F_mb8_c0, B_mb2_c1 2, -2, # F_mb5_c1, B_mb1_c0 2, -2, # F_mb6_c1, B_mb2_c0 2, -1, # F_mb7_c1, B_mb3_c1 2, -1, # F_mb8_c1, B_mb4_c1 # Cooldown (全 backward) -2, -2, # B_mb5,6_c0 -1, -1, # B_mb5,6_c1 -2, -2, # B_mb3,4_c0 -1, -1, # B_mb7,8_c1 -2, -2, # B_mb7,8_c0 ]其中我们可以看到分为三个阶段:warmup阶段(全是forward),1F1B交替阶段,以及cooldown阶段(全是backward)。那么对于每个PP rank怎么才能得到这个数组呢?我们跟着上面的代码来就行:
- 首先
get_pp_rank_microbatches告诉我们num_warmup_microbatches的计算公式就是
这里microbatch_group_size_pervp_stage表示每个 virtual stage 连续执行的 microbatch 数量,默认值是PP。代入到上图里计算确实这样。
接着是get_schedule_table这个函数,这个函数是个生成 (microbatch_id, model_chunk_id) 的调度顺序,实际上是每个PP rank上fwd操作的顺序,这个数组每个PP rank看到的都是一样的,从上图中看到的确实是这样。同时这个顺序顺序也是每个PP rank上看到backward操作的顺序。
def get_schedule_table(num_microbatches, num_model_chunks, microbatch_group_size_per_vp_stage):
"""
生成 (microbatch_id, model_chunk_id) 的调度顺序
"""
schedule_table = []
# 按 group 遍历所有 microbatch
for min_microbatch_id_in_group in range(0, num_microbatches, microbatch_group_size_per_vp_stage):
# 对于每个 group:
# 外层循环: model_chunk_id (0, 1, ..., VPP-1)
# 内层循环: microbatch_id (group 内的连续 microbatch)
schedule_table.extend([
(microbatch_id, model_chunk_id)
for model_chunk_id in range(num_model_chunks) # 外层
for microbatch_id in range(group_start, group_end) # 内层
])
return schedule_table
比如参数为num_microbatches=8, num_model_chunks=2, microbatch_group_size=4的时候,
# Group 0: microbatch 0-3
# Group 1: microbatch 4-7
# 遍历过程:
Group 0 (mb 0-3):
chunk 0: mb 0, 1, 2, 3 → (0,0), (1,0), (2,0), (3,0)
chunk 1: mb 0, 1, 2, 3 → (0,1), (1,1), (2,1), (3,1)
Group 1 (mb 4-7):
chunk 0: mb 4, 5, 6, 7 → (4,0), (5,0), (6,0), (7,0)
chunk 1: mb 4, 5, 6, 7 → (4,1), (5,1), (6,1), (7,1)
# 最终 schedule_table:
schedule_table = [
(0,0), (1,0), (2,0), (3,0), # chunk 0, mb 0-3
(0,1), (1,1), (2,1), (3,1), # chunk 1, mb 0-3
(4,0), (5,0), (6,0), (7,0), # chunk 0, mb 4-7
(4,1), (5,1), (6,1), (7,1), # chunk 1, mb 4-7
]
接着是convert_schedule_table_to_order这个函数,每个PP rank的num_warmup_microbatches是不一样的,都是都是全forward(warmup) + forward和backward交替 + 全backward三个阶段。只需要知道warmup有几个microbatches,后面的就能模拟出来了。
# Step 1: schedule_table
schedule_table = [
(0,0), (1,0), (2,0), (3,0), # chunk 0, mb 0-3
(0,1), (1,1), (2,1), (3,1), # chunk 1, mb 0-3
(4,0), (5,0), (6,0), (7,0), # chunk 0, mb 4-7
(4,1), (5,1), (6,1), (7,1), # chunk 1, mb 4-7
]
# _, model_chunk_id_table = zip(*schedule_table)
model_chunk_id_table = [0,0,0,0, 1,1,1,1, 0,0,0,0, 1,1,1,1]
# Step 2: forward_order 和 backward_order
forward_order = [1,1,1,1, 2,2,2,2, 1,1,1,1, 2,2,2,2] # chunk_id + 1
backward_order = [-2,-2,-2,-2, -1,-1,-1,-1, -2,-2,-2,-2, -1,-1,-1,-1] # chunk_id - 2
# Step 3: 生成 order
# Warmup: forward_order[:10]
order = [1,1,1,1, 2,2,2,2, 1,1]
# 1F1B: i = 10 to 15
# i=10: fwd[10]=1, bwd[0]=-2 → [1, -2]
# i=11: fwd[11]=1, bwd[1]=-2 → [1, -2]
# i=12: fwd[12]=2, bwd[2]=-2 → [2, -2]
# i=13: fwd[13]=2, bwd[3]=-2 → [2, -2]
# i=14: fwd[14]=2, bwd[4]=-1 → [2, -1]
# i=15: fwd[15]=2, bwd[5]=-1 → [2, -1]
# Cooldown: backward_order[-10:]
# = [-1,-1, -2,-2,-2,-2, -1,-1,-1,-1]
补充一句,VPP的本质是切更多的流水线阶段,只是将部分流水线阶段划分到同一个 GPU 上。如果我们的 GPU 数量足够,其实完全可以把这些被切分出来的子阶段分别放在不同的 GPU 上,只是受限于现实资源有限,只能将多个阶段合并放置于同一块 GPU罢了。因此从理论分析的角度,Bubble time fraction就是会下降至原先的 1/VPP,因为本质上流水线阶段就是变多了。
有了这个order之后,一个非常简单规则:当 order 中出现一个负数(backward),对应的 forward buffer 就可以被后续的 forward 复用了。这个规则将指导我们下面的显存复用,做到和PP成正比。
更加准确地说,同时因为进入稳定1F1B阶段正数和负数是交替的,是边使用边释放,因此更加准确来说,activation的峰值占用和(PP - rank - 1) × 2 + (VPP - 1) × PP + 1成正比。
fwd graph的Static Input buffer的大小如何做到和PP成正比?
根据这个简单规则简单模拟即可。注意一个model chunk对应多个layer,我们首先获得每个layer的输入应该长什么样子(记为sample_keys), 然后模拟即可:
# Forward: 尝试复用
if consumed_sample_queue.get(sample_keys, []):
# 有可复用的 → 直接复用
reuse_fwd_idx = consumed_sample_queue[sample_keys].pop(0)
sample_args[per_callable_fwd_idx] = sample_args[reuse_fwd_idx] # 复用!
else:
# 没有可复用的 → 生成新的
sample_args[per_callable_fwd_idx] = _get_layer_static_inputs(...)
# Backward: 释放 buffer
for sample_keys, fwd_idx in fwd_sample_queues[chunk][:num_layers]:
consumed_sample_queue[sample_keys].append(fwd_idx) # 标记可复用
这部分其实是TE的make_graphed_callables()的参数_reuse_graph_input_output_buffers对fwd graph的input buffer做的事情。对于bwd graph的input buffer,这部分不需要像fwd那样保存中间结构,直接检查key是否存在然后复用即可。这样就可以做到static Input buffer的大小和PP成正比。当然在Magatron-LM中准备sample args的时候,也要参考TE这样做,不然在运行某个时候这部分大小会和num of microbatches成正比。
中间变量的显存占用如何做到和PP成正比?
Static Input buffer这部分我们可以自己通过模拟pipeline的调度显式控制,使得和PP而不是和num_of_microbatches成正比。
另外一个需要思考的如何让中间变量的显存占用做到和PP成正比。为了实现这个我们需要的使用到memory pool这个功能以及pytorch的make_weak_ref这个功能。
首先,我们需要让所有的graph都共享一个memory pool,
# graph.py:359
mempool = graph_pool_handle() if pool is None else pool
# 所有 graph 捕获时都使用同一个 pool
with _graph_context_wrapper(fwd_graph, pool=mempool): # forward graph
outputs = func(*args, **kwargs)
with _graph_context_wrapper(bwd_graph, pool=mempool): # backward graph
...
因为pytoch的API默认情况下,每个 CUDA Graph 有私有内存池,内存不能互相复用;所有 Graph 共享一个池才有复用的可能性。
┌─────────────────────────────────────────────────────────────┐
│ 默认行为:每个 CUDA Graph 有私有内存池 │
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ │
│ │ Graph 1 │ │ Graph 2 │ │ Graph 3 │ ... │
│ │ Pool 1 │ │ Pool 2 │ │ Pool 3 │ │
│ └─────────┘ └─────────┘ └─────────┘ │
│ 内存不能互相复用 → 总显存 = N * graph_size │
├─────────────────────────────────────────────────────────────┤
│ 共享 Memory Pool:所有 Graph 共享一个池 │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ Shared Memory Pool │ │
│ │ ┌───────┐ ┌───────┐ ┌───────┐ │ │
│ │ │Graph 1│ │Graph 2│ │Graph 3│ ... │ │
│ │ └───────┘ └───────┘ └───────┘ │ │
│ │ 内存可以复用 → 总显存 = max_concurrent * graph_size │ │
│ └─────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────┘
在复用了同一个memory pool之后,为了让中间变量释放,我们需要做的是及时触发python的gc自动删除不再需要的中间变量。在这里,我们需要用到Pytorch的make_weak_ref这个功能:
class _WeakRefTensor:
"""只保存 data_ptr,不持有 tensor 引用"""
def __init__(self, data_ptr, dtype, shape):
self._data_ptr = data_ptr # 只记录地址
self.dtype = dtype
self.shape = shape
# 注意:没有保存对原 tensor 的引用!
def make_weak_ref(x):
if isinstance(x, torch.Tensor):
return _WeakRefTensor(x.data_ptr(), x.dtype, x.shape)
# 原 tensor 的引用被丢弃 → PyTorch 认为该内存可以释放
其核心原理是:make_weak_ref 把”持有引用的 tensor 对象”替换成”只记录地址的整数”,原 tensor 失去所有引用后被 Python GC 回收,PyTorch 随之释放其 GPU 内存回 mempool。地址虽然还被记住,但内存已经可以被其他 graph 复用了。
原始状态:
┌─────────────────────────────────────────────┐
│ per_callable_static_outputs[idx] │
│ = (tensor_A, tensor_B, tensor_C) │
│ ↓ ↓ ↓ │
│ ┌────────┐ ┌────────┐ ┌────────┐ │
│ │ Memory │ │ Memory │ │ Memory │ (被占用) │
│ └────────┘ └────────┘ └────────┘ │
└─────────────────────────────────────────────┘
调用 make_weak_ref 后:
┌─────────────────────────────────────────────┐
│ per_callable_static_outputs[idx] │
│ = (_WeakRefTensor, _WeakRefTensor, ...) │
│ ↓ (只保存地址,不持有引用) │
│ ┌────────┐ ┌────────┐ ┌────────┐ │
│ │ Memory │ │ Memory │ │ Memory │ (可复用) │
│ └────────┘ └────────┘ └────────┘ │
│ ↑ PyTorch 认为这些内存不再被使用 │
│ ↑ 可以被后续的 graph capture 复用 │
└─────────────────────────────────────────────┘
因此,当完成一个backward后,我们能将三部分tensor设置为week reference,从而触发python的自动gc:
# graph.py:634-665
# Weak ref the static outputs and static grad inputs that are no longer needed
# in the following steps. These two type of tensors are both in cudagraph
# mempool, so we just deallocate them and let PyTorch's memory allocator
# reuse them elsewhere.
if _reuse_graph_input_output_buffers:
# 1. Weak ref the static outputs of the forward pass of this backward. It's
# no longer needed after the corresponding backward graph is built up.
per_callable_static_outputs[per_callable_bwd_idx] = make_weak_ref(
static_outputs
)
# 2. Weak ref the static grad inputs of the previous backward pass within the
# same chunk.
if previous_per_callable_bwd_idx is not None:
idx = previous_per_callable_bwd_idx
per_callable_static_grad_inputs[idx] = make_weak_ref(
per_callable_static_grad_inputs[idx]
)
previous_per_callable_bwd_idx = per_callable_bwd_idx
# 3. Weak ref the static grad inputs of the previous chunk's last backward
# pass.
# Note: After a chunk's backward pass, we assume Mcore will send the grad
# input to another pipeline parallel rank and that the communication is
# finished before the end of the next chunk's backward pass.
if l_no == 0:
if previous_chunk_last_callable_bwd_idx is not None:
idx = previous_chunk_last_callable_bwd_idx
per_callable_static_grad_inputs[idx] = make_weak_ref(
per_callable_static_grad_inputs[idx]
)
previous_chunk_last_callable_bwd_idx = per_callable_bwd_idx
if ceil(c_id) == c_id:
bwd_idx[m_chunk] += 1
MoE Partial Capture/Replay的工作流程
上面的叙述适用于任意layer,对于dense layer这种可以完全被capture没有太大的问题,对于MoE layer,我们需要根据config配置capture部分操作。为了实现这部分功能,我们基于基于 Python 异常机制实现 graph 边界的动态中断与恢复。首先我们看MoE partial capture/replay的总体工作流程如下:
sequenceDiagram
participant TL as TransformerLayer
participant MoE as MoELayer
participant Router as Router
participant Dispatcher as TokenDispatcher
participant Store as CudagraphTensorStore
Note over TL,Store: CUDA Graph Capture Phase
TL->>MoE: forward(hidden_states)
MoE->>Router: route()
Router-->>MoE: probs, routing_map
MoE->>Dispatcher: preprocess()
Dispatcher-->>MoE: raise PartialCaptureSignal
MoE-->>TL: early_return_outputs
Note over TL,Store: CUDA Graph Replay Phase
TL->>TL: replay graph outputs
TL->>Store: set(hidden_states, probs, routing_map, residual)
TL->>MoE: forward(hidden_states)
MoE->>Store: check tensor_store
Store-->>MoE: skip router/preprocess
MoE->>MoE: dispatch, compute, combine
MoE-->>TL: output
我们分为capture阶段和replay阶段来看如何实现Partial CUDA Graph技术:
- 异常信号机制:在capture阶段使用
MoECudaGraphPartialCaptureSignal实现优雅的提前返回,只捕获静态部分(attention、router、preprocess) - 状态恢复机制:在 replay 阶段,使用
MoECudaGraphTensorStore跳过已计算部分
实现MoE的Partial Capture:装饰器+异常
我们首先自顶向下来看CUDA Graph Capture的过程。
首先是最外层的TransformerLayer, TransformerLayer原本继承自MegatronModule, 现在改成继承自GraphableMegatronModule。GraphableMegatronModule 是支持 CUDA Graph 的 Megatron 模块基类,目前被 TransformerLayer 和 MambaLayer 继承。
在初始化中,我们存储记录每个microbatch的graph callable(就是上文描述过的闭包)以及手动hook(后面会描述):
def __init__(self, config: TransformerConfig, vp_stage: Optional[int] = None):
super().__init__(config)
if config.cuda_graph_impl == "local":
# Local 实现:使用 CudaGraphManager
self.cudagraph_manager = CudaGraphManager(config, vp_stage=vp_stage)
elif config.cuda_graph_impl == "transformer_engine":
# TE 实现:
self.cuda_graphs = [] # 存储每个 microbatch 的 graph callable
self.cuda_graph_manual_hooks = [] # 手动 hooks(参数 all-gather 等)
其__call__函数就会检查是否走 TE CUDA Graph 路径:
# module.py:285-300
def __call__(self, *args, **kwargs):
if self._should_call_local_cudagraph(*args, **kwargs):
# Local 实现
return self.cudagraph_manager(self, args, kwargs)
elif self._should_call_te_cudagraph(*args, **kwargs):
if not self.cuda_graphs:
# Capture 模式
cuda_graph_func = self._te_cuda_graph_capture
else:
# Replay 模式
cuda_graph_func = self._te_cuda_graph_replay
return cuda_graph_func(*args, **kwargs)
# 普通 forward
return super().__call__(*args, **kwargs)
其中,继承GraphableMegatronModule需要自定义_te_cuda_graph_capture分别制定TE CUDA Graph Capture和TE CUDA Graph Replay的时候要怎么做。TransformerLayer的_te_cuda_graph_capture实现如下:
def _te_cuda_graph_capture(self, *args, **kwargs):
# 根据 scope 决定捕获什么
if not self.config.cuda_graph_scope or 'attn' in self.config.cuda_graph_scope:
hidden_states, context = self._forward_attention(*args, **kwargs)
if 'moe_router' in self.config.cuda_graph_scope:
hidden_states = self._forward_mlp(hidden_states) # 会触发异常提前返回
# 收集输出
cuda_graph_outputs = [hidden_states, ...]
return tuple(cuda_graph_outputs)
这里我们就能很清晰看到如何根据配置选择需要执行的范围。如果我们只capture attn部分,那么我们就不会运行_forward_mlp部分,这个时候__call__的调用就没有跑完整个transformer layer了,直接就返回了attn的结果。但是如果我们想capture _forward_mlp的前半部分,本质上这里是可以类似attn一样继续拆分下去的,只要在合适的点退出即可。事实上,我们确实对对MoE 的 forward 被拆分成多个独立的函数:
class MoELayer(BaseMoELayer):
def __init__(self, ...):
# ... 省略其他初始化 ...
@maybe_skip_or_early_return_by_cudagraph("shared_experts_compute")
def shared_experts_compute(self, hidden_states: torch.Tensor):
"""计算 shared expert 的输出(如果配置了的话)"""
shared_expert_output = None
if self.use_shared_expert and not self.shared_expert_overlap:
if self.shared_experts_recompute:
shared_expert_output = tensor_parallel.checkpoint(
self.shared_experts, False, hidden_states
)
else:
shared_expert_output = self.shared_experts(hidden_states)
return shared_expert_output
@maybe_skip_or_early_return_by_cudagraph("route")
def route(self, hidden_states: torch.Tensor):
"""使用 router 计算 token 到 expert 的映射"""
probs, routing_map = self.router(hidden_states)
return probs, routing_map
@maybe_skip_or_early_return_by_cudagraph("preprocess")
def preprocess(self, hidden_states, probs, routing_map):
"""预处理:计算通信 splits,重排 token"""
residual = hidden_states
hidden_states, probs = self.token_dispatcher.dispatch_preprocess(
hidden_states, routing_map, probs
)
return hidden_states, probs, residual
# 以下方法不使用装饰器,因为它们已经在 partial capture 范围之外
def dispatch(self, hidden_states, probs):
"""执行 AlltoAll 通信"""
return self.token_dispatcher.token_dispatch(hidden_states, probs)
def routed_experts_compute(self, hidden_states, probs, residual):
"""在 dispatched tokens 上计算 expert 输出"""
dispatched_input, tokens_per_expert, permuted_probs = (
self.token_dispatcher.dispatch_postprocess(hidden_states, probs)
)
expert_output, mlp_bias = self.experts(dispatched_input, tokens_per_expert)
output = self.token_dispatcher.combine_preprocess(expert_output)
return output, mlp_bias
def combine(self, output, shared_expert_output):
"""合并 routed expert 和 shared expert 的输出"""
output = self.token_dispatcher.token_combine(output)
if shared_expert_output is not None:
output = output + shared_expert_output
return output
理论上我们可以类似刚才TransformerLayer的_te_cuda_graph_capture函数,读取config判断什么时候退出即可,但是这样对代码侵入改动太大了。面对这样的问题,其实就是在某函数上“外包”一层,直接使用装饰器模式:
def maybe_skip_or_early_return_by_cudagraph(step_condition):
"""
step_condition 可以是:
- "shared_experts_compute": 跳过 shared expert 计算
- "route": 跳过 router 计算(或在 capture 时提前返回)
- "preprocess": 跳过 preprocess 计算(或在 capture 时提前返回)
"""
def maybe_raise_signal(moe_layer, **kwargs):
"""Capture 阶段:检查是否需要提前返回"""
if (
moe_layer.config.cuda_graph_impl == "transformer_engine"
and moe_layer.training
and is_graph_capturing()
):
if step_condition == "route" and 'moe_router' in scope and 'moe_preprocess' not in scope:
raise MoECudaGraphPartialCaptureSignal(moe_layer, "route", **kwargs)
elif step_condition == "preprocess" and 'moe_preprocess' in scope:
raise MoECudaGraphPartialCaptureSignal(moe_layer, "preprocess", **kwargs)
def decorator(func):
def wrapped_func(moe_layer, *args, **kwargs):
# 非 cudagraph 路径:直接执行
if not is_graph_capturing() and moe_layer.cudagraph_tensor_store.is_empty():
return func(moe_layer, *args, **kwargs)
# Capture 和 Replay 路径
if step_condition == "route":
if moe_layer.cudagraph_tensor_store.probs is None:
# Capture 阶段
# 执行 router 计算
probs, routing_map = func(moe_layer, *args, **kwargs)
# 可能抛出异常提前返回
maybe_raise_signal(moe_layer, probs=probs, routing_map=routing_map)
else:
# Replay 阶段:从 store 读取,跳过计算
probs = moe_layer.cudagraph_tensor_store.probs
routing_map = moe_layer.cudagraph_tensor_store.routing_map
return probs, routing_map
# 类似处理 "preprocess" 和 "shared_experts_compute"
return wrapped_func
return decorator
注意这个装饰器其在Capture和Replay都有用到,用is_graph_capturing()来区分 Capture vs Replay。我们可以先关注Capture部分,我们发现,这里就是根据config来判断什么时候抛出异常,同时这个异常还会带有信息来帮助我们得到中间结果。下面是MoECudaGraphPartialCaptureSignal这个异常的定义:
# megatron/core/transformer/moe/moe_utils.py
class MoECudaGraphPartialCaptureSignal(Exception):
"""
用于在 CUDA graph capture 阶段从 MoE 层提前返回。
当我们只想部分捕获 MoE 层时,会抛出这个异常。
"""
def __init__(self, moe_layer, return_step: str, **kwargs):
self.moe_layer = moe_layer
self.return_step = return_step # "route" 或 "preprocess"
self.kwargs = kwargs # 保存中间结果
def get_early_return_outputs(self, hidden_states, shared_expert_output):
"""收集作为 CUDA graph 输出的张量"""
if self.return_step == "route":
# moe_router scope: 返回 3 个张量
outputs = [hidden_states, self.kwargs['probs'], self.kwargs['routing_map']]
elif self.return_step == "preprocess":
# moe_preprocess scope: 返回 3 个张量 + dispatcher 属性
outputs = [self.kwargs['hidden_states'], self.kwargs['probs'], self.kwargs['residual']]
# 遍历 dispatcher 的 cudagraph_attrs,收集所有 tensor 属性
for attr_name in self.moe_layer.token_dispatcher.cudagraph_attrs:
# 支持层级属性,如 'shared_experts.gate_score'
hier_attr_name = attr_name.split('.')
attr = self.moe_layer.token_dispatcher
for name in hier_attr_name:
attr = getattr(attr, name, None)
if attr is None:
break
if isinstance(attr, torch.Tensor):
outputs.append(attr)
# 如果有 shared expert output,也加入
if shared_expert_output is not None:
outputs.append(shared_expert_output)
return outputs
因此对于MoE层, 我们只需要在forward里面捕捉一下异常即可:
def forward(self, hidden_states: torch.Tensor):
"""MoE forward: route → dispatch → compute → combine"""
# ECHO 模式使用不同的 forward
if self.config.moe_enable_echo:
return self.echo_forward(hidden_states)
def custom_forward(hidden_states):
try:
# 这三个方法都被装饰器包装
# Capture 阶段:可能在任意位置抛出异常
# Replay 阶段:检查 tensor_store,可能跳过计算
shared_expert_output = self.shared_experts_compute(hidden_states)
probs, routing_map = self.route(hidden_states)
hidden_states, probs, residual = self.preprocess(hidden_states, probs, routing_map)
except MoECudaGraphPartialCaptureSignal as e:
# 捕获 Partial Capture 信号
# 这意味着我们只需要返回中间结果作为 CUDA graph 输出
return e.get_early_return_outputs(hidden_states, shared_expert_output)
# 如果没有抛出异常,继续执行剩余部分
# 这部分在 Capture 阶段不会执行(因为异常提前返回)
# 在 Replay 阶段会执行(从 store 恢复后继续)
dispatched_input, probs = self.dispatch(hidden_states, probs)
output, mlp_bias = self.routed_experts_compute(dispatched_input, probs, residual)
output = self.combine(output, shared_expert_output)
return output, mlp_bias
# 支持激活重计算
if self.moe_layer_recompute:
outputs = tensor_parallel.checkpoint(custom_forward, False, hidden_states)
else:
outputs = custom_forward(hidden_states)
return outputs
另外需要注意一点是,Preprocess 会被 Graph Capture,但是其计算结果 tokens_per_expert, routing_map 等这些被存储为 dispatcher 的属性,因此这些也需要作为CUDA Graph的输出,我们用cudagraph_attrs来标记。
# dispatcher.preprocess() 执行后会设置这些属性:
dispatcher.tokens_per_expert = [100, 50, 200, ...] # 动态计算的!
dispatcher.routing_map = torch.Tensor(...)
dispatcher.input_splits = [...]
# ... 等等
同时我们需要调整DtoH 同步点,需要延迟到 before_ep_alltoall完成DtoH的同步。
class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
def __init__(self, ...):
# ... 省略其他初始化 ...
# 调整 DtoH 同步点:如果使用 moe_preprocess,延迟到 graph 边界之后
if (
config.cuda_graph_impl == "transformer_engine"
and 'moe_preprocess' in config.cuda_graph_scope
):
self.cuda_dtoh_point = "before_ep_alltoall" # 延迟
else:
self.cuda_dtoh_point = "before_permutation_1" # 默认
# 需要被 CUDA graph 捕获的属性列表
self.cudagraph_attrs = [
'tokens_per_expert', # 每个 expert 的 token 数量
'input_splits', # 发送到各 EP rank 的 token 数
'output_splits', # 从各 EP rank 接收的 token 数
'output_splits_tp', # 从各 TP rank 接收的 token 数
'num_out_tokens', # 输出 token 总数
'num_global_tokens_per_local_expert', # 全局 token 分布
'reversed_local_input_permutation_mapping', # 反向 permutation 映射
'routing_map', # token-expert 路由映射
]
self.valid_cudagraph_attrs = None # 实际有效的属性(Cpature时确定)
def set_shared_experts(self, shared_experts):
"""设置 shared expert 时,添加额外的属性"""
super().set_shared_experts(shared_experts)
if shared_experts.use_shared_expert_gate:
self.cudagraph_attrs.append('shared_experts.gate_score')
self.cudagraph_attrs.append('shared_experts.cached_fc1_input')
另外一个区别是是cudagraph_attrs和valid_cudagraph_attrs。cudagraph_attrs 是静态定义的”想要捕获的属性”候选列表,valid_cudagraph_attrs 是运行时验证后”实际存在且是 tensor”的有效列表。这种设计允许不同配置(如是否使用 shared expert gate)复用同一份代码,同时确保 graph 输出结构的一致性。
cudagraph_attrs (候选,静态定义):
┌─────────────────────────────────────────────────────┐
│ tokens_per_expert │
│ input_splits │
│ output_splits │
│ routing_map │
│ shared_experts.gate_score ← 可能不存在 │
│ shared_experts.cached_fc1_input ← 可能不存在 │
└─────────────────────────────────────────────────────┘
│
↓ 运行时过滤
valid_cudagraph_attrs (有效,运行时确定):
┌─────────────────────────────────────────────────────┐
│ tokens_per_expert ✓ 存在且是 tensor │
│ input_splits ✓ 存在且是 tensor │
│ output_splits ✓ 存在且是 tensor │
│ routing_map ✓ 存在且是 tensor │
│ (shared_experts.gate_score 被过滤掉 - 不存在) │
│ (shared_experts.cached_fc1_input 被过滤掉) │
└─────────────────────────────────────────────────────┘
实现MoE的Partial Replay:装饰器+状态恢复
至此,我们已经完全理解了CUDA Graph Capture的流程,做完这套流程后,TransformerLayer下的cuda_graphs属性已经设置好了每个microbatch的graph callable闭包了,我们看一下Replay情况下,TransformerLayer的_te_cuda_graph_replay在做什么事情:
def _te_cuda_graph_replay(self, *args, **kwargs):
# 1. 调用父类 replay
cuda_graph_output = super()._te_cuda_graph_replay(*args, **kwargs)
if 'moe_router' in scope:
# 2. 解析输出
hidden_states, probs, routing_map_or_residual = cuda_graph_output[:3]
# 3. 恢复 dispatcher 属性
for i, attr_name in enumerate(valid_cudagraph_attrs):
setattr(dispatcher, attr_name, cuda_graph_output[3+i])
# 4. 设置 tensor store
self.mlp.cudagraph_tensor_store.set(
hidden_states=hidden_states,
probs=probs,
residual=residual,
)
# 5. 继续 MoE forward
mlp_output = self.mlp(hidden_states)
# 6. 清理 store
self.mlp.cudagraph_tensor_store.clear()
其中父类GraphableMegatronModule就是在根据当前的microbatch index来设置hook(后面会讲到)和调用闭包。
def _te_cuda_graph_replay(self, *args, **kwargs):
"""
CUDA graph replay for this layer and microbatch `self.current_microbatch` using TE
interface. TransformerEngine versions>=1.10 allow keyword arguments with CUDA graph.
However, CUDA graph accepts only Tensor inputs.
Hence, check if the arguments are all tensors.
"""
# ...
cg_index = getattr(self, 'current_microbatch', 0) % len(self.cuda_graphs)
cudagraph_args, cudagraph_kwargs = self._get_te_cuda_graph_replay_args(*args, **kwargs)
for hook, hook_args in self.cuda_graph_manual_hooks:
hook(*hook_args)
return self.cuda_graphs[cg_index](*cudagraph_args, **cudagraph_kwargs)
我们可以看到TransformerLayer的会将闭包的结果保存到一个特殊的类MoECudaGraphTensorStore来统一保存CUDA Graph的结果,接着会调用原先应该执行的mlp函数,接下来的问题是,mlp函数的部分计算已经在CUDA Graph里面算过了,怎么跳过这些部分呢?
答案还是通过装饰器的方式。只不过现在不是抛出异常而是直接从MoECudaGraphTensorStore读取即可:
def maybe_skip_or_early_return_by_cudagraph(step_condition):
"""
step_condition 可以是:
- "shared_experts_compute": 跳过 shared expert 计算
- "route": 跳过 router 计算(或在 capture 时提前返回)
- "preprocess": 跳过 preprocess 计算(或在 capture 时提前返回)
"""
def maybe_raise_signal(moe_layer, **kwargs):
"""Capture 阶段:检查是否需要提前返回"""
if (
moe_layer.config.cuda_graph_impl == "transformer_engine"
and moe_layer.training
and is_graph_capturing()
):
if step_condition == "route" and 'moe_router' in scope and 'moe_preprocess' not in scope:
raise MoECudaGraphPartialCaptureSignal(moe_layer, "route", **kwargs)
elif step_condition == "preprocess" and 'moe_preprocess' in scope:
raise MoECudaGraphPartialCaptureSignal(moe_layer, "preprocess", **kwargs)
def decorator(func):
def wrapped_func(moe_layer, *args, **kwargs):
# 非 cudagraph 路径:直接执行
if not is_graph_capturing() and moe_layer.cudagraph_tensor_store.is_empty():
return func(moe_layer, *args, **kwargs)
# Capture 和 Replay 路径
if step_condition == "route":
if moe_layer.cudagraph_tensor_store.probs is None:
# Capture 阶段
# 执行 router 计算
probs, routing_map = func(moe_layer, *args, **kwargs)
# 可能抛出异常提前返回
maybe_raise_signal(moe_layer, probs=probs, routing_map=routing_map)
else:
# Replay 阶段:从 store 读取,跳过计算
probs = moe_layer.cudagraph_tensor_store.probs
routing_map = moe_layer.cudagraph_tensor_store.routing_map
return probs, routing_map
# 类似处理 "preprocess" 和 "shared_experts_compute"
return wrapped_func
return decorator
至此,我们完成理解了MoE Partial Capture/Replay的工作流程。
为什么要手动设置hook?
最后一个需要注意一点是,Megatron-LM一般还会有一个forward pre-hooks功能,用于在模块前向传播之前执行某些操作,比如开启Distributed Optimizer (类似 ZeRO-1/2)时会更新本地1/N的参数,其他参数需要all-gather操作,这个是放在forward pre-hooks里面的。
正常来说, 触发 forward_pre_hooks是放在module.__call__(input)里面的,现在我们改动了module.__call__(input),执行我们的_te_cuda_graph_capture和_te_cuda_graph_replay,因此我们也要在_te_cuda_graph_replay里面加入这部分功能:
# 1. 收集需要手动触发的 hooks
def setup_manual_hooks(self, make_hook_func):
self.cuda_graph_manual_hooks = []
# 找到所有包含参数的子模块
for submodule in self._get_submodules_under_cudagraphs():
for module in submodule.modules():
if next(module.parameters(recurse=False), None) is not None:
# 为每个有参数的模块创建 hook
self.cuda_graph_manual_hooks.append(
(make_hook_func(), (module,)) # (hook函数, 参数)
)
# 2. 在 graph replay 前手动触发
def _te_cuda_graph_replay(self, *args, **kwargs):
# 手动触发所有 hooks
for hook, hook_args in self.cuda_graph_manual_hooks:
hook(*hook_args) # 等待参数 All-Gather 完成
# 现在参数已经准备好,可以安全地 replay graph
return self.cuda_graphs[cg_index](*cudagraph_args, **cudagraph_kwargs)
其收集的时机是:
# training.py 中的调用顺序
cuda_graph_helper = TECudaGraphHelper(...)
cuda_graph_helper.create_cudagraphs() # 1. 创建 CUDA Graphs
cuda_graph_helper.cuda_graph_set_manual_hooks() # 2. 设置 Manual Hooks
# 然后开始训练循环
如何继承到训练迭代中?
最后我们来看一下,我们这套方法如何融入在正常的训练流程里面。 首先会在训练最开始的阶段创建TECudaGraphHelper:
if args.cuda_graph_impl == "transformer_engine":
cuda_graph_helper = TECudaGraphHelper(
model=model,
config=config,
seq_length=seq_length,
micro_batch_size=micro_batch_size,
optimizers=optimizers,
)
在TECudaGraphHelper的初始化阶段会收集可被partial cuda graph的layer:
class TECudaGraphHelper:
def __init__(self, model, config, seq_length, micro_batch_size, ...):
# 收集 graphable layers
for layer in decoder.layers:
if _layer_is_graphable(layer, config):
self.flattened_callables.append(layer)
在第一个iteration结束时,就可以设置hooks了:
# 在第一个 iteration 结束时
if args.cuda_graph_impl == "transformer_engine":
cuda_graph_helper.cuda_graph_set_manual_hooks()
在运行config里面指定的warmup iteartion后,就可以创建cuda graph并分配到各层了:
def create_cudagraphs(self):
start_time = self._start_capturing()
sample_args, kwargs = self._get_cuda_graph_input_data()
graphs = make_graphed_callables(
tuple(self.flattened_callables),
sample_args,
**kwargs
)
# 分配到各层
for layer in layers:
layer.cuda_graphs = [graphs[...] for batch in range(num_microbatches)]
self._finish_capturing(start_time)
至此,我们完成理解了Megatron-LM里面的partial CUDA Graph的工作流程。
显存占用与 CUDA Graph 数量分析
结合上文分析,下面以表格形式梳理不同配置下 CUDA Graph 的数量以及fwd graph static input buffer以及中间变量的显存大小。其中,CUDA Graph是 CUDA 操作的录制,录制部分占用的显存很小,在实现的时候每个 microbatch有独立的graph;但是fwd graph static input buffer以及中间变量的显存是需要根据PP实现microbatch之间共享的,使得和PP而不是num of microbatches成正比。
| 运行配置 | CUDA Graph 数量 | fwd graph static input buffer显存大小 | 中间变量的显存大小 |
|---|---|---|---|
| 不启用 PP | 2 × num_layers | 与num_layers成正比 | 与num_layers成正比 |
| 启用 PP | 2 × num_layers × num_microbatches | 与num_layers × PP成正比 | 与num_layers × PP成正比 |
这里我们思考一个问题,相比于不使用CUDA Graph,使用partial CUDA Graph实际增加的显存占用在哪里? 通过上面的分析,我们知道相比于不使用CUDA Graph,Partial CUDA Graph 额外增加的显存主要来自:(1) PP 份 Static Input Buffers,(2) PP 份 MoE 中断点张量 (MoECudaGraphTensorStore)。
| 开销类型 | 显存大小估算 | 是否可优化 |
|---|---|---|
| CUDA Graph 对象 | < 100 MB | 较小,可忽略 |
| Static Input Buffers | PP × layer × input_size ~128 MB - 1 GB |
通过 1F1B 复用已优化到 PP 份 |
| MoE 中断点张量 | PP × MoE_layers × store_size ~1-3 GB |
Partial Graph 特有,不可避免 |
| Mempool 碎片 | - | - |
注意到很多地方都和num_layers相关,这是因为我们采用partial cuda graph逐layer进行capture,因此我们需要设置和num_layers成正比的同步点,这提示我们加入我们后面可以尽可能capture更完整的模型部分,我们额外需要的显存也会更小。
使用CUDA Graph时常见的易错点
另外,补充一些使用CUDA Graph常见的易错点。
- CUDA Graph的本质是捕捉某个CUDA stream上若干kernel launches和memory copy操作,它无法捕捉到任何CPU上的操作或逻辑。
比如下述代码片段:
s += a.sum() # s和a为GPU上的tensor
cnt += 1 # cnt是普通的Python计数器
在capture阶段,这两行代码都会被执行,但在replay阶段,s += a.sum()这样的GPU操作会被执行,而cnt += 1这类纯CPU的操作则不会再执行。换句话说,CUDA Graph的replay只会按图执行可捕获的GPU操作,至于代码表面上的其他Python逻辑(如计数器递增),在replay过程中是无效的。
- 诸如if分支等控制流结构,会按capture阶段实际走过的路径,把当时分支里的GPU操作捕捉进图,replay时不会根据新的条件判断重新分支。为了避免动态分支发生在CPU上,可以使用 torch.where 替代 if-else。核心思想是将分支判断搬到GPU上。
