# triton.language.dot_scaled ## 1. OP 概述 简介:**计算以缩放格式表示两个矩阵块的矩阵乘积** ```python triton.language.dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, lhs_k_pack=True, rhs_k_pack=True, out_dtype=triton.language.float32, _semantic=None) ``` ## 2. OP 规格 ### 2.1 参数说明 | 参数名 | 类型 | 说明 | | ------------- | ----------------- | -------------------------------------------------------------- | | `lhs` | `tensor` | 左矩阵张量的基指针(支持bf16、fp16格式) | | `lhs_scale` | `tensor` | 左矩阵缩放张量的基指针(支持int8格式) | | `lhs_format` | `string` | 左矩阵张量的存放格式 (支持"bf16"和"fp16") | | `rhs` | `tensor` | 右矩阵张量的基指针 (支持bf16、fp16格式) | | `rhs_scale` | `tensor` | 右矩阵缩放张量的基指针(支持int8格式) | | `rhs_format` | `string` | 右矩阵张量的存放格式 (支持"bf16"和"fp16") | | `acc` | `tensor` | 累积张量 | | `lhs_k_pack` | `(bool, optional)` | true 沿 K 维度打包
false 沿 M 维度打包
| | `rhs_k_pack ` | `(bool, optional)` | true 沿 K 维度打包
false 沿 N 维度打包
| | `_semantic` | - | 保留参数,暂不支持外部调用 | 返回值: `out`:tensor类型,计算缩放矩阵乘后输出的值 ### 2.2 支持规格 #### 2.2.1 DataType 支持 | | fp4 | fp8 | bf16 | fp16 | | ------------- | --------- | -------- | -------- | -------- | | GPU | √ | √ | √ | √ | | Ascend A2/A3 | × | × | √ | √ | 结论: 1、Ascend 对比 GPU 缺失fp4、fp8的支持能力(硬件限制)。 2、缩放张量的值为int8,GPU上为uint8。 #### 2.2.2 Shape 支持 | | 支持维度范围 | | ------ | --------------- | | GPU | 可支持 2~3维 tensor | | Ascend | 可支持 2~3维 tensor | 结论:在 Shape 方面,GPU 与 Ascend 平台无差异,lhs/rhs矩阵均支持 2 至 3 维张量,但scale矩阵只支持2维。 ### 2.3 特殊限制说明 1、由于不支持fp8,左右矩阵不支持fp4、fp8格式,Ascend 对比 GPU 缺失lhs_k_pack、rhs_k_pack的矩阵解压缩支持能力(硬件限制)。 2、输入矩阵lhs、rhs推荐输入范围为[-5, 5],超过可能会出现极值inf。 3、由于硬件存在对齐要求,需要限制scale矩阵做broadcast的倍数,至少应为16 4、当前支持的缩放矩阵格式为int8,社区为uint8 ### 2.4 使用方法 以下示例实现了对输入张量 `x` 做就地绝对值计算: ```python@triton.jit def dot_scale_kernel(a_base, stride_a0: tl.constexpr, stride_a1: tl.constexpr, a_scale, b_base, stride_b0: tl.constexpr, stride_b1: tl.constexpr, b_scale, out, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, type_a: tl.constexpr, type_b: tl.constexpr): PACKED_BLOCK_K_A: tl.constexpr = BLOCK_K PACKED_BLOCK_K_B: tl.constexpr = BLOCK_K str_a0: tl.constexpr = stride_a0 a_ptr = a_base + tl.arange(0, BLOCK_M)[:, None] * stride_a0 + tl.arange(0, str_a0)[None, :] * stride_a1 b_ptr = b_base + tl.arange(0, PACKED_BLOCK_K_B)[:, None] * stride_b0 + tl.arange(0, BLOCK_N)[None, :] * stride_b1 a = tl.load(a_ptr) b = tl.load(b_ptr) SCALE_BLOCK_K: tl.constexpr = BLOCK_K // 32 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) if a_scale is not None: scale_a_ptr = a_scale + tl.arange(0, BLOCK_M)[:, None] * SCALE_BLOCK_K + tl.arange(0, SCALE_BLOCK_K)[None, :] a_scale = tl.load(scale_a_ptr) if b_scale is not None: scale_b_ptr = b_scale + tl.arange(0, BLOCK_N)[:, None] * SCALE_BLOCK_K + tl.arange(0, SCALE_BLOCK_K)[None, :] b_scale = tl.load(scale_b_ptr) accumulator = tl.dot_scaled(a, a_scale, type_a, b, b_scale, type_b, acc=accumulator, out_dtype=tl.float32) out_ptr = out + tl.arange(0, BLOCK_M)[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] tl.store(out_ptr, accumulator.to(a.dtype)) x = torch.randn(shape, dtype=torch.bfloat16, device="npu") y = torch.randn(shape, dtype=torch.bfloat16, device="npu") M, K = shape[0], shape[1] scale_x = torch.randint(min_scale - 128, max_scale - 127, (M, K // 32), dtype=torch.int8, device="npu") scale_y = torch.randint(min_scale - 128, max_scale - 127, (N, K // 32), dtype=torch.int8, device="npu") type_a, type_b = "bf16", "bf16" pgm = dot_scale_kernel[(1,)](x, *x.stride(), scale_x, y, *y.stride(), scale_y, z, M, N, K, type_a, type_b) ```