深入理解 Megatron-LM 中的 Full CUDA Graph:MoE 模型训练加速的关键技术

 

从 Partial CUDA Graph 到 Full CUDA Graph

在上一篇深入理解 Megatron-LM 中的 Partial CUDA Graph:MoE 模型训练加速的关键技术中,我们分析了 MoE 训练流程中影响 CUDA Graph 兼容性的主要瓶颈在于 MoE 部分的动态行为 —— 每个 expert 在每次 iteration 中接收的 token 数量都不同,导致 kernel 参数、缓冲区大小和通信模式无法在捕获时固定。 本文将深入探讨如何逐一攻克这些障碍,使 MoE 部分能够无缝融入 CUDA Graph,并最终实现覆盖前向与反向传播的完整 iteration 的 Full CUDA Graph 捕获。与 Partial CUDA Graph(仅捕获 Attention、MLP 等确定性模块)不同,Full CUDA Graph 会将整个训练 iteration(包括 MoE 层)作为一个完整的 GPU 执行图进行捕获和重放,从而最大程度地消除 CPU 开销。 为了解决 MoE 与 CUDA Graph 的兼容性,首先需要攻克以下两个核心技术难题:

  1. 消除 TE grouped GEMM 中的 CPU-GPU 同步:当前 TE 在启动 grouped GEMM 时依赖 torch.split,该操作要求将 GPU 上的 tokens_per_expert 转移至 CPU,这会触发隐式同步,破坏 CUDA Graph 的捕获流程。
  2. 重构 HybridEP 策略以兼容 CUDA Graph:HybridEP 的 token 分发逻辑仍包含动态内存分配和条件分支,需要重构为静态缓冲区 + 固定执行路径模式。这一点其实在 HybridEP 2025 年底的几个 commit 中都有所涉及,我们可以认为现在的 HybridEP 版本已经兼容 CUDA Graph 了。

实际上,完成上述两项优化后,在 MoE 强制负载均衡(force load balancing)设定下,已经可以实现真正意义上的 Full CUDA Graph。然而,在实际训练中并不会采用 force load balancing;而由于 CUDA Graph 需要静态特性,我们也无法在每个 EP Rank 上为极端不均衡的情况预留过大的缓冲区。因此,我们还需要结合 MoE 的负载均衡算法,并实现高效的 Expert 权重分发与梯度聚合机制,同时保持与 CUDA Graph 的兼容,以便在真实场景下安全、高效地使用 Full CUDA Graph。

在接下来的内容中,我们会调整行文顺序:先讨论 CUDA Graph 兼容的 MoE 负载均衡算法,再讨论如何实现高效的 Expert 权重分发与梯度聚合机制,接着介绍 Megatron Full CUDA Graph 的整体实现机制,最后介绍如何解决 TE grouped GEMM 的 CPU-GPU 同步问题。

MoE 负载规划算法

为了实现 MoE 的负载均衡,我们引入了 redundant expert slots(冗余专家槽位)的概念。每个 EP rank 需要预先分配固定数量的 expert 计算槽位。当某些 EP rank 的负载过重(即分配到的 token 数远超平均值)时,可以将部分 token offload 到负载较轻的 EP rank 的空闲槽位上处理。这就需要一套机制来决定:哪些 expert 的权重需要被复制到哪些 redundant expert slots 上,这个机制我们称之为 expert dispatch。我们可以在后文中看到,实际上 expert dispatch 和 token dispatch 非常像,甚至都可以复用 HybridEP 这套通信库。

下图展示了前向过程中整个 expert dispatch 和 token dispatch 的工作流程:

sequenceDiagram
    participant Input as Hidden States
    participant Router
    participant Planner as Offloading Planner
    participant ExpertDisp as Expert Dispatcher
    participant TokenDisp as Token Dispatcher
    participant Experts
    participant Output

    Input->>Router: hidden_states
    Router->>Planner: routing_map, probs
    Planner->>Planner: gen_offloading_plan()
    Planner->>ExpertDisp: expert_offloading_map
    ExpertDisp->>Experts: dispatch weights to echo experts
    Input->>TokenDisp: hidden_states, rerouted_probs
    TokenDisp->>Experts: dispatched tokens
    Experts->>TokenDisp: expert outputs
    TokenDisp->>Output: combined output

假设现在我们已经有了一个 router 生成的路由方案,即我们知道每个 EP rank 上每个 home expert 会收到多少 token。我们现在要做 MoE 的负载规划,即决定每个 EP rank 上的每个 home expert 要将其权重分发到哪个 EP rank 的 redundant expert slots 上,并同时决定原本路由到该 home expert 的 token 中有多少要改为路由到这个 redundant expert slot 上。 这里有三个层级的考虑:

  1. 第一个层级是要考虑每个 EP Rank 上的空闲容量(spare capacity)。空闲容量表示一个 EP rank 还能额外处理多少 token,我们肯定不能将太多的 token 重路由到一个 EP rank 上,我们希望转移后每个 EP Rank 上的计算量更加均衡。
  2. 第二个层级是 home expert 上的,我们想要计算每个 home expert 的溢出量(spillover)。溢出量表示一个 home expert 有多少 token 需要被转移到其他 EP rank 上处理,在这个层级我们希望尽可能减少 expert 权重分发的通信量。比如说,假设某个 home expert 只被安排了转移 1 个 token 的计算到其他 redundant expert slot 上,但是我们却需要将整个 expert 权重分发到这个 redundant expert slot 上,显然这是非常不合算的。
  3. 第三个层级是专家权重的分配。有了每个 EP rank 的 spare capacity 和每个 home expert 的 spillover 之后,每个 home expert 的权重应该被分配到哪个 EP rank 的 spare slot 上。

我们一步步来解决这些问题。

基于前缀和的分摊算法

我们在这一节中考虑第一层级和第二层级的事情。首先考虑每个 EP rank 上的 spare capacity 如何计算:

def gen_intermediate(count_tokens_per_expert_from_ep_rank, ...):
    # 步骤1: 计算每个EP rank的token总数和平均值
    count_tokens_per_ep_rank = count_tokens_per_expert.view(num_ep_ranks, -1).sum(dim=1)
    avg_tokens_per_ep_rank = count_tokens_per_ep_rank.sum() // num_ep_ranks
    
    # 步骤2: 计算spare容量 = max(0, avg - current)
    # 负载低于平均的EP rank有空闲容量接收tokens
    deviation = count_tokens_per_ep_rank - avg_tokens_per_ep_rank
    capacity_spare_per_ep_rank = torch.relu(-deviation)
    
    # 步骤3: 计算spillover(溢出量)
    # 关键思路:对每个EP rank内的专家按token数排序,
    # 累积求和后超过平均值的部分就是spillover
    count_tokens_sorted, indices_sorted = count_tokens_per_expert.view(num_ep_ranks, -1).sort(dim=1)
    spillover_cumsum = (count_tokens_sorted.cumsum(dim=1) - avg_tokens_per_ep_rank).clamp(min=0)
    # 从cumsum转回每个专家的spillover
    count_spillover_sorted = torch.cat([spillover_cumsum[:, :1], 
                                         torch.diff(spillover_cumsum, dim=1)], dim=1)

空闲容量的计算非常简单:超过均值的部分要转移到其他 EP rank,没有达到均值的 EP rank 就有空闲容量来承载转移。下面通过一个直观的例子来理解:假设有 4 个 EP ranks,每个 EP rank 上有 4 个 home experts(共 16 个 experts,编号 0-15),token 总数为 1000,平均每个 EP rank 应处理 250 个 token:

                avg = 250
                    │
EP0: 500 ■■■■■■■■■■│■■■■■■■■■■  超额 250 → 转移到其他EP rank
EP1: 350 ■■■■■■■   │■■■           超额 100 → 转移到其他EP rank
EP2: 100 ■■        │              空闲 150 → spare capacity
EP3:  50 ■         │              空闲 200 → spare capacity
                    │

现在我们知道了每个 EP rank 整体需要 offload 多少 token。接下来考虑第二个问题:如何设计一种分摊算法,将 EP rank 层级的溢出量分摊到各个 expert 上,使得各 expert 的 spillover 之和恰好等于该 EP rank 的总溢出量? 注意在这个问题里面,我们希望将 expert 权重分发的代价尽可能小,尽可能不要出现分发了一次 expert 权重但是只为了转移一个 token 的情况。

为此,我们算法设计的思想是:负载较轻的专家优先保留自己的 tokens,而负载较重的专家承担更多的 spillover。分摊算法如下:以上面的 EP0 为例,假设其有 4 个 experts(编号 0-3),tokens 分布为 [50, 100, 150, 200](总和 500),avg_tokens_per_ep_rank 为 250:

# Step 1: 排序(从小到大)
sorted_tokens = [50, 100, 150, 200]

# Step 2: 累积和
cumsum = [50, 150, 300, 500]

# Step 3: 减去平均值并 clamp
spillover_cumsum = ([50, 150, 300, 500] - 250).clamp(min=0)
                 = [-200, -100, 50, 250].clamp(min=0)
                 = [0, 0, 50, 250]

# Step 4: 差分得到每个专家的 spillover
spillover = [0, 0-0, 50-0, 250-50]
          = [0, 0, 50, 200]

# 总 spillover: 250 tokens  ← 正好等于超额部分!

我们看到,这 4 个专家的 token 总和为 500,总 spillover 为 0+0+50+200=250,恰好等于该 EP rank 的超额量(500-250=250)。 这里的排序步骤很重要,它确保负载较轻的专家优先保留自己的 tokens,而负载较重的专家承担更多的 spillover,这更加符合我们设计算法的直觉。

# 未排序 [200, 50, 150, 100]:
cumsum = [200, 250, 400, 500]
spillover_cumsum = [0, 0, 150, 250]  # ← 从专家0开始就接近平均值
spillover = [0, 0, 150, 100]

# 排序后 [50, 100, 150, 200]:
cumsum = [50, 150, 300, 500]
spillover_cumsum = [0, 0, 50, 250]   # ← 更晚才超过平均值
spillover = [0, 0, 50, 200]

类似地,对 EP1(experts 4-7,tokens 分布 [50, 80, 100, 120],总和 350)应用同样的算法:cumsum 为 [50, 130, 230, 350],减去 250 后 clamp 得到 [0, 0, 0, 100],差分得 spillover = [0, 0, 0, 100]。也就是说,EP1 超额的 100 个 token 全部由 expert 7(120 tokens,最大)承担。

至此,我们得到所有 16 个 experts 的 spillover 与每个 EP rank 的 spare capacity,可以直接送入下一节的匹配算法:

count_spillover_per_home_expert = [0, 0, 50, 200, 0, 0, 0, 100, 0, 0, 0, 0, 0, 0, 0, 0]
# 专家编号:                        0  1  2   3   4  5  6   7   8  9 10 11 12 13 14 15
# (experts 0-3 在 EP0,4-7 在 EP1,8-11 在 EP2,12-15 在 EP3)

capacity_spare_per_ep_rank = [0, 0, 150, 200]
# EP rank:                    0  1   2    3

基于区间重叠的匹配算法

有了每个 EP rank 的 spare capacity 和每个 home expert 的 spillover 之后,下一步是决定:每个 home expert 的溢出 token 应该被分配到哪个 EP rank 的 redundant expert slot 上。

我们首先考虑一个简单一点的问题:home expert 到 EP rank 的映射问题。我们通过一个贪心分配算法 one_shot_greedy_assignment 来求解。它的核心思想是把每个 home expert 的 spillover 和每个 EP rank 的 spare capacity 都视为一维数轴上的连续区间,通过计算区间重叠来确定分配方案。

我们沿用上一节得到的输入:

# 每个专家的 spillover(来自上一节的计算)
count_spillover_per_home_expert = [0, 0, 50, 200, 0, 0, 0, 100, 0, 0, 0, 0, 0, 0, 0, 0]
# 专家编号:                        0  1  2   3   4  5  6   7   8  9 10 11 12 13 14 15

# 每个 EP rank 的 spare capacity
capacity_spare_per_ep_rank = [0, 0, 150, 200]
# EP rank:                    0  1   2    3

我们分别对 spillover 和 spare capacity 降序排序,这里降序排序的目的是”大的优先匹配”,将最大的 spillover 和最大的 capacity 排在前面,确保大块的溢出量优先被大容量的 EP rank 吸收,避免碎片化分配。

# 对 spillover 降序排序
count_spillover_sorted = [200, 100, 50, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
indices_spillover_sort = [3,   7,   2, 0, 1, 4, 5, 6, 8, 9,10,11,12,13,14,15]  # 原始专家编号

# 对 spare capacity 降序排序
capacity_spare_sorted = [200, 150, 0, 0]
indices_spare_sort    = [3,   2,   0, 1]  # 原始 EP rank 编号

接下来,我们直接进行区间重叠即可:

# 输入:
chunks  = spillover_sorted = [200, 100, 50, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]  # 待分配的量
buckets = capacity_sorted  = [200, 150, 0, 0]                                       # 可接收的容量

# 算法运行过程:
chunks 累积和:  [200, 300, 350, 350, 350, 350, ...]
buckets 累积和: [200, 350, 350, 350]

chunks 区间:
  chunk 0: [0,   200)    spillover=200
  chunk 1: [200, 300)    spillover=100
  chunk 2: [300, 350)    spillover=50
  chunk 3-15:  (spillover=0)

buckets 区间:
  bucket 0: [0,   200)   capacity=200
  bucket 1: [200, 350)   capacity=150
  bucket 2-3:  (capacity=0)

可视化图如下:

                0                 200       300      350
                                                   
chunks:         │◄──────chunk0────►│◄chunk1─►│◄chunk2►│
                       200           100      50   
                                                   
buckets:        │◄──────bucket0───►│◄────bucket1─────►│
                       200               150        
                                                   

我们可以进一步得到输出的 assignment 矩阵,这个矩阵的大小是 [num_total_experts, num_ep_ranks],表示所有 EP rank 上的 home experts 应该 offload 多少 token 到其他 EP rank。注意此时 assignment 矩阵考虑了所有 EP rank 上的 home experts。

                  bucket0    bucket1    bucket2    bucket3
                  [0,200)    [200,350)  [350,350)  [350,350)
chunk0 [0,200)      200          0          0          0
chunk1 [200,300)      0        100          0          0
chunk2 [300,350)      0         50          0          0   (bucket1 结束于 350)
chunk3-15             0          0          0          0

assignment_sorted =
[[200,   0, 0, 0],
 [  0, 100, 0, 0],
 [  0,  50, 0, 0],
 [  0,   0, 0, 0],
 ...                        # 余下的行全部为 0
 [  0,   0, 0, 0]]

indices_spillover_sortindices_spare_sort 反映射回原始编号,可以得到:

  • sorted chunk 0(= 原始 expert 3)→ sorted bucket 0(= EP3):200 tokens
  • sorted chunk 1(= 原始 expert 7)→ sorted bucket 1(= EP2):100 tokens
  • sorted chunk 2(= 原始 expert 2)→ sorted bucket 1(= EP2):50 tokens

但是上面的算法还有一个小问题:它只告诉了我们所有 EP rank 上的每个 home expert 应该 offload 多少 token 到其他 EP rank 上,却还没有考虑每个 EP rank 上 redundant expert slots 的数量限制——每个 EP rank 上能容纳的 expert 数量是有限的。

关于这个问题,我们直接用贪心选择 top-k 即可。假设每个 EP rank 有 2 个 redundant expert slots,对于每个 EP rank,我们直接在上述 assignment 矩阵中每一列选择 top 2 即可。

最终方案是:

  • EP2 上的 2 个 redundant expert slots 分别接受 expert 7 的 100 个 tokens 和 expert 2 的 50 个 tokens;
  • EP3 上的 1 个 redundant expert slot 接受 expert 3 的 200 个 tokens(剩余的 1 个 slot 闲置)。

贪心分配算法

回到上一节的例子,我们已经得到 expert 7 需要 offload 100 个 token 到 EP2 上的一个 redundant expert slot。但还有一个问题没有解决:在 Megatron-LM 当前的实现中,EP 和 DP 的并行度数值是相等的——非 MoE 部分走正常的 DP,而 MoE 部分走 EP。因为并行度数值一样,我们不妨都用 EP rank 来说明。我们知道,原本分配给 expert 7 的 tokens 来源于前一个 attention 阶段的多个 EP rank。不妨假设位于 EP1 上的 expert 7 总共收到了 120 个 token,分别是来自 EP0 的 50 个 tokens、EP1 的 30 个 tokens、EP2 的 25 个 tokens,以及 EP3 的 15 个 tokens。

4  EP ranks专家 7  tokens 分布:
  - 来自 EP0: 50 tokens
  - 来自 EP1: 30 tokens (home rank)
  - 来自 EP2: 25 tokens
  - 来自 EP3: 15 tokens
  - 总计: 120 tokens

我们在前一阶段的基于区间重叠的匹配算法中已知,我们需要将 EP1 上 expert 7 的 100 个 tokens 分配到 EP2 的其中一个 redundant expert slot 上,那么应该从这 4 个 EP rank 各抽多少呢?

                          专家 7 (home: EP1)
                            tokens = 120
                                  
              ┌───────────┬───────┴───────┬───────────┐
                                                   
           来自 EP0    来自 EP1        来自 EP2     来自 EP3
           50 tokens   30 tokens       25 tokens   15 tokens
                                                   
              └───────────┴───────┬───────┴───────────┘
                                  
                                  
                  需要决定 offload  100 tokens
                  应该从 4  EP rank 各抽多少

方法很简单:首先类似 BFS,每个 EP rank 按比例公平分配;然后对于由取整误差带来的剩余部分,同样用区间重叠算法补上即可。

输入: 
  - count_tokens_per_expert_from_ep_rank [ep_size, num_experts]
  - count_tokens_from_home_expert_to_spare_expert [num_experts, num_spare]

┌────────────────────┐
   Phase1 广度优先分配  
├────────────────────┤
 按比例分配每个EP rank公平贡献
   - EP0: 50/120  41.7%
   - EP1: 30/120 = 25.0%
   - EP2: 25/120  20.8%
   - EP3: 15/120 = 12.5%
 使用 floor() 取整可能有剩余
└──────────┬─────────┘
           
      剩余容量 (取整误差)
           
┌────────────────────┐
   Phase2 深度优先补充  
├────────────────────┤
 处理取整误差的剩余容量
 区间重叠算法贪心填充剩余空间
└──────────┬─────────┘
           
        输出:
        - 每个EP rank具体offload多少tokens到每个spare expert

主要的计算过程如下。首先是按比例分配:

# 找到主要供应者(argmax):tokens 来源最多的 EP rank 是 EP0
idx_supplier = argmax([50, 30, 25, 15]) = EP0

# 计算每个 EP rank 的贡献比例
count_tokens_rel    = [50, 30, 25, 15]                              # 各 EP rank 发给专家 7 的 tokens
probs_proportional  = [50/120, 30/120, 25/120, 15/120]
                     [0.417,  0.250,  0.208,  0.125]

# 按比例分配 capacity=100
count_tokens_ideal  = [50/120*100, 30/120*100, 25/120*100, 15/120*100]
                     [41.67,      25.00,      20.83,      12.50]

# 取整(floor)
count_tokens_floors = floor([41.67, 25.00, 20.83, 12.50]) = [41, 25, 20, 12]

总计 = 41 + 25 + 20 + 12 = 98 < 100
剩余容量 = 100 - 98 = 2 tokens

由于按比例分配后 floor 取整自然产生了 2 个 token 的余数,下面我们把这 2 个 token 也补到某些 EP rank 上。

# 剩余容量
capacity_spare_remaining = 100 - 98 = 2

# 各 EP rank 剩余可 offload 的 tokens
# EP0: 50 - 41 = 9 tokens 还没 offload
# EP1: 30 - 25 = 5 tokens 还没 offload
# EP2: 25 - 20 = 5 tokens 还没 offload
# EP3: 15 - 12 = 3 tokens 还没 offload

# 使用区间重叠贪心分配
# 按 EP rank 顺序(EP0 → EP1 → EP2 → EP3)填充剩余容量 2
# EP0 还能再贡献 9 > 2,因此 2 个 token 全部由 EP0 补齐

second_pass_offload = [2, 0, 0, 0]

最终结果如下:

EP0 offload: 41 + 2 = 43 tokens
EP1 offload: 25 + 0 = 25 tokens
EP2 offload: 20 + 0 = 20 tokens
EP3 offload: 12 + 0 = 12 tokens
──────────────────────────────
总计: 100 tokens ✓

至此,我们完整地设计了一个基于贪心的负载均衡算法。

Token 重路由

最后,基于得到的负载均衡结果,我们写一个简单的 Triton kernel 来修改 token 的路由表即可:

# Step 7: Launch Triton kernel with permute map
max_tokens = num_tokens
BLOCK_SIZE = triton.next_power_of_2(max_tokens)
grid = (num_spare_experts,)

# Outputs of the kernel: map_token_to_all_experts, map_permute
reroute_tokens_w_permute_map_kernel[grid](
    indices_token_sorted, idx_expert_for_offload, count_tokens_offloading_to_spare,
    offset_cumulative, map_token_to_all_experts, map_permute, num_tokens, num_experts, num_spare_experts, BLOCK_SIZE
)

# ....

# 每个 offloading expert 一个 block 并行处理
idx_flat = indices_token * num_total_experts + idx_source_expert
tl.store(map_rerouted_ptr + idx_flat, False, mask=mask_valid)  # 清除原始路由

idx_flat_rerouted = indices_token * num_total_experts + idx_offload_col
tl.store(map_rerouted_ptr + idx_flat_rerouted, True, mask=mask_valid)  # 设置新路由

另外还有一个小细节,当整体负载其实已经比较均衡时,即使算出来某些 echo slot 要接收少量 token,通信代价可能得不偿失。所以我们加了一个阈值过滤,直接把部分 spare expert slot 的分配全部置零:

capacity_remaining = max_allowed_load - count_tokens_after_offloading
idx_safe_steps = torch.searchsorted(count_tokens_cumsum, capacity_remaining.unsqueeze(1), ...)
...
mask_column = mask_original_order.any(dim=0)
count_tokens_from_home_expert_to_spare_expert = torch.where(mask_column.unsqueeze(0), ..., zeros)

CUDA Graph 兼容的 MoE 负载均衡算法

为了使该 MoE 负载均衡算法与 CUDA Graph 兼容,我们需要把这套算法实现在 GPU 上,避免 CPU 操作,具体技巧包括:

  1. 尽可能使用 PyTorch 内置的 GPU 张量算子来实现,大部分场景下是足够应付的。
  2. 如果遇到需要精细控制的地方,那么就手写 Triton kernel,启动一个单线程即可。
  3. 最后用 @torch.compile 将这些操作进一步做 kernel fusion,减少 kernel launch 次数以及中间 tensor 的分配。

总体架构如下:

gen_offloading_plan()    @torch.compile 整体优化
    
    ├── gen_intermediate()                PyTorch (cumsum, sort, relu...)
    
    ├── gen_assignment()                  PyTorch
       └── approx_bin_packing_triton()  Triton kernel (默认)
            one_shot_greedy_assignment()   PyTorch (可选)
    
    ├── breadth_first_allocation()        PyTorch
    ├── depth_first_allocation()          PyTorch
    
    └── reroute_tokens_triton()          Triton kernel + PyTorch 混合
        └── reroute_tokens_w_permute_map_kernel   Triton kernel

注意前一节说的 one_shot_greedy_assignment 这个 MoE 负载分配算法。如果是更加复杂的 MoE 负载均衡算法,比如这里的 approx_bin_packing_triton,我们也可以实现到 Triton kernel 里面。

专家权重分发与梯度收集

有了上述 MoE 负载均衡算法之后,我们需要实现一套通信机制来完成 expert dispatch。在前向过程中,我们需要根据 MoE 负载均衡算法给出的结果将专家权重(主要是 fc1 矩阵和 fc2 矩阵)dispatch 到对应的 redundant expert slot 上;在反向传播中,我们首先在 redundant expert slots 上计算梯度,然后需要将梯度 combine 回到 home expert 上。expert dispatch 和 token dispatch 在通信模式上非常像,为了快速实现,我们可以复用 HybridEP 来做 expert dispatch。

实际上,expert dispatch 比 token dispatch 更加简单一些,比如不需要做 permute 和 unpermute。我们可以为 expert dispatch 实现更简单的通信库,这样速度会更快一些。

在这一节中,我们首先介绍基于 NCCL all-to-all 路径的 expert dispatch,然后介绍如何利用 Megatron 的重计算功能实现 redundant expert slots 的层间复用;接下来介绍如何使用 HybridEP 实现 CUDA Graph 兼容的 expert dispatch;最后,我们介绍一些实现中的小技巧。

基于 NCCL alltoall 路径的 expert dispatch

在这一小节中,我们首先不考虑 CUDA Graph 的兼容性问题,并假设我们已经实现了一个基于 NCCL all-to-all 通信的 expert dispatcher(不是 HybridEP),这个其实和 token dispatcher 基本类似。 首先,在前向过程中,我们直接使用 expert dispatcher 来进行权重分发:

# 直接分发 fc1 权重
fc1_expert_dispatch_metadata = self.expert_dispatcher.preprocess(expert_offloading_map)
dispatched_fc1_weights = self.expert_dispatcher.expert_dispatch(
    fc1_expert_dispatch_metadata,
    *fc1_expert_weights,
)
# 设置到 spare slots(引用,非复制)
self.experts.set_expert_weights("fc1", dispatched_fc1_weights, self.echo_expert_indices)

# fc2 同理
# ...

# 常规的 Token dispatch → Expert 计算 → Token combine
output, mlp_bias = dispatch_and_compute(hidden_states, probs, metadata)

dispatched_fc1_weights 是通过 alltoall 通信从 home experts 复制到 redundant expert slots 的权重副本,然后将 redundant expert slots 的权重矩阵引用设置为 dispatched_fc1_weights。 那么梯度是怎么回收到 home experts 的呢?其实就是在 self.expert_dispatcher.expert_dispatch 内部的 torch.autograd.Function 里自定义反向传播函数 backward 即可。在这个 torch.autograd.Functionforward 里面,我们做的是把 home expert 的权重通过 all-to-all 通信发送到它需要到达的 EP rank 上:

# forward: home weights → dispatch_with_permute → dispatched weights
(dispatched_weight, ..., handle) = buffer.dispatch_with_permute(
    hidden=weight_tensor,      # home expert 权重(当作"token"来发送)
    routing_map=routing_map,   # expert_offloading_map
    ...
)

这个时候就有另外一个问题:当负载比较均衡时,某些 redundant expert slot 不会被分配任何 home expert,这个时候要怎么办呢? 技巧还是用一个 torch.autograd.Function 插入到计算图中作为一个节点,然后在 forwardbackward 中自定义一个空矩阵即可。我们可以这么做:在 set_expert_weights 中,判断某个 redundant expert slot 是否被分配了某个 home expert:

def set_expert_weights(self, module, expert_weights, expert_indices):
    for i, expert_index in enumerate(expert_indices):
        if expert_weights[i].numel() == 0:          # ← 没收到任何权重
            setattr(expert_layer, f"weight{expert_index}",
                    DummyFunction.apply(expert_weights[i], weight_shape))
        else:                                         # ← 正常收到了权重
            setattr(expert_layer, f"weight{expert_index}",
                    expert_weights[i])

然后,DummyFunction 就是一个计算图里面的节点,定义如下:

class DummyFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, weight_shape):
        ctx.input = x           # 保存空的输入 tensor(numel=0)
        dummy_weight = torch.empty(weight_shape, dtype=x.dtype, device=x.device)
        return dummy_weight     # 返回一个正确 shape 但内容随机的"占位"权重

    @staticmethod
    def backward(ctx, grad_output):
        return torch.zeros_like(ctx.input), None   # 梯度 = 0

我们还有一个问题没有解决,即如何确保梯度累积不会算错。我们知道,home expert 的梯度来源于两处:一处是本地 expert 计算得到的;另一处是本地 home expert 被复制到远端后,在远端计算梯度再 combine 收回到 home expert。这两个来源的梯度在不同时间、不同机制下产生,如何确保它们被正确累加而不重复? 我们首先要了解两个东西。一个是 param.grad,这个是 PyTorch autograd 反向传播中自动累加的;另一个是 param.main_grad,这个是 optimizer 使用的。一般情况下,如果 .grad 属性非 None,并且没有设置 grad_added_to_main_grad 这个标记,.grad 就会在 PyTorch DDP 的 backward hook 中被搬运到 .main_grad 中。


                    param.grad                    param.main_grad
                                                      
   PyTorch autograd             DDP backward hook            Optimizer
   自动填充                      搬运                          使用
                                                      
                  ┌─────────┐    搬运条件:          ┌──────────┐
                   .grad    ──────────────────→  main_grad   optimizer.step()
                  └─────────┘  if .grad != None    └──────────┘
                                AND not grad_added_to_main_grad

注意到在上述讨论中,home expert 的两个来源的梯度都是通过计算图里面的 autograd 实现的;因此,通过 PyTorch 的 DDP 本身,我们就可以正确设置 .main_grad 这个梯度。

alltoall 路径:
                     .grad                          main_grad
  home GEMM grad ──→ ████ ──┐                    ┌→ ████████████
  echo combine grad ████ ──┤  DDP hook 搬运     
                            └──────────────────→─┘

Megatron 的重计算

上一小节的实现有一个问题:每一层的 redundant expert slot 之间并没有被复用,造成更大的显存压力。本节我们介绍如何基于 Megatron 的重计算机制实现 redundant expert slots 的显存层间复用。 我们首先来理解 Megatron 的重计算功能,特别是 checkpointCheckpointWithoutOutput 两个函数。我们知道,重计算就是在前向传播过程中舍弃掉部分中间变量,在反向过程中重新跑一次前向计算来获得这些中间值。我们来看看 Megatron 的标准 checkpoint 是怎么用的。假设有这样的计算流程:

# A → B → C → D → E → loss

如果不使用 checkpoint,那么前向传播和反向传播是这样的:

Forward:  保存 A, B, C, D, E     显存峰值 = 5份
Backward: 直接反向传播

如果我们使用 checkpoint:

from megatron.core.tensor_parallel import checkpoint

# 把 B, C 包在 checkpoint 里
def compute_bc(a):
    b = fn_b(a)
    c = fn_c(b)
    return c

c = checkpoint(compute_bc, False, a)
d = fn_d(c)
e = fn_e(d)

# Forward:  保存 A, [C], D, E       显存 = 4份(B 不保存)
# Backward: 到 C 时重算 B

在前向传播时,torch 会保存输入 x,然后在 torch.no_grad() 模式下执行 compute_bc 得到 c,中间结果 b 不在计算图中,用完即弃;在反向传播时,当 c 的梯度到达时会触发 hook,hook 调用 _recompute() 重新执行 compute_bc(),这次是在 torch.enable_grad() 下实现的,然后用重建的计算图计算梯度。这个时候显存中会多出 b 的显存,但是只在 backward 的这一段时间内存在。 checkpoint 的原理是在计算图中插入一个 torch.autograd.Function 作为节点,这个节点保存了两样东西:一是保存的输入,存在 ctx 里面;二是要重跑的函数。反向传播时,梯度传到哪里,就在哪个节点重算。 比如,假如我们设置两个 checkpoint:

# 假设计算流程:A → B → C → D → E → F → G → loss
# 我们在两个地方分别做 checkpoint

def compute_bc(a):
    b = fn_b(a)
    c = fn_c(b)
    return c

def compute_ef(d):
    e = fn_e(d)
    f = fn_f(e)
    return f

c = checkpoint(compute_bc, False, a)    # ckpt1:保存输入 a,丢弃 b
d = fn_d(c)                             # 正常计算,保存 c 和 d
f = checkpoint(compute_ef, False, d)    # ckpt2:保存输入 d,丢弃 e
g = fn_g(f)
loss = loss_fn(g)

前向传播的时候,构造的计算图如下:

a ──→ [Ckpt1Function] ──→ c ──→ fn_d ──→ d ──→ [Ckpt2Function] ──→ f ──→ fn_g ──→ g ──→ loss
       保存了: a                                   保存了: d
       函数: compute_bc                            函数: compute_ef
       (b 没保存)                                  (e 没保存)

CheckpointWithoutOutput 则是在 checkpoint 的基础上,把输出的显存占用也去掉了。在下面的例子里面:

from megatron.core.tensor_parallel import CheckpointWithoutOutput

def compute_bc(a):
    b = fn_b(a)
    c = fn_c(b)
    return c

ckpt = CheckpointWithoutOutput()
c = ckpt.checkpoint(compute_bc, a)          # 前向执行,保存输入 a,中间变量 b 不保存
d = fn_d(c)
e = fn_e(d)
ckpt.discard_output_and_register_recompute(d)  # 丢弃 c 的数据,把重算挂在 d 上

# Forward:  保存 A, [D], E          显存 = 3份(B 不保存,C 也被释放!)
# Backward: d 的梯度到达时 → 重算 compute_bc → 恢复 C → 继续反向传播

我们可以看到,使用 CheckpointWithoutOutput 的话多了一步 ckpt.discard_output_and_register_recompute(hook_tensor)。这里的意思是在 hook tensor d 上挂一个 backward hook,确保在反向传播时把 tensor c 重算出来。 所以我们可以总结:标准 checkpoint 虽然丢弃了中间变量,但输出必须保留;而 CheckpointWithoutOutput 更进一步,连输出也丢弃,等到反向传播时再通过 hook 重新计算恢复。 到这里,我们就可以推理出 CheckpointWithoutOutput 大致的实现,内部走的是 CheckpointWithoutOutputFunction.apply(),前向通过 torch.no_grad() 绕开计算图,同时保持 detach 后的输入,供反向时重算,同时把 ctx 挂到外部对象上:

@staticmethod
def forward(ctx, run_function, checkpoint_without_output_obj, *args):
    with torch.no_grad():          # 不构建计算图 → 中间变量用完即弃
        outputs = run_function(*args)

    # 保存输入(detach 后),供反向时重算
    # detached_args[i] 是 args[i] 的独立副本,is_leaf=True
    # 有独立的 .grad 存储,和原张量不共享 autograd 关系
    detached_args = tuple(
        arg.detach().requires_grad_(arg.requires_grad) if isinstance(arg, torch.Tensor) else arg
        for arg in args
    )
    ctx.detached_args = detached_args
    checkpoint_without_output_obj.ctx = ctx   # 把 ctx 挂到外部对象上
    return outputs

然后在前向传播结束后,调用方可以自定义 hook 触发的时机:

def discard_output_and_register_recompute(self, hook_tensor):
    # 第一步:把输出的 storage 大小 resize 到 0
    # 只释放数据内存,保留 tensor 的元信息(shape, dtype, device, strides)
    for output in self.outputs:
        output.untyped_storage().resize_(0)    # ← 内存释放!但 tensor 对象还活着

    # 第二步:在 hook_tensor 上注册一个 backward hook
    # 当 hook_tensor 的梯度计算完成时,触发 _recompute
    hook_tensor.register_hook(self._recompute)

这里有两个很巧妙的设计:

  1. untyped_storage().resize_(0) 这个命令不是删除 tensor 对象,而是把底层存储清空到 0 字节。tensor 的”壳”(shape、dtype、device 等元信息)还在,下游模块持有的引用不会失效,只是此时访问数据会出错。
  2. hook_tensor 的选择规则是:当 hook_tensor 的梯度被计算出来时,hook 被触发,但这并不意味着必须是下一个 tensor,我们可以选择任意一个符合条件的 tensor。这一点很好,我们可以精心挑选合适的时间点触发 hook,比如实现”计算通信重叠”。例如在 TransformerLayer 中,前向是这样的:
# 前向:layernorm → attention → ...
output = self.input_layernorm_checkpoint.checkpoint(self.input_layernorm, hidden_states)
attention_output = self.self_attention(output, ...)

# 丢弃 layernorm 的输出,把重算挂在 attention_output 上
self.input_layernorm_checkpoint.discard_output_and_register_recompute(attention_output)

反向传播时,梯度先到达 attention_output,此时 hook 触发重算,恢复 layernorm_output,然后梯度继续流经 layernorm_output。时序恰好吻合。 接下来我们来看一下,当 hook_tensor 的梯度到达时,_recompute 是怎么实现的:

def _recompute(self, _):
    inputs = self.ctx.detached_args

    # 用影子 leaf 在 enable_grad 下重新执行前向函数 → 构建计算图
    # 这张图的 leaf 是 detached_args(影子),不是原始 *args
    with torch.enable_grad():
        outputs = self.run_function(*inputs)

    # 关键:把重算结果的数据"塞回"原来的 tensor 壳里
    with torch.no_grad():
        for output, recomputation_output in zip(self.outputs, outputs):
            # 先把 storage 恢复到正确大小
            output.untyped_storage().resize_(recomputation_output.untyped_storage().size())
            # 再把数据复制进去
            output.untyped_storage().copy_(recomputation_output.untyped_storage())

    # 把重算的输出(带计算图的版本)存到 ctx 上,供 backward 使用
    self.ctx.outputs = outputs

我们看到核心操作是 resizecopy_,这样可以把之前 resize 为 0 的存储重新扩大,然后把重算的数据复制进去即可。那么在backward我们做的事情是:

@staticmethod
def backward(ctx, *output_grads):
    inputs  = ctx.detached_args          # 影子 leaf
    outputs = ctx.outputs                # 挂在影子图上的输出

    # 沿着影子图把梯度写到影子 leaf 的 .grad 上
    torch.autograd.backward(outputs, output_grads)

    # 取出影子 leaf 的 .grad,作为"input 的梯度"返回给外层 autograd
    grads = tuple(inp.grad if torch.is_tensor(inp) else None for inp in inputs)
    return (None, None) + grads

因此反向传播的时候梯度流向如下:

外层 autograd:         ... ── args[i] ──[ Function]── outputs ── ...
                                
                                  return  grads[i] 喂回外层
                                
内部影子图:   detached_args[i]  ──►  new_outputs
              (is_leaf, 独立)
            `torch.autograd.backward` 只会往 detached_args[i].grad 
            不会碰到外层的 args[i]也不会碰到任何 nn.Parameter  .grad

这个时候,”内部反传”和”外部反传”是两张独立的图,通过 return grads 这一个出口相连。所以内部用 torch.autograd.backward() 是安全。但是问题是,这在内部反传的时候,因为 detach 出的副本不是 nn.Parameter,是没有 main_grad / grad_added_to_main_grad 这些 TE / DDP 约定属性。

基于重计算实现 redundant expert slots 的显存层间复用

接下来,我们来看一下如何利用重计算来实现 redundant expert slots 的层间复用。在前向过程中,我们利用 CheckpointWithoutOutput 来包住 fc1 的权重分发操作:

fc1_expert_checkpoint = CheckpointWithoutOutput(only_calculate_input_grad=True)
fc2_expert_checkpoint = CheckpointWithoutOutput(only_calculate_input_grad=True)

# 用 CheckpointWithoutOutput 包住权重分发操作
dispatched_fc1_weights = fc1_expert_checkpoint.checkpoint(
    partial(self.expert_dispatcher.expert_dispatch, fc1_expert_dispatch_metadata),
    *fc1_expert_weights,          # 输入:home expert 的权重
)
# 把分发后的权重设置到 spare expert slots 上,这里不是复制
self.experts.set_expert_weights("fc1", dispatched_fc1_weights, self.echo_expert_indices)

# ... fc2 同理 ...

# 常规的 Token dispatch → Expert 计算 → Token combine
output, mlp_bias = dispatch_and_compute(hidden_states, probs, metadata)

# 关键!丢弃权重分发的输出,把重算挂在 MoE 层的最终 output 上
fc1_expert_checkpoint.discard_output_and_register_recompute(output)
fc2_expert_checkpoint.discard_output_and_register_recompute(output)

dispatched_fc1_weights 是通过 HybridEP 通信从 home experts 复制到 redundant expert slots 的权重副本,然后将 redundant expert slots 的权重矩阵引用设置为 dispatched_fc1_weights。 这些副本在前向计算完成后不再需要,在运行 discard_output_and_register_recompute 后存储就会被丢弃。因此,实际上 redundant expert slots 是层间复用的。 因为 redundant expert slots 是层间复用的,所以这里 hook 触发的 _recompute 里面恢复数据,然后在 backward 里面完成梯度收集。这里进行重计算的目的是减少显存占用,代价是 fc1 和 fc2 在重计算的时候各多了一次 all-to-all 通信。 据此,我们可以写出基于 alltoall 的 expert dispatch 在使用重计算时的前向传播和反向传播:

home_weights (W0, W1, W2, W3)
      
      
[CheckpointWithoutOutputFunction]        torch.no_grad() 下执行
   内部: [permute  all_to_all  sort_chunks]
   保存: ctx.detached_args = (W0, W1, W2, W3)
   不建图: dispatch 操作不进入 autograd
      
      
dispatched_weights (W4', W5')             grad_fnno_grad 下产生
      
       setattr
      
[grouped GEMM]  expert_output  [token combine]  output
                                                      
      discard: W4'.storage.resize_(0)                 │
      register_hook(_recompute) ──────────────────────┘

反向传播如下:

L/output 到达
      
      
_recompute hook 触发:
  with torch.enable_grad():
    W4'_new, W5'_new = [permute  all_to_all  sort_chunks](*inputs)
     这次有 grad_fn
  W4'.storage.copy_(W4'_new.storage)   恢复数据
  ctx.outputs = (W4'_new, W5'_new)     保存带 grad_fn 的版本
      
      
L/output 继续反向
      
      
[grouped GEMM backward]  (grad_accum_fusion=False)
      
      ├─→ L/W0..W3  W0.grad..W3.grad            来源 1
      ├─→ L/W4', ∂L/∂W5'
              
              
      [CheckpointWithoutOutputFunction.backward]
        only_calculate_input_grad=True:
        torch.autograd.grad(outputs=(W4'_new, W5'_new),
                            inputs=(W0, W1, W2, W3),
                            grad_outputs=(L/W4', ∂L/∂W5'))
              
              
        [sort_chunks backward  all_to_all backward  permute backward]
              
              
        返回梯度  累加到 W0.grad..W3.grad           来源 2
              
              
      DDP hook: W0.grad  W0.main_grad (搬运)

基于 SyncFree HybridEP 路径的 expert dispatch

上面基于 alltoall 的 expert dispatch 有很多和 CUDA Graph 不兼容的地方,比如出现运行在 CPU 上的条件执行语句 expert_weights[i].numel() == 0。为了和 CUDA Graph 兼容,我们可以使用基于 SyncFree HybridEP 的 expert dispatch。 在发送端,我们将所有 home expert 的权重进行打包。假设有 4 个 home expert,4 个 expert 的权重全部被 stack 进 weight_tensor,即使只有 1 个需要 offload。HybridEP 的 dispatch_with_permute 会根据 routing_map 只发送需要 offload 的 expert 的数据,但 permute 操作仍然要处理整个 weight_tensor

# 所有 home expert 的权重都被打包进输入 tensor
# 因为复用了 HybridEP 的基础设施,这里设置 chunk 是为了传输效率
weight_tensor = torch.stack(weight_list, dim=0).reshape(num_local_home_experts, -1)
weight_tensor = weight_tensor.reshape(num_local_home_experts * num_chunks_per_weight, weight_chunk_size)

# routing_map 扩展到 chunk 级别
routing_map = (
    routing_map.reshape(num_local_home_experts, 1, num_total_experts)
    .expand(-1, num_chunks_per_weight, -1)
    .reshape(num_local_home_experts * num_chunks_per_weight, num_total_experts)
).contiguous()

# dispatch,固定输出大小
(dispatched_weight, ...) = buffer.dispatch_with_permute(
    hidden=weight_tensor,
    routing_map=routing_map,
    num_permuted_tokens=num_dispatched_weights * num_chunks_per_weight,  # 固定!
    ...
)

这里打包了不需要 offload 的权重,虽然不会发送,但仍然参与了 permute,这里存在进一步优化的空间。

注意我们这里设置了chunk size来获得更好的传输效率。因为我们是复用HybridEP的,然后DeepSeek V3里面hidden dim就是8192,因此我们尽可能让每个chunk尽可能接近8192个元素:

n = max(0, round(math.log2(8192 / config.hidden_size)))
self.weight_chunk_size = config.hidden_size * (2 ** n)

在接收端,我们设置了固定大小的输出 buffer,然后均匀切分即可。虽然这里没有收到权重的 redundant expert slots 未被初始化,但是也占了 buffer。

# num_permuted_tokens 是固定值,不依赖运行时数据
num_permuted_tokens = num_dispatched_weights * num_chunks_per_weight

# dispatch 返回固定大小的 tensor
(dispatched_weight, ...) = buffer.dispatch_with_permute(
    hidden=weight_tensor,
    num_permuted_tokens=num_permuted_tokens,  # ← 固定!
    ...
)

# 均匀切分 → 每个 spare slot 都得到固定 shape 的 tensor
dispatched_weight_list = [
    weight.reshape(weight_shape) 
    for weight in dispatched_weight.chunk(num_dispatched_weights, dim=0)
]

另外一个问题是如何正确累加 home expert 两个来源的梯度。为了与 CUDA Graph 兼容,我们使用了 Transformer Engine 中 grouped GEMM 的 gradient_accumulation_fusion 参数和 wgrad_accumulation_mask 参数。在 TE 的 grouped GEMM 反向传播中,可以开启 gradient_accumulation_fusion,直接把权重梯度写入 main_grad(跳过 .grad)。但问题是 grouped GEMM 一次性算出所有 expert(home + spare)的权重梯度,spare slot 的梯度不应该写入 main_grad(因为 spare slot 没有自己的”home 权重”要更新)。于是我们可以用另外一个参数 wgrad_accumulation_mask 来控制:

wgrad_accumulation_mask = [True] * num_home_experts + [False] * num_echo_local_experts
#                          home experts: 融合写入 main_grad    spare slots: 不融合

举个例子,我们可以实现这样的效果:

grouped GEMM backward 对每个 expert 的权重梯度:

Expert 0 (home): mask=True   L/W0 直接写入 W0.main_grad 
Expert 1 (home): mask=True   L/W1 直接写入 W1.main_grad 
Expert 2 (home): mask=True   L/W2 直接写入 W2.main_grad 
Expert 3 (home): mask=True   L/W3 直接写入 W3.main_grad 
Expert 4 (spare): mask=False  L/W4 存到 W4.grad W4  dispatch 来的副本
Expert 5 (spare): mask=False  L/W5 存到 W5.grad W5  dispatch 来的副本

对于 home experts,因为梯度已经直接融合到 main_grad,TE 会默认设置 grad_added_to_main_grad = True。因此这种情况下梯度是按如下方式进行累加的:

grouped GEMM backward:
  Expert 0-3 (home, mask=True):
     L/W 直接融合写入 main_grad   来源 1 
     grad_added_to_main_grad = True

  Expert 4-5 (spare, mask=False):
     L/W 存到 .grad

HybridEPExpertDispatch.backward:
  spare 的梯度  combine_with_unpermute (反向 all-to-all)
   收集回 home weight
   weight.main_grad.add_(wgrad)      来源 2 
   grad_added_to_main_grad = True再次确认
   return None不走 autograd 累加

DDP hook:
  grad_added_to_main_grad = True
   跳过不做搬运

我们可以对比 alltoall 路径和 SyncFree HybridEP 路径下的 expert dispatcher 分别是怎么累加梯度的:

                        alltoall 路径              SyncFree HybridEP 路径
─────────────────────────────────────────────────────────────────────────
gradient_accumulation    关闭 (False)              开启 ( mask)
_fusion

来源1 (home计算)          .grad                    main_grad (融合)
  梯度去向                                          mask=True  expert

来源2 (echo计算)          .grad (autograd反向)      main_grad (手动add_)
  梯度去向               标准 all-to-all 反向        combine_with_unpermute

合并方式                 autograd 自动累加到 .grad   两个来源各自直接写 main_grad
                          DDP hook 搬到 main_grad

DDP hook                 执行搬运                    跳过已在 main_grad 
  (grad  main_grad)     (.grad  main_grad)

CUDA Graph 兼容           ( .tolist() )        
─────────────────────────────────────────────────────────────────────────

我们也可以这样对比两者的区别:

alltoall 路径:
                     .grad                          main_grad
  home GEMM grad ──→ ████ ──┐                    ┌→ ████████████
  echo combine grad ████ ──┤  DDP hook 搬运     
                            └──────────────────→─┘

SyncFree HybridEP 路径:
                     .grad        main_grad
  home GEMM grad ──────────────→ ████████ (融合直写, mask=True)
  echo combine grad────────────→ ████████ (手动 add_)
                                     
                              DDP hook 跳过

我们可以看到这里有一个解决 CUDA Graph 兼容性的通用策略:把原本运行在 CPU 上的条件判断变成 GPU 上的 mask tensor(也就是 wgrad_accumulation_mask 这个参数)。据此,我们就可以实现一个 sync-free 的基于 HybridEP 的 expert dispatcher。同样,我们可以写出前向传播时的计算图:

home_weights (W0, W1, W2, W3)
      
      
[HybridEPExpertDispatch.forward]         自定义 autograd.Function
   stack  chunk  dispatch_with_permute (all-to-all)
   保存: ctx.handle, ctx.expert_weights = (W0, W1, W2, W3)
      
      
dispatched_weights (W4', W5')             grad_fn  HybridEPExpertDispatch
      
       setattr
      
[grouped GEMM]                           wgrad_accumulation_mask 控制融合
  home (mask=True):  W0@t0..W3@t3
  spare (mask=False): W4'@t4, W5'@t5
      
      
expert_output  [token combine]  output

以及反向传播时的计算图:

L/output
      
      
[token combine backward]
      
      
[grouped GEMM backward]
      
      ├─→ W0..W3 (mask=True):
         L/W  直接融合写入 main_grad             来源 1 
         W0.grad_added_to_main_grad = True
      
      ├─→ W4', W5' (mask=False):
         L/W4' → W4'.grad
         L/W5' → W5'.grad
              
              
      [HybridEPExpertDispatch.backward]
        stack 梯度  combine_with_unpermute (反向 all-to-all)
         weight_grad_list
        for W, wgrad in zip((W0,W1,W2,W3), weight_grad_list):
            W.main_grad.add_(wgrad)                   来源 2 
            W.grad_added_to_main_grad = True
        return None   不通过 autograd 传梯度
              
              
      DDP hook: grad_added_to_main_grad = True  跳过

基于重计算实现 HybridEP expert dispatcher下redundant expert slots 的显存层间复用

类似基于alltoall的expert dispatcher,HybridEP expert dispatcher也可以使用重计算来实现redundant expert slots 的显存层间复用。 但是这样会存在一些问题需要解决。 注意到在上一章中,我们发现HybirdEP反向传播的时候,是直接写到main_grad。但是如前面所述,CheckpointWithoutOutputFunction的使用了detach,detached args是不会有main_grad这个属性的,会有问题。因此我们对CheckpointWithoutOutputFunction的前向过程进行了修改,对leaf的向量不detach,只detach非leaf向量,这样能保留leaf的向量的main_grad:

    @staticmethod
    def forward(ctx, run_function, checkpoint_without_output_obj, *args):
        """Forward pass."""
        
        with torch.no_grad(), fwd_ctx:
            outputs = run_function(*args)
        # Skip detach of leaf nodes, since we want to access main_grad of leaf nodes.
        detached_args = tuple(
            arg if (isinstance(arg, torch.Tensor) and arg.is_leaf) else
            (arg.detach().requires_grad_(arg.requires_grad) if isinstance(arg, torch.Tensor) else arg)
            for arg in args
        )
        ctx.detached_args = detached_args
        ctx.only_calculate_input_grad = checkpoint_without_output_obj.only_calculate_input_grad
        # ctx.save_for_backward(*detached_args)
        # the CheckpointWithoutOutput object is passed in, then it can access the saved input
        # tensors later for recomputation
        checkpoint_without_output_obj.ctx = ctx
        return outputs

同时,我们在_recompute里面,非 leaf 又 detach 了一次,leaf 依旧保持原身:

def _recompute(self):
    # 注意这里对非 leaf 又 detach 了一次(防止二次挂图),leaf 依旧保持原身
    inputs = tuple(
        arg if (isinstance(arg, torch.Tensor) and arg.is_leaf)
        else _detach_with_grad(arg)
        for arg in self.ctx.detached_args
    )
    with torch.enable_grad():
        new_outputs = self.run_function(*inputs)
    self.ctx.outputs = new_outputs

因此在在这张”半影子图”里,原始 nn.Parameter 就是图上的 leaf。run_function 内部(比如 HybridEPExpertDispatch.backward)可以直接 weight.main_grad.add_(...)

同样地,对于backward我们也要修改。原本backward使用的是torch.autograd.backward, 会累加到.grad一次。因为我们现在不detach了,在外层autograd时候又会加一次,因为我们用only_calculate_input_grad这个参数控制,如果为true就用torch.autograd.grad, 这个是不会累加到.grad的。

@staticmethod
def backward(ctx, *output_grads):
    inputs  = ctx.detached_args
    outputs = ctx.outputs

    valid_outputs, valid_grads = _filter_nones(outputs, output_grads)

    if ctx.only_calculate_input_grad:
        # 纯函数式:只"返回"梯度,不往任何 .grad / main_grad 里写
        tensor_inputs = [x for x in inputs
                         if torch.is_tensor(x) and x.requires_grad]
        grads = torch.autograd.grad(
            outputs=valid_outputs,
            inputs=tensor_inputs,
            grad_outputs=valid_grads,
            allow_unused=True,
        )
        grad_map = {id(x): g for x, g in zip(tensor_inputs, grads)}
        input_grads = tuple(
            grad_map.get(id(x)) if torch.is_tensor(x) else x
            for x in inputs
        )
    else:
        # 兼容路径:沿用 backward(),此时调用方必须自己保证不会造成 double accumulation
        # (典型做法:run_function 内部已经把 wgrad 写到 main_grad、并 return None,
        #  所以这条反传里对应 weight 那一支根本不产生 .grad 写入)
        torch.autograd.backward(valid_outputs, grad_tensors=valid_grads)
        input_grads = tuple(
            x.grad if torch.is_tensor(x) else x for x in inputs
        )

    return (None, None) + input_grads

因此反向传播的时候,整体的流向图如下:

外层 autograd:   ... ── args[i] ──[ Function]── outputs ── ...
                          
                           return  input_grads[i] 喂回外层由外层累加到 .grad/main_grad
                          
内部半影子图:   ctx.detached_args[i]  ──►  new_outputs
                (leaf 就是原 nn.Parameter 本身)
      `torch.autograd.grad` 只是"算出并返回"梯度
      绝不触碰 inputs[*].grad也不触发 TE  main_grad 写入
      ──► 不会和外层 autograd 争抢 nn.Parameter

我们可以比较这两种方式的差异:

                                A. 全部 detach              B. 跳过 leaf detach
────────────────────────────────────────────────────────────────────────────────────────
ctx.detached_args             影子副本                    原始 nn.Parameter 本体
nn.Parameter 的身份             (is_leaf=True 
                                不是 nn.Parameter)

访问 weight.main_grad /         不能                        
grad_added_to_main_grad

内部反传 API                    torch.autograd.backward     torch.autograd.grad
                                (outputs, grads)            (outputs, inputs, grads)

内部反传副作用                  但只写到影子副本上      若用 backward() 会写到真
( .grad / main_grad)           安全                      nn.Parameter  与外层
                                                            autograd 冲突

梯度交给外层的方式              读影子 leaf  .grad         torch.autograd.grad
                                 return                   的返回值直接 return

是否会 double-accumulate        不会                        grad() 路径不会
                                                            backward() 路径取决于
                                                            run_function 是否已自行处理

适用场景                        checkpoint 的函数内部       checkpoint 的函数需要通过
                                不需要触达权重对象本身      main_grad  fused wgrad
                                                            累加HybridEP / SyncFree 
────────────────────────────────────────────────────────────────────────────────────────

综合对比

同样地,我们也可以使用重计算来实现 sync-free HybridEP 的 expert dispatcher。我们有两个维度来看待 expert dispatcher:通信手段选择 all-to-all 或者 HybridEP,以及是否开启重计算。组合后,我们可以得到下面的对比分析表:

                    alltoall            alltoall            HybridEP           HybridEP
                     Ckpt              Ckpt              Ckpt             Ckpt
                    (组合一)            (组合二)            (组合三)           (组合四)
───────────────────────────────────────────────────────────────────────────────────────
autograd 节点       标准 PyTorch ops     CkptFunc         HybridEP           CkptFunc 
(dispatch)                              标准 ops           ExpertDispatch      HybridEPExpertDispatch

来源1 去向          .grad               .grad              main_grad(融合)     main_grad(融合)

来源2 去向          .grad               .grad              main_grad(手动)     main_grad(手动)
                    (autograd链)        (autograd链)       (combine+add_)     (combine+add_)

合并位置            .grad               .grad              main_grad           main_grad

DDP hook            执行搬运             执行搬运            跳过               跳过

discard+recompute                                                        
显存层间复用

反向 all-to-all     autograd 自动       autograd 自动       手动 combine        手动 combine
(梯度收集)          反向 all_to_all      反向 all_to_all

CUDA Graph                                                               
兼容

额外通信次数        0                   +2 (重新dispatch)   0                  +2 (重新dispatch)
(相比无 Ckpt)
───────────────────────────────────────────────────────────────────────────────────────

技巧:显式 metadata 对象传递

另外有一个关于 metadata 的数据结构小技巧。在 expert dispatcher 中,我们使用 metadata 的方式来保存状态并进行参数传递。在之前,状态是通过实例变量的形式直接保存在结构体的 self 下面的:

# 状态存储在 dispatcher 实例上(self.xxx)
class MoEFlexTokenDispatcher:
	def dispatch_preprocess(self, hidden_states, routing_map, probs):
		self.hidden_shape = hidden_states.shape # 存在 self 上
		self.token_probs = ...
		self.handle = ...

	def token_combine(self, hidden_states):
		return hidden_states.view(self.hidden_shape) # 从 self 读取

现在是通过显式 metadata 对象传递:

  class MoEFlexTokenDispatcher:
      def dispatch_preprocess(self, hidden_states, probs, metadata):
          metadata.hidden_shape = hidden_states.shape    # 存在 metadata 上
          metadata.token_probs = ...

      def token_combine(self, hidden_states, metadata):
          return hidden_states.view(metadata.hidden_shape)  # 从 metadata 读取

metadata 的本质是将”无状态的逻辑执行流程”和”有状态的数据结构”分离,这样的好处有很多。第一个是同一段逻辑的多次并发调用。比如在 expert dispatch 中我们需要分别传递 fc1 和 fc2 的状态,只需要把它们分别保留到各自的 metadata 里面,这样就可以用同一个 dispatcher 实现多次甚至是并发调用:

# 分别保留 fc1 和 fc2 的状态到各自的 metadata 中
fc1_expert_dispatch_metadata = self.expert_dispatcher.preprocess(expert_offloading_map)
fc2_expert_dispatch_metadata = self.expert_dispatcher.preprocess(expert_offloading_map)
# ...
# 然后用同一个 expert dispatcher,但分别传入各自的 metadata
dispatched_fc1_weights = self.expert_dispatcher.expert_dispatch(fc1_expert_dispatch_metadata, *fc1_expert_weights)
dispatched_fc2_weights = self.expert_dispatcher.expert_dispatch(fc2_expert_dispatch_metadata, *fc2_expert_weights)

如果用 self 存状态,第二次 preprocess 会覆盖第一次的状态(比如 self.handle),导致 fc1 的 dispatch 信息丢失。而 metadata 对象让两次调用的状态完全独立。 第二个好处是更好的数据结构生命周期管理,特别是在使用 CUDA Graph 和重计算的时候。在 CUDA Graph 捕获时,中间状态需要在 graph 外部被引用和管理。metadata 作为显式对象,生命周期由调用方控制,比隐式的 self 状态更容易被 CUDA Graph 正确追踪和重放,因此 metadata 的模式对 CUDA Graph 更加友好。

重设让redundant expert slots并非参数

还有一点需要注意,redundant expert slots实际上并不是真正的模型参数nn.Parameter,只是一个占位符,因此我们需要进行一些配置来改变它们。 首先看一下初始化的逻辑:

self.experts = build_module(
    self.submodules.experts,
    num_echo_local_experts+self.num_home_experts,  # ← 多创建了 echo slot
    ...
)
self.echo_expert_indices = list(
    range(self.num_home_experts, num_echo_local_experts + self.num_home_experts)
)
...
self.experts.free_expert_parameters(self.echo_expert_indices)

build_module 会按 num_home + num_echo 创建出 weight0, weight1, ..., weightN 全部为 nn.Parameter(调 torch.empty 分配显存)。但其中 只有 weight0..weight_{home-1} 是真正的 home 权重,后面的 weightN_home..weightN 只是”占位槽”,它们在前向时会被 set_expert_weights 替换成从其他 EP rank dispatch 过来的权重副本。因此需要从 nn.Module 的参数注册表里摘掉。

def free_expert_parameters(self, expert_indices: List[int]):
    """Free echo expert parameters."""
    to_free_weight_names = [f'weight{i}' for i in expert_indices]
    for module in [self.linear_fc1, self.linear_fc2]:
        # Clear all parameters in the module
        for name, param in list(module.named_parameters()):
            if name in to_free_weight_names:
                delattr(module, name)
                module._parameters.pop(name, None)

Megatron Full CUDA Graph 的实现机制

前面我们已经讨论了 MoE 负载均衡算法与 expert dispatch 的 CUDA Graph 兼容实现,本节来看看 Megatron-LM 是如何实现 Full CUDA Graph 的。

捕获粒度:整个 forward_backward_func

一个自然的问题是:Megatron 的 Full CUDA Graph 是否支持 Pipeline Parallelism(PP)?

答案是支持的。Full CUDA Graph 捕获的是整个 forward_backward_func 的执行,而不是单个 layer 或单个 stage。forward_backward_func 内部会处理所有 PP stages 的前向传播、反向传播,以及 microbatch 的调度(如 1F1B 调度)。因此,一个 CUDA Graph 就包含了

  • 所有 PP stages 的计算
  • 所有 microbatches 的处理
  • PP stages 之间的通信(send/recv)

具体来看,CUDA Graph 的捕获发生在 warmup 步骤完成之后。在捕获时,所有进程通过 torch.distributed.barrier() 同步,然后使用 torch.cuda.graph() 上下文管理器捕获整个 forward_backward_func 的执行:

        if curr_iteration == self.cuda_graph_warmup_steps:
            logger.info(f'Capture CUDA graph for {training_str}!!!')
            torch.distributed.barrier()
            assert FullCudaGraphWrapper.cuda_graph[training_str] is None
            FullCudaGraphWrapper.cuda_graph[training_str] = torch.cuda.CUDAGraph()
            # ... 注册 RNG states
            with torch.cuda.graph(
                FullCudaGraphWrapper.cuda_graph[training_str],
                stream=capture_stream,
                capture_error_mode="thread_local",
            ):
                # 捕获整个 forward_backward_func,包含所有 PP stages
                FullCudaGraphWrapper.result[training_str] = self.forward_backward_func(
                    *args, **kwargs
                )

数据读取:支持多 stage PP

同时我们也能看到,data_read 方法在数据读取上天然支持多 stage PP。它会根据 PP stage 数量,为每个 stage 分别读取 microbatch 数据,并将其复制到静态缓冲区中(这是 CUDA Graph replay 所必需的):

    def data_read(self, data_iterator, model, training, num_microbatches):
        """Read all microbatch inputs from Dataloader and copy to static buffers."""
        if not isinstance(model, list) or len(model) == 1:
            # 单 stage 场景(无 PP 或 PP size = 1)
            assert not isinstance(data_iterator, list) or len(data_iterator) == 1
            # ... 处理单个 data_iterator
        else:
            # 多 stage 场景(PP size > 1)
            assert isinstance(data_iterator, list) and len(data_iterator) == len(model)
            data_list = []
            for i in range(len(model)):
                if data_iterator[i] is not None:
                    # 为每个 PP stage 分别读取 microbatch 数据
                    data_list_i = []
                    for b in range(num_microbatches):
                        data_list_i.append(...)
                    data_list.append(iter(data_list_i))
                else:
                    data_list.append(None)

因此,无论 PP 有多少个 stage,整个系统只会创建 2 个 CUDA Graph:1 个用于 training,1 个用于 validation。这种设计极大地简化了 CUDA Graph 的管理复杂度。

Device-Initiated Grouped GEMM 消除同步

TE 的两种 grouped GEMM 后端

在 MoE 架构中,Router 会将不同的 token 路由到不同的 expert,因此每个 expert 接收到的 token 数量各不相同。这意味着我们不能简单地使用一个统一大小的矩阵乘法(GEMM),而是需要执行一组大小各异的 GEMM——即 grouped GEMM在 Megatron-LM 中,通常通过调用 Transformer Engine(TE)的 grouped GEMM 来完成这一计算,具体是调用 TE 的 pytorch.GroupedLinear 类。

TE 实现 grouped GEMM 主要通过两种后端:

  1. cuBLAS 后端:当前 cuBLAS 实际上并没有提供真正的单次 kernel launch grouped GEMM 接口。TE 的做法是循环调用多次 cuBLAS kernel,并将它们分发到不同的 CUDA stream 上并行执行(见下图左侧)。这种 multi-stream 方式虽然能实现并行,但我们观察到在 B 系列 GPU 上 kernel launch 的 overhead 依然很大,不同 stream 之间并没有很好地 overlap 起来。
  2. CUTLASS 后端:CUTLASS 原生支持单次 kernel launch 的 grouped GEMM(见下图右侧),所有 expert 的计算在一次 kernel 调用中完成,减少了 launch 开销。 grouped GEMM backend

我们可以通过环境变量 NVTE_USE_CUTLASS_GROUPED_GEMM 来控制使用 cuBLAS 还是 CUTLASS 后端。但在 TE 2.12 版本中,这个环境变量仅对 H 系列 GPU(Hopper)生效,在 B 系列(Blackwell)上即使设置了也会 fallback 到 cuBLAS。值得注意的是,在 H 系列芯片上 multi-stream 方式的性能反而更优,而在 B 系列芯片上 cuBLAS 同样更快。因此,我们后续讨论的实现方案以 cuBLAS 作为后端

GPU-CPU 同步的根因:torch.split

尽管 cuBLAS 后端在性能上有优势,但 TE 2.12 在启动 grouped GEMM 时会触发一次关键的 GPU-CPU 同步,其根源在于 torch.split 这个 API。在 TE 2.12 的代码中,有这样一行:

inputmats = torch.split(cast_if_needed(inp_view, activation_dtype), m_splits_list)

torch.split 要求 m_splits_list 参数必须是 CPU 上的 Python list。它会将输入 tensor inp_view 按照 m_splits_list 中指定的大小切分成多个子 tensor,每个子 tensor 对应一个 expert 的输入。

m_splits_list 就是 tokens_per_expert——每个 expert 分到的 token 数量。这个信息最初是在 GPU 上通过 routing 计算得到的,因此在 Megatron-LM 中,需要通过 .cpu().tolist() 将 GPU tensor 转换为 CPU 上的 Python list。正是这个 .cpu() 操作触发了隐式的 GPU-CPU 同步——CPU 必须等待 GPU 上所有先前提交的操作完成,才能读取到正确的 tokens_per_expert 值。

从另一个角度看,在 TE 2.12 版本中,m_splits 虽然作为参数传递到了 C++ 层面的 te_general_grouped_gemm,但实际上并未被使用——因为切分信息已经通过 torch.split 编码在了 A[i].shape[0] 中。这意味着 m_splits 这个变量是完全冗余的

解决方案:GPU 端参数设置 kernel

在 CUDA Graph 的要求下,我们不能有任何 GPU-CPU 同步。因此,我们需要想办法将 torch.split 完全搬到 GPU 上。解决思路是m_splits 始终保留在 GPU 上,通过一个轻量级 GPU kernel 完成原本由 torch.split 承担的参数配置工作

具体实现可以参考这个文件,它完成了 TE 2.12 尚未实现的两项关键功能:

  1. 在 B 系列(Blackwell/SM100)上支持 CUTLASS 作为 grouped GEMM 后端
  2. 消除 CPU-GPU 同步,完全在 GPU 端完成类似 torch.split 的功能

对于第 2 点,核心思想是在启动 CUTLASS Grouped GEMM 之前,先启动一个轻量级 GPU kernel(setGroupedGemmArguments_fp16bf16),它直接从 GPU 内存中读取 m_splits 并配置每个 expert 的 GEMM 参数。整个过程无需 CPU 参与,因此不会触发 CPU-GPU 同步,使得 MoE 层可以完全被 CUDA Graph 捕获:

__global__ void setGroupedGemmArguments_fp16bf16(int num_experts, const int64_t *gemm_m_per_expert,
                                        int gemm_n, int gemm_k, ElementA *ptr_A, ElementD *ptr_D,
                                        UnderlyingProblemShape *problem_sizes,
                                        ElementA **ptr_A_list, StrideA *stride_A_list, StrideB *stride_B_list,
                                        ElementD **ptr_D_list, StrideD *stride_D_list) {
  uint64_t m_offset = 0;
  if (threadIdx.x == 0 && blockIdx.x == 0) {  // 只用一个线程执行
    for (int expert_id = 0; expert_id < num_experts; expert_id++) {
      int gemm_m = int(gemm_m_per_expert[expert_id]);  // <-- 直接从GPU读取m_splits
      problem_sizes[expert_id] = cute::make_shape(gemm_m, gemm_n, gemm_k);

      ptr_A_list[expert_id] = ptr_A + m_offset * gemm_k;  // 计算每个expert的A指针
      stride_A_list[expert_id] = cute::make_stride(int64_t(gemm_k), _1{}, _0{});
      // ...
      ptr_D_list[expert_id] = ptr_D + m_offset * gemm_n;  // 计算每个expert的D指针
      stride_D_list[expert_id] = cute::make_stride(int64_t(gemm_n), _1{}, _0{});

      m_offset += gemm_m;  // 累加偏移
    }
  }
}

对于第 1 点(在 B 系列上实现 CUTLASS grouped GEMM),相关函数 generic_moe_gemm_kernelLauncher_fp16bf16 的实现主要分为 4 个阶段:

  • 阶段 1:定义 CUTLASS 类型。使用 CUTLASS 的 Builder 模式定义高性能 Grouped GEMM kernel 的所有类型参数,包括矩阵元素类型、布局、Tile 大小和调度策略等。
// 定义 GEMM 问题形状
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;  // <M,N,K>

// 配置矩阵类型和布局
using ElementA = ElementInput;  // FP16 或 BF16
using LayoutA = cutlass::layout::RowMajor;
using ElementAccumulator = float;  // 累加器用 FP32

// 核心配置:针对 SM100 (Blackwell) 架构
using ArchTag = cutlass::arch::Sm100;
using MmaTileShape = Shape<_256, _256, Int<128 / sizeof(ElementA)>>;  // Tile 大小
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100;  // TMA调度

// 构建 GEMM kernel
using GemmGrouped = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel2SM>;
  • 阶段 2:Workspace 内存布局。在预分配的 GPU workspace 中划分内存区域,用于存储各 expert 的 GEMM 参数。这些参数将在阶段 3 中由 GPU kernel 填充。
// 在 workspace 中分配各种指针数组和 stride 数组
auto ptr_A_list = ...;     // 每个 expert 的 A 矩阵指针
auto ptr_D_list = ...;     // 每个 expert 的 D(输出)矩阵指针
auto stride_A_list = ...;  // 每个 expert 的 A stride
auto stride_B_list = ...;  // 每个 expert 的 B stride
auto stride_D_list = ...;  // 每个 expert 的 D stride
auto problem_sizes = ...;  // 每个 expert 的 GEMM shape (M, N, K)
  • 阶段 3:启动参数设置 Kernel。这是消除 CPU-GPU 同步的关键步骤。启动一个轻量级 GPU kernel,在 GPU 上直接读取 gemm_m_per_expert(即 m_splits),并填充每个 expert 的 GEMM 配置:
  • problem_sizes[i] = (M_i, N, K) —— 每个 expert 的 GEMM 形状
  • ptr_A_list[i] —— 每个 expert 的输入指针
  • ptr_D_list[i] —— 每个 expert 的输出指针
  • 各种 stride
setGroupedGemmArguments_fp16bf16<<<1, 32, 0, stream>>>(
    num_experts, gemm_m_per_expert,  // <-- m_splits (GPU tensor)
    gemm_n, gemm_k, ptr_A, ptr_D, problem_sizes,
    ptr_A_list, stride_A_list, stride_B_list,
    ptr_D_list, stride_D_list);
  • 阶段 4:启动 CUTLASS Grouped GEMM。使用阶段 3 在 GPU 上设置好的参数,启动 CUTLASS Grouped GEMM kernel。由于所有参数都已在 GPU 端准备就绪,整个过程不需要任何 CPU-GPU 同步。
// 构建 CUTLASS 参数
args = typename GemmGrouped::Arguments{
    cutlass::gemm::GemmUniversalMode::kGrouped,
    {num_experts, problem_sizes, nullptr},  // 问题形状(在 GPU 上)
    {ptr_A_list, stride_A_list, ptr_B_list, stride_B_list},  // 输入
    {fusion_args, nullptr, stride_D_list, ptr_D_list, stride_D_list},  // 输出
    hw_info, scheduler
};

// 初始化并运行
gemm.initialize(args, workspace + offset);
gemm.run(stream);  // 执行 Grouped GEMM