block_gemm_areg_bsmem_creg_v2r1.hpp Source File

block_gemm_areg_bsmem_creg_v2r1.hpp Source File#

Composable Kernel: block_gemm_areg_bsmem_creg_v2r1.hpp Source File
block_gemm_areg_bsmem_creg_v2r1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
11// A is block distributed tensor
12// B is block window on shared memory
13// C is block distributed tensor
14template <typename Problem_, typename Policy_ = BlockGemmARegBSmemCRegV2DefaultPolicy>
16{
23
24 static constexpr index_t kBlockSize = Problem::kBlockSize;
25
26 // C += A * B
27 template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
28 CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
29 const ABlockTensorTmp& a_block_tensor_tmp,
30 const BBlockWindowTmp& b_block_window_tmp) const
31 {
32 static_assert(
33 std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
34 std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
35 std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
36 "wrong!");
37
38 constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
39 constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
40 constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
41
42 static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
43 KPerBlock == BlockGemmShape::kK,
44 "wrong!");
45
46 constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
47
48 using WG = remove_cvref_t<decltype(config.template at<0>())>;
49
50 constexpr index_t MWarp = config.template at<1>();
51 constexpr index_t NWarp = config.template at<2>();
52
53 constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
54 constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
55 constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
56
57 constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
58 constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
59
60 const index_t iNWarp = get_warp_id() % NWarp;
61
62 constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
69
70 constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
71 c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
72
73 // constrcut from A-block-tensor from A-Block-tensor-tmp
74 // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
75 // distribution
78
79 a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
80
81 // construct B-warp-window
82 auto b_warp_window_tmp = make_tile_window(
83 b_block_window_tmp.get_bottom_tensor_view(),
85 b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
86 make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
87
89 statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
90 NIterPerWarp>
91 b_warp_windows;
92
93 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
94 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
95 b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
96
97 move_tile_window(b_warp_windows(nIter)(kIter),
98 {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
99 });
100 });
101
102 // check C-block-distribution
103 static_assert(
104 std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
105 remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
106 .get_static_tile_distribution_encoding())>>,
107 "wrong!");
108
109 using AWarpDstr = typename WG::AWarpDstr;
110 using CWarpDstr = typename WG::CWarpDstr;
111
112 using AWarpTensor = typename WG::AWarpTensor;
113 using CWarpTensor = typename WG::CWarpTensor;
114
115 constexpr auto a_warp_y_lengths =
116 to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
117 constexpr auto c_warp_y_lengths =
118 to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
119
120 constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
121 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
122
124 statically_indexed_array<decltype(load_tile(decltype(b_warp_window_tmp){})),
125 KIterPerWarp>,
126 NIterPerWarp>
127 b_warp_tensors;
128
129 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
130 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
131 b_warp_tensors(nIter)(kIter) = load_tile(b_warp_windows(nIter)(kIter));
132 });
133 });
134
135 // hot loop:
136 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
137 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
138 // read B warp tensor from B Block window
139 const auto b_warp_tensor = b_warp_tensors(nIter)(kIter);
140
141 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
142 // read A warp tensor from A block tensor
143 AWarpTensor a_warp_tensor;
144
145 a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
146 merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
147 merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
148
149 // read C warp tensor from C block tensor
150 CWarpTensor c_warp_tensor;
151
152 c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
153 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
154 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
155
156 // warp GEMM
157 WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
158 // WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]);
159
160 // write C warp tensor into C block tensor
161 c_block_tensor.set_y_sliced_thread_data(
162 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
163 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
164 c_warp_tensor.get_thread_buffer());
165 });
166 });
167 });
168
169 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
172 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
173 __builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
174 });
175 });
176 }
177
178 template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
180 {
181 constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
182
183 using WG = remove_cvref_t<decltype(config.template at<0>())>;
184
185 constexpr index_t MWarp = config.template at<1>();
186 constexpr index_t NWarp = config.template at<2>();
187
188 constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
189 constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
190
191 constexpr auto a_block_outer_dstr_encoding =
198
199 constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
200 a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
201
202 return make_static_tile_distribution(a_block_dstr_encode);
203 }
204
205 CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
206 {
207 constexpr index_t MPerBlock = BlockGemmShape::kM;
208 constexpr index_t NPerBlock = BlockGemmShape::kN;
209
210 constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
211
212 using WG = remove_cvref_t<decltype(config.template at<0>())>;
213
214 constexpr index_t MWarp = config.template at<1>();
215 constexpr index_t NWarp = config.template at<2>();
216
217 constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
218 constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
219 // constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
220
221 constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
228
229 constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
230 c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
231 constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
232 auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
233 return c_block_tensor;
234 }
235
236 // C = A * B
237 template <typename ABlockTensorTmp, typename BBlockWindowTmp>
238 CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
239 const BBlockWindowTmp& b_block_window_tmp) const
240 {
241 auto c_block_tensor = MakeCBlockTile();
242 operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp);
243 return c_block_tensor;
244 }
245};
246
247} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition tile_distribution_encoding.hpp:457
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE index_t get_warp_id(bool_constant< ReturnSgpr >={})
Definition arch.hpp:104
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_HOST_DEVICE constexpr auto merge_sequences(Seqs...)
Definition tile/core/container/sequence.hpp:826
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
array< index_t, N > multi_index
Definition tile/core/container/multi_index.hpp:17
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple< number< Is >... >)
Definition tile/core/container/sequence.hpp:1055
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition tile/core/container/sequence.hpp:1026
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
tuple_array< T, N > statically_indexed_array
Definition tile/core/container/statically_indexed_array.hpp:16
Definition block_gemm_areg_bsmem_creg_v2r1.hpp:16
remove_cvref_t< typename Problem::CDataType > CDataType
Definition block_gemm_areg_bsmem_creg_v2r1.hpp:21
remove_cvref_t< typename Problem::BlockGemmShape > BlockGemmShape
Definition block_gemm_areg_bsmem_creg_v2r1.hpp:22
CK_TILE_DEVICE void operator()(CBlockTensor &c_block_tensor, const ABlockTensorTmp &a_block_tensor_tmp, const BBlockWindowTmp &b_block_window_tmp) const
Definition block_gemm_areg_bsmem_creg_v2r1.hpp:28
remove_cvref_t< typename Problem::ADataType > ADataType
Definition block_gemm_areg_bsmem_creg_v2r1.hpp:19
static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
Definition block_gemm_areg_bsmem_creg_v2r1.hpp:205
remove_cvref_t< typename Problem::BDataType > BDataType
Definition block_gemm_areg_bsmem_creg_v2r1.hpp:20
static CK_TILE_DEVICE constexpr auto MakeABlockTileDistribution()
Definition block_gemm_areg_bsmem_creg_v2r1.hpp:179
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp &a_block_tensor_tmp, const BBlockWindowTmp &b_block_window_tmp) const
Definition block_gemm_areg_bsmem_creg_v2r1.hpp:238
remove_cvref_t< Policy_ > Policy
Definition block_gemm_areg_bsmem_creg_v2r1.hpp:18
static constexpr index_t kBlockSize
Definition block_gemm_areg_bsmem_creg_v2r1.hpp:24
remove_cvref_t< Problem_ > Problem
Definition block_gemm_areg_bsmem_creg_v2r1.hpp:17
Definition tile/core/utility/functional.hpp:43
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192