block_universal_gemm_ar_flatbr_bquant_cr.hpp Source File

block_universal_gemm_ar_flatbr_bquant_cr.hpp Source File#

Composable Kernel: block_universal_gemm_ar_flatbr_bquant_cr.hpp Source File
block_universal_gemm_ar_flatbr_bquant_cr.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
11// A is block window on shared memory
12// BQ (scale tensor) is block distributed tensor.
13// Consecutive QuantGroupSize elements of B are quantized with a separate scale.
14// B is block window on block distributed tensor.
15// C is block distributed tensor
16template <typename Problem_, typename BlockPolicy_>
18{
28
29 static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant preshuffle kernel!");
30 static_assert(QuantGroupSize::kN == 1, "no block for N supported yet!");
31
32 static constexpr auto I0 = number<0>();
33 static constexpr auto I1 = number<1>();
34 static constexpr auto I2 = number<2>();
35 static constexpr auto idxM = I0;
36 static constexpr auto idxN = I1;
37 static constexpr auto idxK = I2;
41
42 static constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
43
44 static constexpr auto warp_size = get_warp_size();
45
46 using WG = remove_cvref_t<decltype(config.template at<0>())>;
47
48 static constexpr index_t MWarp = config.template at<1>();
49 static constexpr index_t NWarp = config.template at<2>();
50
51 static constexpr index_t MPerBlock = BlockGemmShape::kM;
52 static constexpr index_t KPerBlock = BlockGemmShape::kK;
53
54 static constexpr index_t kBlockSize = Problem::kBlockSize;
55
56 static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
57 static constexpr index_t NIterPerWarp =
58 BlockTile::at(idxN) / (WarpTile::at(idxN) * BlockWarps::at(idxN));
59 static constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
60
61 static constexpr auto MIter_2nd_last =
62 (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
63
64 static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
65
66 static constexpr index_t QScalesPerBlockRow =
67 integer_divide_ceil(KPerBlock, QuantGroupSize::kK);
68 static constexpr index_t QScalesPerWarpGemmRow =
69 integer_divide_ceil(WG::kK, QuantGroupSize::kK);
70
72 static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
73
77
78 template <typename T>
79 CK_TILE_DEVICE static float cvt_scale_to_fp32(T& scale)
80 {
81 float scale_reg_f = 0.f;
82 if constexpr(std::is_same_v<BQDataType, ck_tile::fp8_t>)
83 {
84 scale_reg_f = element_wise::amd_assembly_fp8_to_fp32(static_cast<uint32_t>(scale));
85 }
86 else if constexpr(std::is_same_v<BQDataType, ck_tile::bf8_t>)
87 {
88 scale_reg_f = element_wise::amd_assembly_bf8_to_fp32(static_cast<uint32_t>(scale));
89 }
90 else if constexpr(std::is_same_v<BQDataType, float>)
91 {
92 scale_reg_f = ck_tile::bit_cast<float>(scale);
93 }
94 else
95 {
96 static_assert(false, "BQDataType must be float, fp8_t or bf8_t.");
97 }
98 return scale_reg_f;
99 }
100
101 CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
102 {
103 constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
110
111 constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
112 c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
113
114 constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
115
116 auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
117 return c_block_tensor;
118 }
119
120 // C += A * B
121 template <typename CBlockTensor,
122 typename ABlockTensor,
123 typename BFlatBlockTensor,
124 typename BQBlockTensor,
125 typename ABlockWindow>
126 CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
127 ABlockTensor& a_warp_tensor,
128 BFlatBlockTensor& b_warp_tensor,
129 BQBlockTensor& bq_block_tensor,
130 ABlockWindow& a_warp_windows) const
131 {
132 using CWarpDstr = typename WG::CWarpDstr;
133 using AccTensor = typename WG::CWarpTensor;
134
135 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
136
138 c_acc;
139
140 auto zero_accumulators = [&] {
141 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
142 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
143 static_for<0, (WG::kM * WG::kN) / warp_size, 1>{}([&](auto i) {
144 c_acc(mIter)(nIter).get_thread_buffer()[i] = 0.0f;
145 }); // make sure WG::CWarpTensor exposes a clear/zero
146 });
147 });
148 };
149 static_for<0, QScalesPerBlockRow, 1>{}([&](auto kQScale) {
150 zero_accumulators();
151 static_for<0, KIterPerQScale, 1>{}([&](auto kIterInQScale) {
152 constexpr auto kIter = kQScale * KIterPerQScale + kIterInQScale;
153 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
154 constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
155 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
156 // warp GEMM
157 WG{}(c_acc(mIter)(nIter),
158 a_warp_tensor(number<AwarpIter>{}),
159 b_warp_tensor(nIter)(number<kIter>{}));
160 });
161 __builtin_amdgcn_sched_barrier(0x7F6);
162 // preload next A from lds
163 if constexpr((kIter * MIterPerWarp + mIter) <
165 {
166 constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
167 constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
168 a_warp_tensor(number<AwarpIter>{}) =
169 load_tile(a_warp_windows(number<AmIter>{})(number<AkIter>{}));
170 }
171 // barrier
172 // Could be deleted
173 if constexpr((mIter == MIter_2nd_last))
174 {
176 }
177 });
178 });
179 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
180 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
181 constexpr auto tbuf_offset =
182 number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
184 c_warp_y_index_zeros)) /
185 CBlockTensor::PackedSize>{};
186
187 constexpr index_t reg_offset = nIter * KPerBlockBQ + kQScale;
188
189 auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
190 float scale_reg_f = cvt_scale_to_fp32(scale_reg);
191
192 static_for<0, WG::kM * WG::kN / warp_size, 1>{}([&](auto c_row) {
193 auto& c_ref = c_block_tensor.get_thread_buffer()[tbuf_offset + c_row];
194 const auto acc_val = c_acc(mIter)(nIter).get_thread_buffer()[c_row];
195 c_ref = c_ref + acc_val * scale_reg_f;
196 });
197 });
198 });
199 });
200 }
201};
202
203} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition tile_distribution_encoding.hpp:457
CK_TILE_DEVICE float amd_assembly_fp8_to_fp32(uint32_t src)
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:258
CK_TILE_DEVICE float amd_assembly_bf8_to_fp32(uint32_t src)
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:265
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X &x)
Definition bit_cast.hpp:11
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_HOST_DEVICE constexpr auto merge_sequences(Seqs...)
Definition tile/core/container/sequence.hpp:826
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition tile/core/container/sequence.hpp:1026
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
tuple_array< T, N > statically_indexed_array
Definition tile/core/container/statically_indexed_array.hpp:16
unsigned int uint32_t
Definition stdint.h:126
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:18
static constexpr index_t NWarp
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:49
remove_cvref_t< typename Problem::BQDataType > BQDataType
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:23
static constexpr index_t KIterPerWarp
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:59
static constexpr index_t QScalesPerWarpGemmRow
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:68
static constexpr index_t MWarp
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:48
static constexpr index_t m_preload
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:74
static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:101
remove_cvref_t< typename BlockGemmShape::BlockWarps > BlockWarps
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:39
remove_cvref_t< typename Problem::QuantGroupSize > QuantGroupSize
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:27
remove_cvref_t< Problem_ > Problem
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:19
static constexpr auto idxK
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:37
remove_cvref_t< BlockPolicy_ > BlockPolicy
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:20
remove_cvref_t< typename Problem::CDataType > CDataType
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:24
remove_cvref_t< typename Problem::ADataType > ADataType
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:21
remove_cvref_t< typename BlockGemmShape::WarpTile > WarpTile
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:40
static constexpr auto I1
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:33
static constexpr index_t KPerBlock
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:52
static constexpr index_t MIterPerWarp
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:56
static CK_TILE_DEVICE float cvt_scale_to_fp32(T &scale)
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:79
static constexpr auto I0
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:32
static constexpr index_t QScalesPerBlockRow
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:66
remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:25
static constexpr auto I2
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:34
remove_cvref_t< typename Problem::BDataType > BDataType
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:22
CK_TILE_DEVICE void operator()(CBlockTensor &c_block_tensor, ABlockTensor &a_warp_tensor, BFlatBlockTensor &b_warp_tensor, BQBlockTensor &bq_block_tensor, ABlockWindow &a_warp_windows) const
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:126
remove_cvref_t< decltype(config.template at< 0 >())> WG
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:46
remove_cvref_t< typename BlockGemmShape::BlockTile > BlockTile
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:38
static constexpr index_t MPerBlock
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:51
static constexpr index_t kBlockSize
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:54
static constexpr index_t KPerBlockBQ
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:64
static constexpr auto config
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:42
static constexpr index_t NIterPerWarp
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:57
static constexpr index_t KIterPerQScale
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:71
static constexpr auto idxM
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:35
static constexpr index_t DsReadPreload
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:72
static constexpr auto idxN
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:36
static constexpr auto warp_size
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:44
static constexpr auto MIter_2nd_last
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:61
remove_cvref_t< typename Problem::BlockGemmShape > BlockGemmShape
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:26
Definition tile/core/utility/functional.hpp:43
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192