triton.language.get_element
1. OP 概述
简介:根据给定的索引,从输入张量中读取单个元素。 原型:
triton.language.get_element(
src,
indice,
_builder=None,
_generator=None
)→ scalar
可以作为tensor的成员函数调用,如x.get_element(...),与get_element(x, ...)等效。
2. OP 规格
2.1 参数说明
参数名 |
类型 |
说明 |
|---|---|---|
|
|
要被访问的源张量 |
|
|
用于指定元素位置的索引 |
|
- |
保留参数,暂不支持外部调用 |
|
- |
保留参数,暂不支持外部调用 |
返回值:
scalar:与 src 张量元素类型相同的标量值
2.2 支持规格
2.2.1 DataType 支持
int8 |
int16 |
int32 |
uint8 |
uint16 |
uint32 |
uint64 |
int64 |
fp16 |
fp32 |
bf16 |
bool |
|
|---|---|---|---|---|---|---|---|---|---|---|---|---|
Ascend A2/A3 |
√ |
√ |
√ |
√ |
√ |
√ |
√ |
√ |
√ |
√ |
√ |
× |
2.2.2 Shape 支持
支持任意形状的张量,但需满足:
indice 的长度必须与 src 张量的维度数相同。
2.3 特殊限制说明
无特殊限制
2.4 使用方法
以下示例实现了get_element的调用:
def index_select_manual_kernel(in_ptr, indices_ptr, out_ptr, dim,
g_stride: tl.constexpr, indice_length: tl.constexpr,
g_block: tl.constexpr, g_block_sub: tl.constexpr,
other_block: tl.constexpr):
"""
Manual implementation using tl.get_element and tl.insert_slice.
"""
g_begin = tl.program_id(0) * g_block
for goffs in range(0, g_block, g_block_sub):
g_idx = tl.arange(0, g_block_sub) + g_begin + goffs
g_mask = g_idx < indice_length
indices = tl.load(indices_ptr + g_idx, g_mask, other=0)
for other_offset in range(0, g_stride, other_block):
tmp_buf = tl.zeros((g_block_sub, other_block), in_ptr.dtype.element_ty)
other_idx = tl.arange(0, other_block) + other_offset
other_mask = other_idx < g_stride
# Manual gather: iterate over each index
for i in range(0, g_block_sub):
gather_offset = tl.get_element(indices, (i,)) * g_stride
val = tl.load(in_ptr + gather_offset + other_idx, other_mask)
tmp_buf = tl.insert_slice(tmp_buf, val[None, :],
offsets=(i, 0), sizes=(1, other_block), strides=(1, 1))
tl.store(out_ptr + g_idx[:, None] * g_stride + other_idx[None, :],
tmp_buf, g_mask[:, None] & other_mask[None, :])
3. 语义GAP
无语义差异