block_gemm_asmem_bsmem_creg_v1.hpp Source File

block_gemm_asmem_bsmem_creg_v1.hpp Source File#

Composable Kernel: block_gemm_asmem_bsmem_creg_v1.hpp Source File
block_gemm_asmem_bsmem_creg_v1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
11// A is block window on shared memory
12// B is block window on shared memory
13// C is block distributed tensor
14template <typename Problem_, typename Policy_ = BlockGemmASmemBSmemCRegV1DefaultPolicy>
16{
23
24 static constexpr index_t kBlockSize = Problem::kBlockSize;
25
26 // C += A * B
27 template <typename CBlockTensor, typename ABlockWindow, typename BBlockWindow>
28 CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
29 const ABlockWindow& a_block_window,
30 const BBlockWindow& b_block_window) const
31 {
32 static_assert(std::is_same_v<ADataType, typename ABlockWindow::DataType> &&
33 std::is_same_v<BDataType, typename BBlockWindow::DataType> &&
34 std::is_same_v<CDataType, typename CBlockTensor::DataType>,
35 "wrong!");
36
37 constexpr index_t MPerBlock = ABlockWindow{}.get_window_lengths()[number<0>{}];
38 constexpr index_t NPerBlock = BBlockWindow{}.get_window_lengths()[number<0>{}];
39 constexpr index_t KPerBlock = ABlockWindow{}.get_window_lengths()[number<1>{}];
40
41 static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
42 KPerBlock == BlockGemmShape::kK,
43 "wrong!");
44
45 constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
46
47 using WG = remove_cvref_t<decltype(config.template at<0>())>;
48
49 constexpr index_t MWarp = config.template at<1>();
50 constexpr index_t NWarp = config.template at<2>();
51
52 constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
53 constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
54 constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
55
56 constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
57 constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
58 constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
59
60 const index_t iMWarp = get_warp_id() / NWarp;
61 const index_t iNWarp = get_warp_id() % NWarp;
62
63 // construct A-warp-window
64 auto a_warp_window_tmp = make_tile_window(
65 a_block_window.get_bottom_tensor_view(),
67 a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0},
68 make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
69
70#if 0 // FIXME: using array will cause register spill
71 array<array<decltype(a_warp_window_tmp), KIterPerWarp>, MIterPerWarp> a_warp_windows{
72 {a_warp_window_tmp}};
73
74 for(index_t mIter = 0; mIter < MIterPerWarp; mIter++)
75 {
76 for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
77 {
78 move_tile_window(a_warp_windows(mIter)(kIter),
79 {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
80 }
81 }
82#else
84 statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
85 MIterPerWarp>
86 a_warp_windows;
87
88 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
89 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
90 a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
91
92 move_tile_window(a_warp_windows(mIter)(kIter),
93 {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
94 });
95 });
96#endif
97
98 // construct B-warp-window
99 auto b_warp_window_tmp = make_tile_window(
100 b_block_window.get_bottom_tensor_view(),
102 b_block_window.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
103 make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
104
105#if 0 // FIXME: using array will cause register spill
106 array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
107 {b_warp_window_tmp}};
108
109 for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
110 {
111 for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
112 {
113 move_tile_window(b_warp_windows(nIter)(kIter),
114 {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
115 }
116 }
117#else
119 statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
120 NIterPerWarp>
121 b_warp_windows;
122
123 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
124 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
125 b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
126
127 move_tile_window(b_warp_windows(nIter)(kIter),
128 {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
129 });
130 });
131#endif
132
133 using CWarpDstr = typename WG::CWarpDstr;
134 using CWarpTensor = typename WG::CWarpTensor;
135
136 constexpr auto c_warp_y_lengths =
137 to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
138 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
139
140 // hot loop:
141 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
142 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
143 // read A warp tensor from A block window
144 const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
145
146 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
147 // read B warp tensor from B Block window
148 const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
149
150 // read C warp tensor from C block tensor
151 CWarpTensor c_warp_tensor;
152
153 c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
154 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
155 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
156
157 // warp GEMM
158 WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
159
160 // write C warp tensor into C block tensor
161 c_block_tensor.set_y_sliced_thread_data(
162 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
163 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
164 c_warp_tensor.get_thread_buffer());
165 });
166 });
167 });
168 }
169
170 CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
171 {
172 constexpr index_t MPerBlock = BlockGemmShape::kM;
173 constexpr index_t NPerBlock = BlockGemmShape::kN;
174
175 constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
176
177 using WG = remove_cvref_t<decltype(config.template at<0>())>;
178
179 constexpr index_t MWarp = config.template at<1>();
180 constexpr index_t NWarp = config.template at<2>();
181
182 constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
183 constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
184
185 constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
192
193 constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
194 c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
195
196 constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
197
198 auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
199 return c_block_tensor;
200 }
201
202 // C = A * B
203 template <typename ABlockTensorTmp, typename BBlockWindow>
204 CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
205 const BBlockWindow& b_block_window) const
206 {
207 auto c_block_tensor = MakeCBlockTile();
208 operator()(c_block_tensor, a_block_tensor_tmp, b_block_window);
209 return c_block_tensor;
210 }
211};
212
213} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition tile_distribution_encoding.hpp:457
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE index_t get_warp_id(bool_constant< ReturnSgpr >={})
Definition arch.hpp:104
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_HOST_DEVICE constexpr auto merge_sequences(Seqs...)
Definition tile/core/container/sequence.hpp:826
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
array< index_t, N > multi_index
Definition tile/core/container/multi_index.hpp:17
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple< number< Is >... >)
Definition tile/core/container/sequence.hpp:1055
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition tile/core/container/sequence.hpp:1026
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
tuple_array< T, N > statically_indexed_array
Definition tile/core/container/statically_indexed_array.hpp:16
Definition block_gemm_asmem_bsmem_creg_v1.hpp:16
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp &a_block_tensor_tmp, const BBlockWindow &b_block_window) const
Definition block_gemm_asmem_bsmem_creg_v1.hpp:204
remove_cvref_t< typename Problem::BlockGemmShape > BlockGemmShape
Definition block_gemm_asmem_bsmem_creg_v1.hpp:22
CK_TILE_DEVICE void operator()(CBlockTensor &c_block_tensor, const ABlockWindow &a_block_window, const BBlockWindow &b_block_window) const
Definition block_gemm_asmem_bsmem_creg_v1.hpp:28
static constexpr index_t kBlockSize
Definition block_gemm_asmem_bsmem_creg_v1.hpp:24
static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
Definition block_gemm_asmem_bsmem_creg_v1.hpp:170
remove_cvref_t< typename Problem::BDataType > BDataType
Definition block_gemm_asmem_bsmem_creg_v1.hpp:20
remove_cvref_t< Policy_ > Policy
Definition block_gemm_asmem_bsmem_creg_v1.hpp:18
remove_cvref_t< typename Problem::ADataType > ADataType
Definition block_gemm_asmem_bsmem_creg_v1.hpp:19
remove_cvref_t< typename Problem::CDataType > CDataType
Definition block_gemm_asmem_bsmem_creg_v1.hpp:21
remove_cvref_t< Problem_ > Problem
Definition block_gemm_asmem_bsmem_creg_v1.hpp:17
A fixed-size array container similar to std::array with additional utilities.
Definition tile/core/container/array.hpp:43
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192