gemm_pipeline_ag_bg_cr_base.hpp Source File

gemm_pipeline_ag_bg_cr_base.hpp Source File#

Composable Kernel: gemm_pipeline_ag_bg_cr_base.hpp Source File
gemm_pipeline_ag_bg_cr_base.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
11template <typename Problem, typename Policy>
13{
19
24
25 static constexpr index_t MPerBlock = BlockGemmShape::kM;
26 static constexpr index_t NPerBlock = BlockGemmShape::kN;
27 static constexpr index_t KPerBlock = BlockGemmShape::kK;
28#if defined(__gfx950__)
29 static constexpr bool is_a_load_tr = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
30 static constexpr bool is_b_load_tr = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
31#else
32 static constexpr bool is_a_load_tr = false;
33 static constexpr bool is_b_load_tr = false;
34#endif
35
36 CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
37
38 template <typename DstBlockTile, typename SrcTileWindow, typename DramTileWindowStep>
39 CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile,
40 SrcTileWindow& dram_tile_window,
41 const DramTileWindowStep& dram_tile_window_step) const
42 {
43 load_tile(dst_block_tile, dram_tile_window);
44 move_tile_window(dram_tile_window, dram_tile_window_step);
45 }
46
47 template <typename DstBlockWindow, typename SrcTileWindow, typename DramTileWindowStep>
48 CK_TILE_DEVICE void GlobalPrefetchAsync(DstBlockWindow& dst_block_window,
49 SrcTileWindow& dram_tile_window,
50 const DramTileWindowStep& dram_tile_window_step) const
51 {
52 async_load_tile(dst_block_window, dram_tile_window);
53 move_tile_window(dram_tile_window, dram_tile_window_step);
54 }
55
56 template <typename DstTileWindow, typename SrcBlockTile, typename ElementFunction>
57 CK_TILE_DEVICE void LocalPrefill(DstTileWindow& lds_tile_window,
58 const SrcBlockTile& src_block_tile,
59 const ElementFunction& element_func) const
60 {
61 const auto block_tile_tmp = tile_elementwise_in(element_func, src_block_tile);
62 store_tile(lds_tile_window, block_tile_tmp);
63 }
64
65 template <typename DstTileWindow, typename SrcBlockTile>
66 CK_TILE_DEVICE void LocalPrefill(DstTileWindow& lds_tile_window,
67 const SrcBlockTile& src_block_tile) const
68 {
69 store_tile(lds_tile_window, src_block_tile);
70 }
71
72 template <typename DstBlockTile, typename SrcTileWindow, bool LoadTranspose = false>
73 CK_TILE_DEVICE void LocalPrefetch(DstBlockTile& dst_block_tile,
74 const SrcTileWindow& lds_tile_window,
76 {
77 if constexpr(LoadTranspose)
78 dst_block_tile = load_tile_transpose(lds_tile_window);
79 else
80 load_tile(dst_block_tile, lds_tile_window);
81 }
82
83 CK_TILE_DEVICE auto GetABLdsTensorViews(void* p_smem) const
84 {
85 // A tile in LDS
86 ADataType* __restrict__ p_a_lds = static_cast<ADataType*>(p_smem);
87 constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
88 auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
89
90 // TODO: LDS alignment should come from Policy!
91 constexpr index_t a_lds_block_space_size_aligned = integer_least_multiple(
92 sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16);
93
94 // B tile in LDS
95 BDataType* __restrict__ p_b_lds = static_cast<BDataType*>(
96 static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
97 constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
98 auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
99
100 return make_tuple(std::move(a_lds_block), std::move(b_lds_block));
101 }
102
103 template <typename DramBlockWindowTmp,
104 typename std::enable_if_t<is_detected<is_tuple, DramBlockWindowTmp>::value, bool>* =
105 nullptr>
106 CK_TILE_DEVICE constexpr auto CopyADramWindow(const DramBlockWindowTmp& dram_block_window_tmp,
107 const array<index_t, 2>& offset = {0, 0}) const
108 {
109 constexpr bool is_col_major = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
110
111 using YPerTile = std::conditional_t<is_col_major, number<KPerBlock>, number<MPerBlock>>;
112 using XPerTile = std::conditional_t<is_col_major, number<MPerBlock>, number<KPerBlock>>;
113 // A DRAM tile window for load
114 auto a_copy_dram_window = generate_tuple(
115 [&](auto idx) {
116 return make_tile_window(
117 dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
118 make_tuple(YPerTile{}, XPerTile{}),
119 dram_block_window_tmp[number<idx>{}].get_window_origin() + offset,
120 Policy::template MakeADramTileDistribution<Problem>());
121 },
122 number<DramBlockWindowTmp::size()>{});
123 return std::move(a_copy_dram_window);
124 }
125
126 template <typename DramBlockWindowTmp,
127 typename std::enable_if_t<!is_detected<is_tuple, DramBlockWindowTmp>::value, bool>* =
128 nullptr>
129 CK_TILE_DEVICE constexpr auto CopyADramWindow(const DramBlockWindowTmp& dram_block_window_tmp,
130 const array<index_t, 2>& offset = {0, 0}) const
131 {
132 constexpr bool is_col_major = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
133
134 using YPerTile = std::conditional_t<is_col_major, number<KPerBlock>, number<MPerBlock>>;
135 using XPerTile = std::conditional_t<is_col_major, number<MPerBlock>, number<KPerBlock>>;
136 // A DRAM tile window for load
137 auto a_copy_dram_window =
138 make_tile_window(dram_block_window_tmp.get_bottom_tensor_view(),
139 make_tuple(YPerTile{}, XPerTile{}),
140 dram_block_window_tmp.get_window_origin() + offset,
141 Policy::template MakeADramTileDistribution<Problem>());
142
143 return std::move(a_copy_dram_window);
144 }
145
146 template <typename DramBlockWindowTmp,
147 typename std::enable_if_t<is_detected<is_tuple, DramBlockWindowTmp>::value, bool>* =
148 nullptr>
149 CK_TILE_DEVICE constexpr auto CopyBDramWindow(const DramBlockWindowTmp& dram_block_window_tmp,
150 const array<index_t, 2>& offset = {0, 0}) const
151 {
152 constexpr bool is_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
153
154 using YPerTile = std::conditional_t<is_row_major, number<KPerBlock>, number<NPerBlock>>;
155 using XPerTile = std::conditional_t<is_row_major, number<NPerBlock>, number<KPerBlock>>;
156 // A DRAM tile window for load
157 auto a_copy_dram_window = generate_tuple(
158 [&](auto idx) {
159 return make_tile_window(
160 dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
161 make_tuple(YPerTile{}, XPerTile{}),
162 dram_block_window_tmp[number<idx>{}].get_window_origin() + offset,
163 Policy::template MakeBDramTileDistribution<Problem>());
164 },
165 number<DramBlockWindowTmp::size()>{});
166 return std::move(a_copy_dram_window);
167 }
168
169 template <typename DramBlockWindowTmp,
170 typename std::enable_if_t<!is_detected<is_tuple, DramBlockWindowTmp>::value, bool>* =
171 nullptr>
172 CK_TILE_DEVICE constexpr auto CopyBDramWindow(const DramBlockWindowTmp& dram_block_window_tmp,
173 const array<index_t, 2>& offset = {0, 0}) const
174 {
175 constexpr bool is_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
176
177 using YPerTile = std::conditional_t<is_row_major, number<KPerBlock>, number<NPerBlock>>;
178 using XPerTile = std::conditional_t<is_row_major, number<NPerBlock>, number<KPerBlock>>;
179 // A DRAM tile window for load
180 auto a_copy_dram_window =
181 make_tile_window(dram_block_window_tmp.get_bottom_tensor_view(),
182 make_tuple(YPerTile{}, XPerTile{}),
183 dram_block_window_tmp.get_window_origin() + offset,
184 Policy::template MakeBDramTileDistribution<Problem>());
185
186 return std::move(a_copy_dram_window);
187 }
188
189 template <typename ADramBlockWindowTmp, typename ALdsTensorView, typename ALdsLoadTileDistr>
190 CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp,
191 const ALdsTensorView& a_lds_block_view,
192 const ALdsLoadTileDistr&,
193 const array<index_t, 2>& offset = {0, 0}) const
194 {
195 // A DRAM tile window for load
196 auto a_copy_dram_window = CopyADramWindow(a_dram_block_window_tmp, offset);
197
198 // A LDS tile window for store
199 auto a_lds_shape = []() {
200 if constexpr(is_a_load_tr)
202 else
204 }();
205 auto a_copy_lds_window = make_tile_window(a_lds_block_view, a_lds_shape, {0, 0});
206
207 auto a_lds_load_tile_distr = []() {
208 if constexpr(is_a_load_tr)
211 typename ALdsLoadTileDistr::DstrEncode,
212 typename Problem::ADataType>::TransposedDstrEncode{});
213 else
214 return ALdsLoadTileDistr{};
215 }();
216 auto a_lds_gemm_window =
217 make_tile_window(a_lds_block_view, a_lds_shape, {0, 0}, a_lds_load_tile_distr);
218
219 return make_tuple(std::move(a_copy_dram_window),
220 std::move(a_copy_lds_window),
221 std::move(a_lds_gemm_window));
222 }
223
224 template <typename BDramBlockWindowTmp, typename BLdsTensorView, typename BLdsLoadTileDistr>
225 CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp,
226 const BLdsTensorView& b_lds_block_view,
227 const BLdsLoadTileDistr&,
228 const array<index_t, 2>& offset = {0, 0}) const
229 {
230 // A DRAM tile window for load
231 auto b_copy_dram_window = CopyBDramWindow(b_dram_block_window_tmp, offset);
232
233 // TODO: Do we really need those two tile windows???
234 // They're exactly same...
235 // B LDS tile window for store
236 auto b_lds_shape = []() {
237 if constexpr(is_b_load_tr)
239 else
241 }();
242 auto b_copy_lds_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0});
243
244 auto b_lds_load_tile_distr = []() {
245 if constexpr(is_b_load_tr)
248 typename BLdsLoadTileDistr::DstrEncode,
249 typename Problem::BDataType>::TransposedDstrEncode{});
250 else
251 return BLdsLoadTileDistr{};
252 }();
253 auto b_lds_gemm_window =
254 make_tile_window(b_lds_block_view, b_lds_shape, {0, 0}, b_lds_load_tile_distr);
255
256 return make_tuple(std::move(b_copy_dram_window),
257 std::move(b_copy_lds_window),
258 std::move(b_lds_gemm_window));
259 }
260};
261
262} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
@ offset
Definition coordinate_transform.hpp:26
CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_ &&lds_tile, const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:119
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
TransposeTileDistributionTraits< TileDistributionEncoding_, DataType_, Policy, true > InputTileDistributionTraits
Definition load_tile_transpose.hpp:343
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_DEVICE auto load_tile_transpose(const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window)
transpose loads tile from a tensor and returns the resulting tensor with a new (transposed) tile dist...
Definition load_tile_transpose.hpp:403
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
CK_TILE_HOST_DEVICE constexpr auto integer_least_multiple(X x, Y y)
Definition tile/core/numeric/math.hpp:155
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition gemm_pipeline_ag_bg_cr_base.hpp:13
CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp &b_dram_block_window_tmp, const BLdsTensorView &b_lds_block_view, const BLdsLoadTileDistr &, const array< index_t, 2 > &offset={0, 0}) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:225
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsLayout > > BLayout
Definition gemm_pipeline_ag_bg_cr_base.hpp:23
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsDataType > > BDataType
Definition gemm_pipeline_ag_bg_cr_base.hpp:22
remove_cvref_t< typename Problem::BsDataTypeTuple > BsDataType
Definition gemm_pipeline_ag_bg_cr_base.hpp:15
remove_cvref_t< typename Problem::AsDataTypeTuple > AsDataType
Definition gemm_pipeline_ag_bg_cr_base.hpp:14
CK_TILE_DEVICE auto GetABLdsTensorViews(void *p_smem) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:83
CK_TILE_DEVICE void LocalPrefill(DstTileWindow &lds_tile_window, const SrcBlockTile &src_block_tile) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:66
static constexpr index_t NPerBlock
Definition gemm_pipeline_ag_bg_cr_base.hpp:26
remove_cvref_t< typename Problem::AsLayoutTuple > AsLayout
Definition gemm_pipeline_ag_bg_cr_base.hpp:16
static CK_TILE_HOST_DEVICE constexpr auto TransposeC()
Definition gemm_pipeline_ag_bg_cr_base.hpp:36
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsLayout > > ALayout
Definition gemm_pipeline_ag_bg_cr_base.hpp:21
static constexpr bool is_a_load_tr
Definition gemm_pipeline_ag_bg_cr_base.hpp:32
CK_TILE_DEVICE constexpr auto CopyADramWindow(const DramBlockWindowTmp &dram_block_window_tmp, const array< index_t, 2 > &offset={0, 0}) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:106
CK_TILE_DEVICE void GlobalPrefetchAsync(DstBlockWindow &dst_block_window, SrcTileWindow &dram_tile_window, const DramTileWindowStep &dram_tile_window_step) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:48
static constexpr bool is_b_load_tr
Definition gemm_pipeline_ag_bg_cr_base.hpp:33
remove_cvref_t< typename Problem::BsLayoutTuple > BsLayout
Definition gemm_pipeline_ag_bg_cr_base.hpp:17
CK_TILE_DEVICE void LocalPrefetch(DstBlockTile &dst_block_tile, const SrcTileWindow &lds_tile_window, bool_constant< LoadTranspose >={}) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:73
CK_TILE_DEVICE constexpr auto CopyBDramWindow(const DramBlockWindowTmp &dram_block_window_tmp, const array< index_t, 2 > &offset={0, 0}) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:149
remove_cvref_t< typename Problem::BlockGemmShape > BlockGemmShape
Definition gemm_pipeline_ag_bg_cr_base.hpp:18
static constexpr index_t MPerBlock
Definition gemm_pipeline_ag_bg_cr_base.hpp:25
CK_TILE_DEVICE void LocalPrefill(DstTileWindow &lds_tile_window, const SrcBlockTile &src_block_tile, const ElementFunction &element_func) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:57
CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp &a_dram_block_window_tmp, const ALdsTensorView &a_lds_block_view, const ALdsLoadTileDistr &, const array< index_t, 2 > &offset={0, 0}) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:190
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsDataType > > ADataType
Definition gemm_pipeline_ag_bg_cr_base.hpp:20
static constexpr index_t KPerBlock
Definition gemm_pipeline_ag_bg_cr_base.hpp:27
CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile &dst_block_tile, SrcTileWindow &dram_tile_window, const DramTileWindowStep &dram_tile_window_step) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:39
A fixed-size array container similar to std::array with additional utilities.
Definition tile/core/container/array.hpp:43
Definition coordinate_transform.hpp:1392