深入理解 Megatron-LM 中的 Full CUDA Graph:MoE 模型训练加速的关键技术

 

从Partial CUDA Graph到Full CUDA Graph

在上一篇深入理解 Megatron-LM 中的 Partial CUDA Graph:MoE 模型训练加速的关键技术中,我们分析了在MoE训练流程中,影响CUDA Graph兼容性的主要瓶颈是MoE部分的动态行为。在本文中,我们将深入探讨,为了使MoE部分能够无缝地融入CUDA Graph,并最终实现包含前向与反向传播在内的完整iteration的Full CUDA Graph捕获,需要对其结构做出哪些关键性改进。

具体而言,MoE与CUDA Graph的兼容性亟需解决以下三项核心技术难题:

  1. 排除TE grouped GEMM中的CPU-GPU同步环节,实现与CUDA Graph的原生兼容;
  2. 对HybridEP策略进行重构,使其运行逻辑与CUDA Graph协同无障碍; 实际上,完成上述两项优化后,在MoE强制负载均衡(load balancing)的设定下,已能够实现真正意义上的Full CUDA Graph。然而,由于HybridEP与CUDA Graph兼容性提升带来的静态缓冲区约束,我们还需进一步实现高效的Expert权重分发及梯度聚合机制,以充分释放CUDA Graph的性能潜力。

如何让TE启动grouped GEMM不需要GPU-CPU同步?

在MoE中,由于每个experts分到的token数量不同,实际上我们是在做grouped GEMM,通常我们会调用transformer engine (TE)的grouped GEMM,也就是在调用TE的pytorch.GroupedLinear类。 一般来说,TE实现grouped GEMM主要是通过两个后端实现方式。

  1. 调用cuBLAS作为后端实现方式。但是当前cuBLAS实际上并没有真正实现单个kernel launch grouped GEMM的形式,实际上TE是循环调用多次cuBLAS kernel并分发到不同CUDA stream上实现的(见下图的左图)。
  2. 另外一个后端实现方式是cutlass。cutlass是有单个kernel launch grouped GEMM的支持的(见下图的右图)。 grouped GEMM backend

我们可以通过环境变量NVTE_USE_CUTLASS_GROUPED_GEMM来控制使用cuBLAS还是cutlass。但是目前至少在TE 2.12版本里面,这个环境变量只对H系列GPU起作用,在B系列上即使设置了这个环境变量也还是会fallback到使用cuBLAS中。在H系列芯片中,multi-stream的方式做grouped GEMM更快,但是在B系列芯片上,cuBLAS更快,因此我们后面的实现都考虑是cuBLAS作为后端。

TE 2.12启动grouped GEMM会有GPU-CPU同步主要是torch.split这个API带来的。在TE 2.12版本版本的代码中,有这么一项

inputmats = torch.split(cast_if_needed(inp_view, activation_dtype), m_splits_list)

torch.split要求m_splits_list这一项必须在CPU上。torch.split(inp_view, m_splits_list) 会将输入 tensor inp_view 按照 m_splits_list 中指定的大小切分成多个子 tensor,每个子 tensor 对应一个 expert 的输入。

m_splits_list就是tokens_per_expert,因此在Megatron-LM中,需要将tokens_per_expert从GPU转到CPU list。tokens_per_expert需要通过 .cpu().tolist() 将GPU上的tensor转换为CPU上的Python list。这个操作会触发GPU-CPU同步。

从另外一个层面上看,在TE 2.12版本中,m_splits虽然作为了参数传到到C++层面的te_general_grouped_gemm,但是其实并没有用到,因为信息已经通过torch.split编码在 A[i].shape[0] 中。 但是在CUDA Graph的要求下,我们不能有GPU-CPU同步。因此我们需要把m_splits 作为 GPU tensor 传给 kernel,通过GPU操作完成类似torch.split的功能。

具体实现可以参考这个文件,此文件完成了TE 2.12还没有做到事情,·1)实现了在B系列上支持cuBLAS作为后端,2)不需要CPU-GPU同步来完成类似torch.split的功能。

对于第2点,其核心思想是,在启动 CUTLASS Grouped GEMM前,用一个轻量级 GPU kernel (setGroupedGemmArguments_fp16bf16) 来读取 GPU 上的 m_splits 并配置 GEMM 参数,整个过程无需 CPU 参与,因此不会触发 CPU-GPU 同步,使得 MoE 层可以完全被 CUDA Graph 捕获。

__global__ void setGroupedGemmArguments_fp16bf16(int num_experts, const int64_t *gemm_m_per_expert,
                                        int gemm_n, int gemm_k, ElementA *ptr_A, ElementD *ptr_D,
                                        UnderlyingProblemShape *problem_sizes,
                                        ElementA **ptr_A_list, StrideA *stride_A_list, StrideB *stride_B_list,
                                        ElementD **ptr_D_list, StrideD *stride_D_list) {
  uint64_t m_offset = 0;
  if (threadIdx.x == 0 && blockIdx.x == 0) {  // 只用一个线程执行
    for (int expert_id = 0; expert_id < num_experts; expert_id++) {
      int gemm_m = int(gemm_m_per_expert[expert_id]);  // <-- 直接从GPU读取m_splits
      problem_sizes[expert_id] = cute::make_shape(gemm_m, gemm_n, gemm_k);

      ptr_A_list[expert_id] = ptr_A + m_offset * gemm_k;  // 计算每个expert的A指针
      stride_A_list[expert_id] = cute::make_stride(int64_t(gemm_k), _1{}, _0{});
      // ...
      ptr_D_list[expert_id] = ptr_D + m_offset * gemm_n;  // 计算每个expert的D指针
      stride_D_list[expert_id] = cute::make_stride(int64_t(gemm_n), _1{}, _0{});

      m_offset += gemm_m;  // 累加偏移
    }
  }
}

对于第一点(实现了在B系列上支持cuBLAS作为后端),相关函数的函数名为generic_moe_gemm_kernelLauncher_fp16bf16,主要分为 4 个主要阶段:

  • 阶段 1:定义CUTLASS 类型。使用 CUTLASS 的 Builder 模式定义高性能 Grouped GEMM kernel 的所有类型参数。
// 定义 GEMM 问题形状
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;  // <M,N,K>

// 配置矩阵类型和布局
using ElementA = ElementInput;  // FP16 或 BF16
using LayoutA = cutlass::layout::RowMajor;
using ElementAccumulator = float;  // 累加器用 FP32

// 核心配置:针对 SM100 (Blackwell) 架构
using ArchTag = cutlass::arch::Sm100;
using MmaTileShape = Shape<_256, _256, Int<128 / sizeof(ElementA)>>;  // Tile 大小
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100;  // TMA调度

// 构建 GEMM kernel
using GemmGrouped = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel2SM>;
  • 阶段 2:Workspace 内存布局。在预分配的 GPU workspace 中划分内存区域,用于存储各 expert 的参数。
// 在 workspace 中分配各种指针数组和 stride 数组
auto ptr_A_list = ...;     // 每个 expert 的 A 矩阵指针
auto ptr_D_list = ...;     // 每个 expert 的 D(输出)矩阵指针
auto stride_A_list = ...;  // 每个 expert 的 A stride
auto stride_B_list = ...;  // 每个 expert 的 B stride
auto stride_D_list = ...;  // 每个 expert 的 D stride
auto problem_sizes = ...;  // 每个 expert 的 GEMM shape (M, N, K)
  • 阶段 3:启动参数设置 Kernel。启动一个小型 GPU kernel,在 GPU 上读取 gemm_m_per_expert(m_splits),并填充:
  • problem_sizes[i] = (M_i, N, K) —— 每个 expert 的 GEMM 形状
  • ptr_A_list[i] —— 每个 expert 的输入指针
  • ptr_D_list[i] —— 每个 expert 的输出指针
  • 各种 stride
setGroupedGemmArguments_fp16bf16<<<1, 32, 0, stream>>>(
    num_experts, gemm_m_per_expert,  // <-- m_splits (GPU tensor)
    gemm_n, gemm_k, ptr_A, ptr_D, problem_sizes,
    ptr_A_list, stride_A_list, stride_B_list,
    ptr_D_list, stride_D_list);
  • 阶段 4:启动 CUTLASS Grouped GEMM。使用阶段 3 在 GPU 上设置好的参数,启动 CUTLASS Grouped GEMM kernel。
// 构建 CUTLASS 参数
args = typename GemmGrouped::Arguments{
    cutlass::gemm::GemmUniversalMode::kGrouped,
    {num_experts, problem_sizes, nullptr},  // 问题形状(在 GPU 上)
    {ptr_A_list, stride_A_list, ptr_B_list, stride_B_list},  // 输入
    {fusion_args, nullptr, stride_D_list, ptr_D_list, stride_D_list},  // 输出
    hw_info, scheduler
};

// 初始化并运行
gemm.initialize(args, workspace + offset);
gemm.run(stream);  // 执行 Grouped GEMM

HybridEP支持CUDA Graph后有什么限制?

Megatron的Full CUDA Graph支持

另外一个问题时,Megatron的Full CUDA Graph是否支持PP?,实际上Full CUDA Graph捕获的是整个 forward_backward_func 的执行,而不是单个 layer 或单个 stage。forward_backward_func 内部会处理所有 PP stages 的前向传播、反向传播以及 microbatch 的调度(如 1F1B 调度)。因此一个 CUDA Graph 就包含了所有 PP stages 的计算,所有 microbatches 的处理以及PP stages 之间的通信(send/recv)。

        if curr_iteration == self.cuda_graph_warmup_steps:
            logger.info(f'Capture CUDA graph for {training_str}!!!')
            torch.distributed.barrier()
            assert FullCudaGraphWrapper.cuda_graph[training_str] is None
            FullCudaGraphWrapper.cuda_graph[training_str] = torch.cuda.CUDAGraph()
            # ... 注册 RNG states
            with torch.cuda.graph(
                FullCudaGraphWrapper.cuda_graph[training_str],
                stream=capture_stream,
                capture_error_mode="thread_local",
            ):
                # 捕获整个 forward_backward_func,包含所有 PP stages
                FullCudaGraphWrapper.result[training_str] = self.forward_backward_func(
                    *args, **kwargs
                )

同时我们也能看到,在数据读取上是支持多stage PP的:

    def data_read(self, data_iterator, model, training, num_microbatches):
        """Read all microbatch inputs from Dataloader and copy to static buffers."""
        if not isinstance(model, list) or len(model) == 1:
            # 单 stage 场景(无 PP 或 PP size = 1)
            assert not isinstance(data_iterator, list) or len(data_iterator) == 1
            # ... 处理单个 data_iterator
        else:
            # 多 stage 场景(PP size > 1)
            assert isinstance(data_iterator, list) and len(data_iterator) == len(model)
            data_list = []
            for i in range(len(model)):
                if data_iterator[i] is not None:
                    # 为每个 PP stage 分别读取 microbatch 数据
                    data_list_i = []
                    for b in range(num_microbatches):
                        data_list_i.append(...)
                    data_list.append(iter(data_list_i))
                else:
                    data_list.append(None)

因此无论 PP 有多少个 stage,只会创建 2 个 CUDA Graph:1 个用于 training,1 个用于 validation。

    curr_iteration = {'training': 0, 'validation': 0}
    cuda_graph = {'training': None, 'validation': None}
    result = {'training': None, 'validation': None}

Expert权重分发与梯度收集机制

sequenceDiagram
    participant Input as Hidden States
    participant Router
    participant Planner as Offloading Planner
    participant ExpertDisp as Expert Dispatcher
    participant TokenDisp as Token Dispatcher
    participant Experts
    participant Output

    Input->>Router: hidden_states
    Router->>Planner: routing_map, probs
    Planner->>Planner: gen_offloading_plan()
    Planner->>ExpertDisp: expert_offloading_map
    ExpertDisp->>Experts: dispatch weights to echo experts
    Input->>TokenDisp: hidden_states, rerouted_probs
    TokenDisp->>Experts: dispatched tokens
    Experts->>TokenDisp: expert outputs
    TokenDisp->>Output: combined output

如何决定home experts复制到哪些redundant expert slots?

gen_intermediate是为了计算每个EP rank的空闲容量和每个local expert的计算溢出量:

def gen_intermediate(count_tokens_per_expert_from_ep_rank, ...):
    # 步骤1: 计算每个EP rank的token总数和平均值
    count_tokens_per_ep_rank = count_tokens_per_expert.view(num_ep_ranks, -1).sum(dim=1)
    avg_tokens_per_ep_rank = count_tokens_per_ep_rank.sum() // num_ep_ranks
    
    # 步骤2: 计算spare容量 = max(0, avg - current)
    # 负载低于平均的EP rank有空闲容量接收tokens
    deviation = count_tokens_per_ep_rank - avg_tokens_per_ep_rank
    capacity_spare_per_ep_rank = torch.relu(-deviation)
    
    # 步骤3: 计算spillover(溢出量)
    # 关键思路:对每个EP rank内的专家按token数排序,
    # 累积求和后超过平均值的部分就是spillover
    count_tokens_sorted, indices_sorted = count_tokens_per_expert.view(num_ep_ranks, -1).sort(dim=1)
    spillover_cumsum = (count_tokens_sorted.cumsum(dim=1) - avg_tokens_per_ep_rank).clamp(min=0)
    # 从cumsum转回每个专家的spillover
    count_spillover_sorted = torch.cat([spillover_cumsum[:, :1], 
                                         torch.diff(spillover_cumsum, dim=1)], dim=1)

capacity_spare_per_ep_rank是一个EP Rank层级的变量,表示当前EP Rank超出了多少token需要重新分配,这个很好理解:

                    avg = 350
                        │
EP0: 500 ■■■■■■■■■■■■■■■│■■■■  超额 150 → spillover
EP1: 200 ■■■■■■         │      空闲 150 → spare capacity  
EP2: 300 ■■■■■■■■■      │      空闲  50 → spare capacity
EP3: 400 ■■■■■■■■■■■■   │■     超额  50 → spillover
                        │

okk。现在我们知道了本地EP rank应该要offload多少token出来,但是一个EP rank有多个local expert,我们怎么知道每个expert应该要分多少token出来。使得这个数刚好等于本地的capacity_spare_per_ep_rank?

计算spillover的直观理解:假设一个 EP rank 有 4 个专家,tokens 分布为 [50, 100, 150, 200],平均值为 250(注意这个平均值是所有EP rank的平均值):

  • 排序后: [50, 100, 150, 200]
  • 累积和: [50, 150, 300, 500]
  • 减去平均值后: [-200, -100, 50, 250]
  • clamp(min=0): [0, 0, 50, 250]
  • 差分得到spillover: [0, 0, 50, 200]。 我们看到,token数总和为1000,

一个要思考的问题为什么我们这里要先做累积和然后做差分得到spillover?直接tokens 分布减去平均值可以吗? 我们先考虑直接减法有什么问题:

spillover = (tokens_per_expert - avg_per_expert).clamp(min=0)
# avg_per_expert = 250 / 4 = 62.5

# 结果:
专家0: (50 - 62.5).clamp(0) = 0
专家1: (100 - 62.5).clamp(0) = 37.5
专家2: (150 - 62.5).clamp(0) = 87.5
专家3: (200 - 62.5).clamp(0) = 137.5
───────────────────────────────────
总 spillover: 262.5 tokens  ← 超过了需要的 250!

这样的问题是这样计算会导致 spillover 过多,因为每个专家独立判断”超额”。

如果我们采用前缀和加上差分的思想:

# Step 1: 排序(从小到大)
sorted_tokens = [50, 100, 150, 200]

# Step 2: 累积和
cumsum = [50, 150, 300, 500]

# Step 3: 减去平均值并 clamp
spillover_cumsum = ([50, 150, 300, 500] - 250).clamp(min=0)
                 = [-200, -100, 50, 250].clamp(min=0)
                 = [0, 0, 50, 250]

# Step 4: 差分得到每个专家的 spillover
spillover = [0, 0-0, 50-0, 250-50]
          = [0, 0, 50, 200]

 spillover: 250 tokens   正好等于超额部分

其核心思想是当累积和超过平均值时,才开始 spillover;同时差分将”整体超额”按贡献分配给各专家。

tokens:     50      100      150      200
            │        │        │        │
cumsum:     50 ──── 150 ──── 300 ──── 500
            │        │        │        │
            │        │        │        │
avg=250 ────│────────│────────┼────────│────
            │        │        ↑        ↑
            │        │     超出50   超出250
            │        │        │        │
spillover:  0        0       50      200
                             └────────┘
                             差分: 250-50=200

另外,我们排序为了让负载轻的专家优先保留自己的 tokens:

未排序 [200, 50, 150, 100]:
cumsum = [200, 250, 400, 500]
spillover_cumsum = [0, 0, 150, 250]  ← 从专家0开始就接近平均值
spillover = [0, 0, 150, 100]

排序后 [50, 100, 150, 200]:
cumsum = [50, 150, 300, 500]
spillover_cumsum = [0, 0, 50, 250]   ← 更晚才超过平均值
spillover = [0, 0, 50, 200]

排序后,小专家贡献的 spillover 更少,大专家贡献更多,这样可以更公平的负载均衡。

有了这个数值之后,我们使用贪心分配算法one_shot_greedy_assignment:来求解这个问题:

def one_shot_greedy_assignment(count_tokens_per_chunk, capacity_per_bucket):
    """
    使用区间重叠计算实现一次性贪心分配
    
    原理:
    - 将所有token chunks看作连续区间 [start, end)
    - 将所有bucket容量看作连续区间 [start, end)
    - 计算每对 (chunk, bucket) 的区间重叠作为分配量
    """
    # 计算累积和得到区间端点
    chunks_cumsum = torch.cumsum(count_tokens_per_chunk, dim=0)
    buckets_cumsum = torch.cumsum(capacity_per_bucket, dim=0)
    
    chunk_start = chunks_cumsum - count_tokens_per_chunk  # [0, c1, c1+c2, ...]
    chunk_end = chunks_cumsum                              # [c1, c1+c2, ...]
    bucket_start = buckets_cumsum - capacity_per_bucket   
    bucket_end = buckets_cumsum
    
    # 计算重叠
    overlap_start = torch.maximum(chunk_start[:, None], bucket_start[None, :])
    overlap_end = torch.minimum(chunk_end[:, None], bucket_end[None, :])
    assignment = (overlap_end - overlap_start).clamp(min=0)

举个例子:chunks = [100, 150], buckets = [80, 120]

chunks:  |---100---|---150---|
buckets: |--80--|----120----|

cumsum:  [100, 250] 和 [80, 200]
chunk区间: [0,100), [100,250)
bucket区间: [0,80), [80,200)

重叠矩阵:
         bucket0   bucket1
chunk0   [0,80)∩[0,100)=80   [80,100)∩[100,250)=0  → [80, 0]
chunk1   [0,80)∩[100,250)=0  [80,200)∩[100,250)=100 → [0, 100]
# 每个专家的 spillover(已计算好)
count_spillover_per_home_expert = [0, 80, 0, 0, 50, 100, 0, 30]
# 专家编号:                        0   1   2   3   4    5   6   7

# 每个 EP rank 的 spare capacity
capacity_spare_per_ep_rank = [0, 120, 60, 0]
# EP rank:                    0    1   2   3
# 对 spillover 降序排序
count_spillover_sorted = [100, 80, 50, 30, 0, 0, 0, 0]
indices_spillover_sort = [5,   1,  4,  7, 0, 2, 3, 6]  # 原始专家编号

# 对 spare capacity 降序排序  
capacity_spare_sorted = [120, 60, 0, 0]
indices_spare_sort = [1,   2,  0, 3]  # 原始 EP rank 编号

为什么排序? 把最大的 spillover 和最大的 capacity 排在前面,方便贪心分配。

这是核心算法,使用区间重叠技术:

def one_shot_greedy_assignment(chunks, buckets):
    """
    chunks = spillover_sorted = [100, 80, 50, 30, 0, 0, 0, 0]  (待分配的量)
    buckets = capacity_sorted = [120, 60, 0, 0]                 (可接收的容量)
    """
chunks 累积和:  [100, 180, 230, 260, 260, 260, 260, 260]
buckets 累积和: [120, 180, 180, 180]

chunks 区间:
  chunk 0: [0, 100)      spillover=100
  chunk 1: [100, 180)    spillover=80
  chunk 2: [180, 230)    spillover=50
  chunk 3: [230, 260)    spillover=30
  chunk 4-7: 空 (spillover=0)

buckets 区间:
  bucket 0: [0, 120)     capacity=120
  bucket 1: [120, 180)   capacity=60
  bucket 2-3: 空 (capacity=0)
                0        100       120       180       230      260
                │         │         │         │         │        │
chunks:         │◄─chunk0─┼─────────►◄chunk1─►│◄chunk2─►│◄chunk3►│
                │   100   │         │   80    │   50    │   30   │
                │         │         │         │         │        │
buckets:        │◄──────bucket0────►│◄bucket1►│         │        │
                │        120        │   60    │         │        │
                │         │         │         │         │        │

重叠矩阵 [8 experts × 4 ep_ranks]:

              bucket0   bucket1   bucket2   bucket3
              [0,120)   [120,180) [180,180) [180,180)
chunk0 [0,100)    100        0        0        0
chunk1 [100,180)   20       60        0        0
chunk2 [180,230)    0        0        0        0    (bucket1 结束于 180)
chunk3 [230,260)    0        0        0        0
chunk4-7           0        0        0        0

assignment_sorted = 
[[100,  0, 0, 0],
 [ 20, 60, 0, 0],
 [  0,  0, 0, 0],
 [  0,  0, 0, 0],
 [  0,  0, 0, 0],
 [  0,  0, 0, 0],
 [  0,  0, 0, 0],
 [  0,  0, 0, 0]]

当 num_spare_experts_per_ep_rank > 1 时,每个 EP rank 可以接收多个专家的 offload:

假设 num_spare_experts_per_ep_rank = 2:

EP1 有 2 个 spare experts,可以接收 2 个不同专家的 offload
→ 从 assignment 矩阵的 EP1 列中选 top-2 个最大值
→ 得到 2 个 (home_expert, offload_amount) 对

gen_assignment / one_shot_greedy_assignment 只决定了:

  • 专家级别的分配:”专家 A 要 offload 100 tokens 到 spare expert X” 但它没有决定:
  • EP rank 级别的分配:”专家 A 的 100 tokens 中,多少来自 EP0,多少来自 EP1…” 因为同一个专家可能收到来自多个 EP ranks 的 tokens!
                     专家 A (home expert)
                     总 spillover = 100
                           │
         ┌─────────────────┼─────────────────┐
         │                 │                 │
      来自 EP0          来自 EP1          来自 EP2
       30 tokens        50 tokens        20 tokens
         │                 │                 │
         └─────────────────┼─────────────────┘
                           │
                           ▼
                   需要决定怎么分配!

两阶段分配流程

┌─────────────────────────────────────────────────────────────────────────────┐
│                          两阶段分配流程                                       │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                              │
│  输入: count_tokens_per_expert_from_ep_rank [ep_size, num_experts]           │
│        count_tokens_from_home_expert_to_spare_expert [num_experts, num_spare]│
│                                                                              │
│  ┌─────────────────────────────────────────────────────────────────┐        │
│  │ Phase 1: breadth_first_allocation (广度优先)                      │        │
│  │                                                                   │        │
│  │ 策略:按比例分配,让每个 EP rank 公平贡献                          │        │
│  │                                                                   │        │
│  │ EP0 贡献: 30/100 * capacity = 30% 的 offload 量                   │        │
│  │ EP1 贡献: 50/100 * capacity = 50% 的 offload 量                   │        │
│  │ EP2 贡献: 20/100 * capacity = 20% 的 offload 量                   │        │
│  │                                                                   │        │
│  │ 使用 floor() 取整,可能有剩余                                      │        │
│  └─────────────────────────────────────────────────────────────────┘        │
│                              │                                               │
│                              ▼                                               │
│  ┌─────────────────────────────────────────────────────────────────┐        │
│  │ Phase 2: depth_first_allocation (深度优先)                        │        │
│  │                                                                   │        │
│  │ 处理 Phase 1 的取整误差导致的剩余容量                              │        │
│  │ 使用区间重叠算法填满剩余空间                                        │        │
│  └─────────────────────────────────────────────────────────────────┘        │
│                              │                                               │
│                              ▼                                               │
│  输出: 每个 EP rank 具体 offload 多少 tokens 到每个 spare expert             │
│                                                                              │
└─────────────────────────────────────────────────────────────────────────────┘
3 个 EP ranks,专家 A 的 tokens 分布:
  - 来自 EP0: 30 tokens
  - 来自 EP1: 50 tokens  
  - 来自 EP2: 20 tokens
  - 总计: 100 tokens

gen_assignment 决定: 专家 A offload 80 tokens 到 spare expert X

Phase 1: breadth_first_allocation(按比例分配)

# 找到主要供应者(argmax)

idx_supplier = argmax([专家A的offload量]) = 专家A


# 计算每个 EP rank 的贡献比例
count_tokens_rel = [30, 50, 20] # 各 EP rank 发给专家 A 的 tokens
probs_proportional = [30/100, 50/100, 20/100] = [0.3, 0.5, 0.2]

# 按比例分配 capacity=80

count_tokens_ideal = [0.3*80, 0.5*80, 0.2*80] = [24, 40, 16]

# 取整(floor)
count_tokens_floors = [24, 40, 16] # 正好,无余数

结果是:

EP0 offload 24 tokens 到 spare X
EP1 offload 40 tokens 到 spare X
EP2 offload 16 tokens 到 spare X
─────────────────────────────────
总计: 80 tokens ✓

如果有取整误差?怎么办?

假设 capacity = 83(不能被整除)

count_tokens_ideal = [0.3*83, 0.5*83, 0.2*83] = [24.9, 41.5, 16.6]
count_tokens_floors = floor([24.9, 41.5, 16.6]) = [24, 41, 16]

总计 = 24 + 41 + 16 = 81 < 83

剩余容量 = 83 - 81 = 2 tokens

Phase 2: depth_first_allocation(填充剩余)

# 剩余容量
capacity_spare_remaining = 2

# 各 EP rank 剩余可 offload 的 tokens
# EP0: 30 - 24 = 6 tokens 还没 offload
# EP1: 50 - 41 = 9 tokens 还没 offload
# EP2: 20 - 16 = 4 tokens 还没 offload

# 使用区间重叠贪心分配
# 按 EP rank 顺序(EP0 → EP1 → EP2)填充剩余容量 2
# EP0 再贡献 2 tokens(因为 6 > 2)

second_pass_offload = [2, 0, 0]

最终结果是:

EP0 offload: 24 + 2 = 26 tokens
EP1 offload: 41 + 0 = 41 tokens
EP2 offload: 16 + 0 = 16 tokens
─────────────────────────────────
总计: 83 tokens ✓

为什么叫为什么叫”广度优先”和”深度优先”?

            专家A        专家B        专家C
              │            │            │
   ┌──────────┼──────────┐ │ ┌──────────┼──────────┐
   │          │          │ │ │          │          │
  EP0        EP1        EP2 ... ...

广度优先 (Breadth-First):
  先按比例分配所有专家的所有 EP ranks
  → 横向扫描,保证公平性

深度优先 (Depth-First):
  处理剩余,一个专家一个专家地填满
  → 纵向扫描,最大化利用率

两阶段设计保证了:

  1. 公平性:各 EP rank 按比例贡献(广度优先)
  2. 完整性:剩余容量被完全利用(深度优先)

每个EP rank都运行gen_offloading_plan算法,得到的结果是一样的吗?

每个 EP rank 都独立运行这个算法。关键在于输入的构成:

# Step 1: 局部信息 - 每个 EP rank 不同
tokens_per_expert_local = routing_map.sum(dim=0).long()  # 本 EP rank 的 token 分布

# Step 2: 全局信息 - 通过 AllGather 收集,所有 EP ranks 相同
tokens_per_expert_per_ep_rank = gather_from_sequence_parallel_region(
    tokens_per_expert_local,
    group=self.ep_group,
).reshape(ep_size, num_experts)  # [ep_size, num_experts]

# Step 3: 调用 gen_offloading_plan
rerouted_routing_map, rerouted_probs, expert_offloading_map = gen_offloading_plan(
    map_token_to_expert=routing_map,                      # 局部:每个 EP rank 不同
    probs_routing=probs,                                   # 局部:每个 EP rank 不同
    count_tokens_per_expert_from_ep_rank=tokens_per_expert_per_ep_rank,  # 全局:相同
    ep_rank=ep_rank,                                       # 不同:0, 1, 2, ...
    num_ep_ranks=ep_size,                                  # 相同
    ...
)

gen_offloading_plan内部,expert_offloading_map 所有 EP ranks 相同,因为这个映射只依赖于 count_tokens_per_expert_from_ep_rank(全局 token 分布),与局部 routing_map 无关。

# gen_offloading_plan 内部
count_tokens_from_home_expert_to_spare_expert = gen_assignment(
    count_tokens_per_expert_from_ep_rank,  # 全局相同的输入
    ...
)
map_home_expert_to_spare = count_tokens_from_home_expert_to_spare_expert > 0

但是rerouted_routing_maprerouted_probs每个 EP rank 不同,因为每个 EP rank 只重路由自己的 tokens。

# reroute_tokens_triton 使用局部信息
map_token_to_all_experts, probs_rerouted = reroute_tokens_triton(
    map_token_to_expert,                          # 局部 routing_map
    probs_routing,                                # 局部 probs
    count_tokens_offloaded_from_ep_rank[ep_rank], # 根据 ep_rank 选择
    count_tokens_offloaded_to_spare[ep_rank],     # 根据 ep_rank 选择
    ...
)