blockwise_gemm_pipeline_xdlops_v3.hpp Source File

blockwise_gemm_pipeline_xdlops_v3.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_v3.hpp Source File
blockwise_gemm_pipeline_xdlops_v3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Compute optimized pipeline
11// GlobalPrefetchStages: 2
12// LocalPreFillStages: 1
13// LocalPreFetchStages: 1
14// LocalSharedMemoryBuffer: 1
15
16template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17 index_t BlockSize,
18 typename ADataType,
19 typename BDataType,
20 typename ComputeDataType,
21 typename AccDataType,
22 typename ATileDesc,
23 typename BTileDesc,
24 typename AMmaTileDesc,
25 typename BMmaTileDesc,
26 index_t ABlockTransferSrcScalarPerVector,
27 index_t BBlockTransferSrcScalarPerVector,
28 index_t MPerBlock,
29 index_t NPerBlock,
30 index_t KPerBlock,
31 index_t MPerXDL,
32 index_t NPerXDL,
33 index_t MRepeat,
34 index_t NRepeat,
35 index_t KPacks>
39
40template <index_t BlockSize,
41 typename ADataType,
42 typename BDataType,
43 typename ComputeDataType,
44 typename AccDataType,
45 typename ATileDesc,
46 typename BTileDesc,
47 typename AMmaTileDesc,
48 typename BMmaTileDesc,
49 index_t ABlockTransferSrcScalarPerVector,
50 index_t BBlockTransferSrcScalarPerVector,
51 index_t MPerBlock,
52 index_t NPerBlock,
53 index_t KPerBlock,
54 index_t MPerXDL,
55 index_t NPerXDL,
56 index_t MRepeat,
57 index_t NRepeat,
58 index_t KPack
59 // ,bool TransposeC //disable transposec right now...
60 >
62 BlockSize,
63 ADataType,
64 BDataType,
65 ComputeDataType,
66 AccDataType,
67 ATileDesc,
68 BTileDesc,
69 AMmaTileDesc,
70 BMmaTileDesc,
71 ABlockTransferSrcScalarPerVector,
72 BBlockTransferSrcScalarPerVector,
73 MPerBlock,
74 NPerBlock,
75 KPerBlock,
76 MPerXDL,
77 NPerXDL,
78 MRepeat,
79 NRepeat,
80 KPack>
82 ADataType,
83 BDataType,
84 ComputeDataType,
85 AccDataType,
86 ATileDesc,
87 BTileDesc,
88 AMmaTileDesc,
89 BMmaTileDesc,
90 ABlockTransferSrcScalarPerVector,
91 BBlockTransferSrcScalarPerVector,
92 MPerBlock,
93 NPerBlock,
94 KPerBlock,
95 MPerXDL,
96 NPerXDL,
97 MRepeat,
98 NRepeat,
99 KPack>
100
101{
103 ADataType,
104 BDataType,
105 ComputeDataType,
106 AccDataType,
107 ATileDesc,
108 BTileDesc,
109 AMmaTileDesc,
110 BMmaTileDesc,
111 ABlockTransferSrcScalarPerVector,
112 BBlockTransferSrcScalarPerVector,
113 MPerBlock,
114 NPerBlock,
115 KPerBlock,
116 MPerXDL,
117 NPerXDL,
118 MRepeat,
119 NRepeat,
120 KPack>;
121 using Base::I0;
122 using Base::I1;
123 using Base::KRepeat;
124 using Base::xdlops_gemm;
125 using typename Base::HotLoopInstList;
126
138
141
142 using Base::AMmaKStride;
143 using Base::BMmaKStride;
144
146
147 static constexpr index_t PrefetchStages = 2;
148 static constexpr index_t PrefillStages = 1;
149 static constexpr index_t GlobalBufferNum = 1;
150
151 __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
152 {
153 return num_loop > PrefetchStages;
154 }
155
156 __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
157 {
158 ignore = num_loop;
159 return TailNumber::Full;
160 }
161
162 __device__ static constexpr auto HotLoopScheduler()
163 {
164#if !defined(__gfx11__) && !defined(__gfx12__)
165 // A/B split schedule
166 // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
167 constexpr auto num_ds_read_inst_a =
168 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
171 constexpr auto num_ds_read_inst_b =
172 HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
175
176 constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
177 constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
178
179 constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
180 constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
181
182 constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
183 constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
184
185 constexpr auto ds_read_a_issue_cycle =
186 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
187 constexpr auto ds_read_b_issue_cycle =
188 HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
189 constexpr auto ds_read_a_mfma_rate =
190 (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
191 constexpr auto ds_read_b_mfma_rate =
192 (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
193
194 constexpr auto num_dsread_a_mfma =
195 (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
196 constexpr auto num_dsread_b_mfma =
197 (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
198
199 // stage 1
200 // Separate this part?
201 // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataTypeBuf) / sizeof(ADataType) >
202 // sizeof(ComputeDataTypeBuf) / sizeof(BDataType)
203 // ? sizeof(ComputeDataTypeBuf) / sizeof(ADataType)
204 // : sizeof(ComputeDataTypeBuf) / sizeof(BDataType);
205 constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
206 constexpr auto num_mfma_per_issue =
207 num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
208 constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
209 constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
210
212 ignore = i;
213 static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
214 ignore = idswrite;
215 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
216 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
217 });
218 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
219 __builtin_amdgcn_sched_group_barrier(
220 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA
221 });
223 ignore = i;
224 static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
225 ignore = idswrite;
226 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
227 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
228 });
229 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
230 __builtin_amdgcn_sched_group_barrier(
231 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA
232 });
233
234 // stage 2
236 if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
237 ds_read_a_mfma_rate)
238 {
239 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
240 }
241 else
242 {
243 __builtin_amdgcn_sched_group_barrier(0x100,
244 num_ds_read_inst_a - (num_dsread_a_mfma - 1) *
245 ds_read_a_mfma_rate,
246 0); // DS read
247 }
248 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
249 });
250
252 if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
253 ds_read_b_mfma_rate)
254 {
255 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
256 }
257 else
258 {
259 __builtin_amdgcn_sched_group_barrier(0x100,
260 num_ds_read_inst_b - (num_dsread_b_mfma - 1) *
261 ds_read_b_mfma_rate,
262 0); // DS read
263 }
264 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
265 });
266#endif
267 }
268
269 template <bool HasMainLoop,
270 TailNumber TailNum,
271 typename AGridDesc,
272 typename ABlockDesc,
273 typename ABlockTransfer,
274 typename AGridBuffer,
275 typename ABlockBuffer,
276 typename ABlockTransferStep,
277 typename BGridDesc,
278 typename BBlockDesc,
279 typename BBlockTransfer,
280 typename BGridBuffer,
281 typename BBlockBuffer,
282 typename BBlockTransferStep,
283 typename CThreadBuffer>
284 __device__ void Run(const AGridDesc& a_grid_desc,
285 const ABlockDesc& a_block_desc,
286 ABlockTransfer& a_blockwise_copy,
287 const AGridBuffer& a_grid_buf,
288 ABlockBuffer& a_block_buf,
289 const ABlockTransferStep& a_block_copy_step,
290 const BGridDesc& b_grid_desc,
291 const BBlockDesc& b_block_desc,
292 BBlockTransfer& b_blockwise_copy,
293 const BGridBuffer& b_grid_buf,
294 BBlockBuffer& b_block_buf,
295 const BBlockTransferStep& b_block_copy_step,
296 CThreadBuffer& c_thread_buf,
297 index_t num_loop) const
298 {
299 __builtin_amdgcn_sched_barrier(0);
301 a_thread_desc_.GetElementSpaceSize());
303 b_thread_desc_.GetElementSpaceSize());
304
305 // Global prefetch 1
306 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
307 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
308
309 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
310 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
311
312 // Local prefill 1
313 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
314 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
315
316 // Global prefetch 2
317 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
318 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
319
320 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
321 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
322
323 // Initialize C
324 c_thread_buf.Clear();
325
326 // Local prefetch 1
328 static_for<0, KRepeat, 1>{}([&](auto k0) {
329 static_for<0, MRepeat, 1>{}([&](auto m0) {
332 a_block_buf,
334 make_tuple(m0, I0, k0, I0),
335 a_thread_buf);
336 });
337 static_for<0, NRepeat, 1>{}([&](auto n0) {
340 b_block_buf,
342 make_tuple(n0, I0, k0, I0),
343 b_thread_buf);
344 });
345 });
346
347 __builtin_amdgcn_sched_barrier(0);
348
349 // main body
350 if constexpr(HasMainLoop)
351 {
352 index_t i = 0;
353 do
354 {
356
357 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
358 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
359
360 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
361 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
362
363 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
364 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
365
366 static_for<0, KRepeat, 1>{}([&](auto k0) {
367 static_for<0, MRepeat, 1>{}([&](auto m0) {
368 static_for<0, NRepeat, 1>{}([&](auto n0) {
371
372 static_for<0, KPack, 1>{}([&](auto ik) {
373 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
374 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
375 make_tuple(m0, I0, k0, ik))>{}];
376 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
377 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
378 make_tuple(n0, I0, k0, ik))>{}];
379 });
380
381 using mfma_input_type =
383 xdlops_gemm.K1PerXdlops>::type;
384
385 constexpr index_t c_offset =
386 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
387
388 xdlops_gemm.Run(
389 a_thread_vec.template AsType<mfma_input_type>(),
390 b_thread_vec.template AsType<mfma_input_type>(),
391 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
392 });
393 });
394 });
395
397
398 static_for<0, KRepeat, 1>{}([&](auto k0) {
399 static_for<0, MRepeat, 1>{}([&](auto m0) {
402 a_block_buf,
404 make_tuple(m0, I0, k0, I0),
405 a_thread_buf);
406 });
407 static_for<0, NRepeat, 1>{}([&](auto n0) {
410 b_block_buf,
412 make_tuple(n0, I0, k0, I0),
413 b_thread_buf);
414 });
415 });
416
418 __builtin_amdgcn_sched_barrier(0);
419
420 i += 1;
421 } while(i < (num_loop - 1));
422 }
423 // tail
424 if constexpr(TailNum == TailNumber::Full)
425 {
426 static_for<0, KRepeat, 1>{}([&](auto k0) {
427 static_for<0, MRepeat, 1>{}([&](auto m0) {
428 static_for<0, NRepeat, 1>{}([&](auto n0) {
431
432 static_for<0, KPack, 1>{}([&](auto ik) {
433 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
434 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
435 make_tuple(m0, I0, k0, ik))>{}];
436 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
437 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
438 make_tuple(n0, I0, k0, ik))>{}];
439 });
440
441 using mfma_input_type =
442 typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
443
444 constexpr index_t c_offset =
445 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
446
447 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
448 b_thread_vec.template AsType<mfma_input_type>(),
449 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
450 });
451 });
452 });
453 // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
454 // latency
455 // __builtin_amdgcn_sched_barrier(0);
456 }
457 }
458
459 protected:
460 using Base::a_thread_copy_;
461 using Base::a_thread_desc_;
462 using Base::b_thread_copy_;
463 using Base::b_thread_desc_;
464 using Base::c_thread_desc_;
465};
466
467} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Full
Definition blkgemmpipe_scheduler.hpp:49
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ BlockwiseGemmXdlops_pipeline_base(Tuple4 a_origin=CalculateAThreadOriginDataIndex(), Tuple4 b_origin=CalculateBThreadOriginDataIndex())
Constructor for BlockwiseGemmXdlops_pipeline_base.
Definition blockwise_gemm_pipeline_xdlops_base.hpp:222
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:280
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:239
static constexpr auto xdlops_gemm
Definition blockwise_gemm_pipeline_xdlops_base.hpp:54
conditional_t< std::is_same< ComputeDataType, ck::tf32_t >::value, float, ComputeDataType > ComputeDataTypeBuf
Definition blockwise_gemm_pipeline_xdlops_base.hpp:57
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k
Definition blockwise_gemm_pipeline_xdlops_base.hpp:360
static constexpr auto I1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:37
__host__ static __device__ constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:266
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:294
static constexpr index_t AMmaKStride
Definition blockwise_gemm_pipeline_xdlops_base.hpp:60
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:253
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops > HotLoopInstList
Definition blockwise_gemm_pipeline_xdlops_base.hpp:82
__host__ __device__ constexpr auto & GetCThreadBuffer()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:111
static constexpr auto I0
Definition blockwise_gemm_pipeline_xdlops_base.hpp:36
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:160
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:189
static constexpr index_t KRepeat
Definition blockwise_gemm_pipeline_xdlops_base.hpp:64
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition blockwise_gemm_pipeline_xdlops_base.hpp:359
static constexpr index_t BMmaKStride
Definition blockwise_gemm_pipeline_xdlops_base.hpp:61
__host__ static __device__ constexpr auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:341
__host__ static __device__ constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:307
__host__ static __device__ constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:324
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_v3.hpp:284
BlockwiseGemmXdlops_pipeline_base< BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack > Base
Definition blockwise_gemm_pipeline_xdlops_v3.hpp:102
Definition blockwise_gemm_pipeline_xdlops_v3.hpp:37
Definition functional2.hpp:33
Definition dtype_vector.hpp:10