triton.language.sync_block_set
1. 函数概述
显式的核心间同步指令,用于协调 Cube-Vector 架构中不同核心间的执行顺序和数据一致性。
2. sync_block_set 操作
2.1 函数概述
生产者核心完成任务后,向消费者发送同步信号。
triton.language.sync_block_set(sender, receiver, event_id, _builder=None)
2.2 规格
2.2.1 参数说明
参数 |
类型 |
默认值 |
含义说明 |
|---|---|---|---|
|
|
必需 |
发送方核心类型:"cube" 或 "vector" |
|
|
必需 |
接收方核心类型:"cube" 或 "vector" |
|
|
必需 |
事件ID,用于区分不同的同步点 |
|
- |
|
保留参数,暂不支持外部调用 |
2.2.2 特殊限制说明
sender和receiver不能相同,不能自己给自己发信号event_id必须在 0-15 范围内(共16个独立事件)
3. sync_block_wait 操作
3.1 函数概述
消费者核心等待生产者的同步信号。
triton.language.sync_block_wait(sender, receiver, event_id, _builder=None)
3.2 规格
3.2.1 参数说明
参数 |
类型 |
默认值 |
含义说明 |
|---|---|---|---|
|
|
必需 |
发送方核心类型:"cube" 或 "vector" |
|
|
必需 |
接收方核心类型:"cube" 或 "vector" |
|
|
必需 |
等待的事件ID |
|
- |
|
保留参数,暂不支持外部调用 |
3.2.2 特殊限制说明
sender和receiver不能相同event_id必须与对应sync_block_set使用的 ID 一致
4. sync_block_all 操作
4.1 函数概述
全局屏障同步,让所有指定类型的核心同步到同一点。
triton.language.sync_block_all(mode, event_id, _builder=None)
4.2 规格
4.2.1 参数说明
参数 |
类型 |
默认值 |
含义说明 |
|---|---|---|---|
|
|
必需 |
同步模式:"all_cube"、"all_vector" 或 "all" |
|
|
必需 |
全局同步事件ID |
|
- |
|
保留参数,暂不支持外部调用 |
4.2.2 特殊限制说明
mode必须为 "all_cube"、"all_vector" 或 "all" 之一event_id必须在 0-15 范围内
5. 使用方法
5.1 基础使用示例
import triton
import triton.language as tl
import triton.language.ascend as al
@triton.jit
def sync_example():
# Cube 核心计算并通知 Vector
with al.Scope(core_mode="cube"):
# ... 执行 Cube 计算 ...
tl.sync_block_set("cube", "vector", 0)
# Vector 核心等待 Cube 完成
with al.Scope(core_mode="vector"):
tl.sync_block_wait("cube", "vector", 0)
# ... 执行 Vector 计算 ...
5.2 Flash Attention 流水线示例
@triton.jit
def flash_attention_fwd(q_ptr, k_ptr, v_ptr, o_ptr, ...):
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
with al.Scope(core_mode="cube"):
for start_n in range(0, N, BLOCK_N):
qk = tl.dot(q, k)
tl.sync_block_set("cube", "vector", 0)
tl.sync_block_wait("vector", "cube", 1)
pv = tl.dot(p, v)
tl.sync_block_set("cube", "vector", 2)
with al.Scope(core_mode="vector"):
for start_n in range(0, N, BLOCK_N):
tl.sync_block_wait("cube", "vector", 0)
m_new, l_new, softmax_out = _softmax(qk, m_prev, l_prev)
tl.sync_block_set("vector", "cube", 1)
tl.sync_block_wait("cube", "vector", 2)
acc = _update_output(pv, softmax_out, acc)
with al.Scope(core_mode="cube"):
tl.sync_block_all("all", 0)
tl.store(o_ptr + offsets, acc)