triton.language.load
1. OP 概述
原型:
triton.language.load(
pointer,
mask=None,
other=None,
boundary_check=(),
padding_option='',
cache_modifier='',
eviction_policy='',
volatile=False,
_semantic=None
)
简介:返回一个Tensor/Scalar,其值从GlobalMemory中pointer参数指向的位置加载。
2. OP 规格
2.1 参数说明
参数名 |
类型 |
说明 |
|---|---|---|
|
|
指向GM上待读取数据的指针 |
|
|
可选参数,当且仅当 |
|
|
可选参数,当且仅当 |
|
|
可选参数,当且仅当 |
|
|
可选参数,当且仅当 |
|
|
可选参数,控制NVIDIA PTX上的cache选项,对Ascend硬件无效 |
|
|
控制NVIDIA PTX的eviction策略, 对Ascend硬件无效 |
|
|
控制NVIDIA PTX的volatile选项, 对Ascend硬件无效 |
|
- |
保留参数,暂不支持外部调用 |
当前910代际均还不支持cache_modifier,eviction_policy, volatile等参数
2.2 支持规格
2.2.1 DataType 支持
int8 |
int16 |
int32 |
uint8 |
uint16 |
uint32 |
uint64 |
int64 |
fp16 |
fp32 |
fp64 |
bf16 |
bool |
|
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
GPU |
√ |
√ |
√ |
√ |
√ |
√ |
√ |
√ |
√ |
√ |
√ |
√ |
√ |
Ascend A2/A3 |
√ |
√ |
√ |
× |
× |
× |
× |
√ |
√ |
√ |
× |
√ |
√ |
结论:Ascend 对比 GPU 缺失uint8、uint16、uint32、uint64、fp64的支持能力(硬件限制)。
2.2.2 Shape 支持
支持维度范围 |
|
|---|---|
GPU |
支持scalar和1~5维tensor |
Ascend A2/A3 |
支持scalar和1~5维 tensor |
结论:在 Shape 方面,GPU 与 Ascend 平台无差异,均支持 1 至 5 维张量。
2.2.3 社区约束
若
pointer是一个单指针:此时
tl.load返回一个标量mask和other必须是标量other会隐式类型转换成pointer.dtype.element_ty的数据类型此时不允许传入
boundary_check和padding_option
若
pointer是一个N-Dimensional tensor:此时
tl.load返回一个与pointershape相同的N-Dimensional tensormask和other会隐式广播到和pointer相同的shape此时不允许传入
boundary_check和padding_option
若
pointer来自于tl.make_block_ptr:此时
mask和other必须是None此时可以通过
boundary_check和padding_option设置边界检查和越界补充值
2.3 特殊限制说明
相对社区能力缺失且做不到
Ascend 对比 GPU 缺失uint8、uint16、uint32、uint64、fp64的支持能力(硬件限制)。
差异点 |
描述 |
解决途径 |
|---|---|---|
不支持 |
当前使用的社区分支新增 |
可软件开发支持 |
与分支、循环语句搭配使用时的泛化性问题 |
当前tl.load的 |
大量泛化测试暴露问题,迭代解决 |
2.4 使用方法
以下示例中通过triton_ldst_indirect_07_kernel和triton_ldst_indirect_07_func的配合调用,实现了torch_ldst_indirect_07_func的功能:
@triton.jit
def triton_ldst_indirect_07_kernel(
out_ptr0, in_ptr0, in_ptr1, in_ptr2, stride_in_r,
XS: tl.constexpr, RS: tl.constexpr
):
pid = tl.program_id(0)
in_idx0 = pid * XS + tl.arange(0, XS)
in_idx1 = tl.arange(0, RS)
tmp0 = tl.load(in_ptr0 + in_idx0)
tmp1 = tl.load(in_ptr1 + in_idx1)
in_idx2 = tmp0[:, None] * stride_in_r + tmp1[None, :]
tmp2 = tl.load(in_ptr2 + in_idx2)
out0_idx = in_idx0[:, None] * RS + in_idx1[None, :]
tl.store(out_ptr0 + out0_idx, tmp2)
def triton_ldst_indirect_07_func(xr, xc, x2, xs, rs):
nr = x2.size()[0]
nc = xc.numel()
stride_in_r = x2.stride()[0]
assert nr == xs, "test only single core"
y0 = torch.empty((nr, nc), dtype=x2.dtype, device=x2.device)
triton_ldst_indirect_07_kernel[nr // xs, 1, 1](
y0, xr, xc, x2, stride_in_r, XS = xs, RS = rs)
return y0
def torch_ldst_indirect_07_func(xr, xc, x2):
flatten_idx = (xr[:, None] * x2.stride()[0] + xc[None, :]).flatten()
extracted = x2.flatten()[flatten_idx].reshape([xr.numel(), xc.numel()])
return extracted
DEV = "npu"
DTYPE = torch.float32
offset = 8
N0, N1 = 16, 32
blocksize = 4
lowdimsize = N0
assert N1 >= N0+offset, "N1 must be >= N0+offset"
assert N0 == lowdimsize, "N0 must be == lowdimsize"
xc = offset + torch.arange(0, N0, device=DEV)
xr = torch.arange(0, blocksize, device=DEV)
x2 = torch.randn((blocksize, N1), dtype=DTYPE, device=DEV)
torch_ref = torch_ldst_indirect_07_func(xr, xc, x2)
triton_cal = triton_ldst_indirect_07_func(xr, xc, x2, blocksize, lowdimsize)
torch.testing.assert_close(triton_cal, torch_ref)