triton.language.argmax

1. OP 概述

简介:在指定维度上返回最大值所在的下标

triton.language.argmax(input, axis, tie_break_left=True, keep_dims=False)

2. OP 规格

2.1 参数说明

参数名

类型

说明

input

tensor

张量数据,左操作数

axis

int

指定在哪个维度上进行规约

keep_dims

bool

保持规约轴规约后的维度

tie_break_left

bool

如果多个元素有相同的最大值,返回最左侧最大值的下标

返回值: tl.tensor:同input的shape的张量

2.2 支持规格

2.2.1 DataType 支持

uint8

int8

uint16

int16

uint32

int32

uint64

int64

fp16

fp32

bf16

bool/int1

GPU

Ascend A2A3系列

x

×

×

2.2.2 Shape 支持

支持维度范围

GPU

无限制

Ascend

无限制

结论:在 Shape 方面,GPU 与 Ascend 平台无差异。

2.3 特殊限制说明

相对社区能力缺失且做不到

Ascend A3 对比 GPU 缺失uint16、uint32、uint64、fp64的支持

2.4 使用方法

更多示例参考triton-ascend代码仓,ascend/examples/generalization_cases/test_argmax.py

def triton_argmax_1d(in_ptr0, out_ptr1, xnumel, XBLOCK: tl.constexpr):
    xoffset = tl.program_id(0) + tl.arange(0, XBLOCK)
    tmp0 = tl.load(in_ptr0 + xoffset, None)
    tmp4 = tl.argmax(tmp0, 0)
    tl.store(out_ptr1, tmp4, None)

3. 特殊取值情况

对于 tensor[nan,inf] 的情况,返回inf所在的下标