triton.language.sync_block_set

1. 函数概述

显式的核心间同步指令,用于协调 Cube-Vector 架构中不同核心间的执行顺序和数据一致性。

2. sync_block_set 操作

2.1 函数概述

生产者核心完成任务后,向消费者发送同步信号。

triton.language.sync_block_set(sender, receiver, event_id, _builder=None)

2.2 规格

2.2.1 参数说明

参数

类型

默认值

含义说明

sender

str

必需

发送方核心类型:"cube" 或 "vector"

receiver

str

必需

接收方核心类型:"cube" 或 "vector"

event_id

int

必需

事件ID,用于区分不同的同步点

_builder

-

None

保留参数,暂不支持外部调用

2.2.2 特殊限制说明

  1. senderreceiver 不能相同,不能自己给自己发信号

  2. event_id 必须在 0-15 范围内(共16个独立事件)

3. sync_block_wait 操作

3.1 函数概述

消费者核心等待生产者的同步信号。

triton.language.sync_block_wait(sender, receiver, event_id, _builder=None)

3.2 规格

3.2.1 参数说明

参数

类型

默认值

含义说明

sender

str

必需

发送方核心类型:"cube" 或 "vector"

receiver

str

必需

接收方核心类型:"cube" 或 "vector"

event_id

int

必需

等待的事件ID

_builder

-

None

保留参数,暂不支持外部调用

3.2.2 特殊限制说明

  1. senderreceiver 不能相同

  2. event_id 必须与对应 sync_block_set 使用的 ID 一致

4. sync_block_all 操作

4.1 函数概述

全局屏障同步,让所有指定类型的核心同步到同一点。

triton.language.sync_block_all(mode, event_id, _builder=None)

4.2 规格

4.2.1 参数说明

参数

类型

默认值

含义说明

mode

str

必需

同步模式:"all_cube"、"all_vector" 或 "all"

event_id

int

必需

全局同步事件ID

_builder

-

None

保留参数,暂不支持外部调用

4.2.2 特殊限制说明

  1. mode 必须为 "all_cube"、"all_vector" 或 "all" 之一

  2. event_id 必须在 0-15 范围内

5. 使用方法

5.1 基础使用示例

import triton
import triton.language as tl
import triton.language.ascend as al

@triton.jit
def sync_example():
    # Cube 核心计算并通知 Vector
    with al.Scope(core_mode="cube"):
        # ... 执行 Cube 计算 ...
        tl.sync_block_set("cube", "vector", 0)

    # Vector 核心等待 Cube 完成
    with al.Scope(core_mode="vector"):
        tl.sync_block_wait("cube", "vector", 0)
        # ... 执行 Vector 计算 ...

5.2 Flash Attention 流水线示例

@triton.jit
def flash_attention_fwd(q_ptr, k_ptr, v_ptr, o_ptr, ...):
    acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)

    with al.Scope(core_mode="cube"):
        for start_n in range(0, N, BLOCK_N):
            qk = tl.dot(q, k)
            tl.sync_block_set("cube", "vector", 0)
            tl.sync_block_wait("vector", "cube", 1)
            pv = tl.dot(p, v)
            tl.sync_block_set("cube", "vector", 2)

    with al.Scope(core_mode="vector"):
        for start_n in range(0, N, BLOCK_N):
            tl.sync_block_wait("cube", "vector", 0)
            m_new, l_new, softmax_out = _softmax(qk, m_prev, l_prev)
            tl.sync_block_set("vector", "cube", 1)
            tl.sync_block_wait("cube", "vector", 2)
            acc = _update_output(pv, softmax_out, acc)

    with al.Scope(core_mode="cube"):
        tl.sync_block_all("all", 0)

    tl.store(o_ptr + offsets, acc)