blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp Source File

blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp Source File
blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Naive pipeline with lowest resource request per WGP
11// GlobalPrefetchStages: 2
12// LocalPreFillStages: 1
13// LocalPreFetchStages: 1
14// LocalSharedMemoryBuffer: 1
15
16template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17 index_t ThreadBlockSize,
18 index_t ScaleBlockSize,
19 typename ADataType,
20 typename AScaleDataType,
21 typename BDataType,
22 typename BScaleDataType,
23 typename ATileDesc,
24 typename BTileDesc,
25 typename AMmaTileDesc,
26 typename BMmaTileDesc,
27 index_t ABlockTransferSrcScalarPerVector,
28 index_t BBlockTransferSrcScalarPerVector,
29 index_t MPerBlock,
30 index_t NPerBlock,
31 index_t KPerBlock,
32 index_t MPerXDL,
33 index_t NPerXDL,
34 index_t MRepeat, // MXdlPerWave
35 index_t NRepeat, // NXdlPerWave
36 index_t KPack>
40
41template <index_t ThreadBlockSize,
42 index_t ScaleBlockSize,
43 typename ADataType,
44 typename AScaleDataType,
45 typename BDataType,
46 typename BScaleDataType,
47 typename ATileDesc,
48 typename BTileDesc,
49 typename AMmaTileDesc,
50 typename BMmaTileDesc,
51 index_t ABlockTransferSrcScalarPerVector,
52 index_t BBlockTransferSrcScalarPerVector,
53 index_t MPerBlock,
54 index_t NPerBlock,
55 index_t KPerBlock,
56 index_t MPerXDL,
57 index_t NPerXDL,
58 index_t MRepeat, // MXdlPerWave
59 index_t NRepeat, // NXdlPerWave
60 index_t KPack>
62 ThreadBlockSize,
63 ScaleBlockSize,
64 ADataType,
65 AScaleDataType,
66 BDataType,
67 BScaleDataType,
68 ATileDesc,
69 BTileDesc,
70 AMmaTileDesc,
71 BMmaTileDesc,
72 ABlockTransferSrcScalarPerVector,
73 BBlockTransferSrcScalarPerVector,
74 MPerBlock,
75 NPerBlock,
76 KPerBlock,
77 MPerXDL,
78 NPerXDL,
79 MRepeat,
80 NRepeat,
81 KPack>
83 ADataType,
84 BDataType,
85 ATileDesc,
86 BTileDesc,
87 AMmaTileDesc,
88 BMmaTileDesc,
89 ABlockTransferSrcScalarPerVector,
90 BBlockTransferSrcScalarPerVector,
91 MPerBlock,
92 NPerBlock,
93 KPerBlock,
94 MPerXDL,
95 NPerXDL,
96 MRepeat,
97 NRepeat,
98 KPack>
99
100{
101
103 ADataType,
104 BDataType,
105 ATileDesc,
106 BTileDesc,
107 AMmaTileDesc,
108 BMmaTileDesc,
109 ABlockTransferSrcScalarPerVector,
110 BBlockTransferSrcScalarPerVector,
111 MPerBlock,
112 NPerBlock,
113 KPerBlock,
114 MPerXDL,
115 NPerXDL,
116 MRepeat,
117 NRepeat,
118 KPack>;
119 using Base::A_K1;
120 using Base::I0;
121 using Base::I1;
122 using Base::KRepeat;
123 using Base::MWaves;
124 using Base::NWaves;
125 using Base::WaveSize;
126 using Base::xdlops_gemm;
127 using typename Base::HotLoopInstList;
128
137 using Base::GetWaveIdx;
140
143
144 using Base::AMmaKStride;
145 using Base::APackedSize;
146 using Base::BMmaKStride;
147 using Base::BPackedSize;
148 using Base::KThreadChunk;
149
150 using Base::KXdlPack;
151 using Base::MXdlPack;
152 using Base::NXdlPack;
153
154 using AccType = typename Base::AccType;
155 using Tuple5 = typename Base::Tuple5;
158
159 static constexpr index_t PrefetchStages = 2;
160 static constexpr index_t PrefillStages = 1;
161 static constexpr index_t GlobalBufferNum = 2;
162 static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1;
163
164 static constexpr auto num_buffer_load_a_scale = MRepeat / MXdlPack * KRepeat / KXdlPack;
165 static constexpr auto num_buffer_load_b_scale = NRepeat / NXdlPack * KRepeat / KXdlPack;
166 static constexpr auto async_vmcnt =
168 static constexpr auto async_vmcnt_encoding = 3952 + async_vmcnt % 16 + async_vmcnt / 16 * 16384;
169
170 static constexpr auto ScalesPerKBlockSize =
171 KPerBlock / ScaleBlockSize; // How many mx-vectors per K block
172
173 //> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run()
174 static constexpr auto ScalesPerXdlopsRun =
175 (APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
176
177 //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run()
178 static constexpr auto ScalesPerXdlopsRunPerThread =
179 ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
180
182 static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
183 static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
184 static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
185 "A scale pack data type too large!");
186 static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
187 "B scale pack data type too large!");
190
191 __host__ static constexpr bool BlockHasHotloop(index_t num_loop)
192 {
193 return num_loop > PrefetchStages;
194 }
195
196 __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
197 {
198 return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
199 }
200
201 __device__ static constexpr auto HotLoopScheduler()
202 {
203 constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
204 constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
205 constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves +
207 constexpr auto mfma_interleave = MPerXDL == 32 ? 1 : 2;
208 // B global
210 ignore = i;
211 if constexpr(MPerBlock >= 128 && NPerBlock >= 128)
212 {
213 __builtin_amdgcn_sched_group_barrier(0x008, 2 * mfma_interleave, 0);
214 }
215 else
216 {
217 __builtin_amdgcn_sched_group_barrier(0x008, mfma_interleave, 0);
218 }
219 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
220 });
221
222 // A global
224 ignore = i;
225 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
226 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
227 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
228 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
229 });
230
231 // A local
232 static_for<0, MPerXDL == 32 ? num_ds_read_inst_a / 2 : num_ds_read_inst_a, 1>{}(
233 [&](auto i) {
234 ignore = i;
235 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
236 __builtin_amdgcn_sched_group_barrier(0x100, MPerXDL == 32 ? 2 : 1, 0); // DS read
237 });
238 }
239
240 template <bool HasMainLoop,
241 TailNumber TailNum,
242 typename AGridDesc,
243 typename ABlockDesc,
244 typename ABlockTransfer,
245 typename AGridBuffer,
246 typename ABlockBuffer,
247 typename ABlockTransferStep,
248 typename BGridDesc,
249 typename BBlockDesc,
250 typename BBlockTransfer,
251 typename BGridBuffer,
252 typename BBlockBuffer,
253 typename BBlockTransferStep,
254 typename CThreadBuffer,
255 typename AScaleGridBuffer,
256 typename AScaleGridDesc,
257 typename AScaleThreadTransfer,
258 typename BScaleGridBuffer,
259 typename BScaleGridDesc,
260 typename BScaleThreadTransfer>
261 __device__ void Run(
262 // ABlockCopy
263 const AGridDesc& a_grid_desc,
264 const ABlockDesc& a_block_desc,
265 ABlockTransfer& a_blockwise_copy,
266 const AGridBuffer& a_grid_buf,
267 ABlockBuffer& a_block_buf,
268 const ABlockTransferStep& a_block_copy_step,
269 // BBlockCopy
270 const BGridDesc& b_grid_desc,
271 const BBlockDesc& b_block_desc,
272 BBlockTransfer& b_blockwise_copy,
273 const BGridBuffer& b_grid_buf,
274 BBlockBuffer& b_block_bufs,
275 const BBlockTransferStep& b_block_copy_step,
276 // CThread
277 CThreadBuffer& c_thread_buf,
278 // A and B scales
279 const AScaleGridDesc& a_scale_grid_desc,
280 AScaleThreadTransfer& a_scale_thread_copy,
281 const AScaleGridBuffer& a_scale_grid_buf,
282 const BScaleGridDesc& b_scale_grid_desc,
283 BScaleThreadTransfer& b_scale_thread_copy,
284 const BScaleGridBuffer& b_scale_grid_buf,
285 index_t num_loop) const
286 {
287 ignore = b_block_bufs;
288 __builtin_amdgcn_sched_barrier(0);
290 a_thread_desc_.GetElementSpaceSize());
292 b_thread_desc_.GetElementSpaceSize());
293
294 StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
295 constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0);
296
298 a_scale_thread_desc.GetElementSpaceSize());
300 b_scale_thread_desc.GetElementSpaceSize());
301
302 StaticallyIndexedArray<decltype(a_scale_thread_buf), Number<2>{}> a_scale_thread_bufs;
303 StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
304
305 // Global prefetch 1
306 a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
307 b_blockwise_copy.Run(
308 b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_thread_bufs(I0));
309
310 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
311 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
312 __builtin_amdgcn_sched_barrier(0);
313
314 // Prefetch a_scales
315 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
316 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
317 a_scale_thread_copy.Run(a_scale_grid_desc,
318 a_scale_grid_buf,
320 make_tuple(m0, k0, I0),
321 a_scale_thread_bufs(I0));
322
323 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
324 make_multi_index(0, I1, 0));
325 });
326 a_scale_thread_copy.MoveSrcSliceWindow(
327 a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
328 });
329
330 // restore row id and advance to the next set of scales
331 a_scale_thread_copy.MoveSrcSliceWindow(
332 a_scale_grid_desc,
333 make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
334
335 // Prefetch b_scales
336 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
337 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
338 b_scale_thread_copy.Run(b_scale_grid_desc,
339 b_scale_grid_buf,
341 make_tuple(n0, k0, I0),
342 b_scale_thread_bufs(I0));
343
344 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
345 make_multi_index(0, I1, 0));
346 });
347 b_scale_thread_copy.MoveSrcSliceWindow(
348 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
349 });
350
351 // restore col id and advance to the next set of scales
352 // NWaves * NPerXDL * NRepeat == NPerBlock
353 b_scale_thread_copy.MoveSrcSliceWindow(
354 b_scale_grid_desc,
355 make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
356
357 // Local prefetch 1, sync the async load
358 __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding);
360 static_for<0, MRepeat, 1>{}([&](auto m0) {
361 static_for<0, KRepeat, 1>{}([&](auto k) {
362 constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
363 (APackedSize * KPack / xdlops_gemm.K1PerXdlops);
364 static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
365 [&](auto chunk) {
366 constexpr auto a_k_step_chunk =
367 k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
370 I0,
372 I0,
374 a_block_buf,
377 I0,
379 k,
381 a_thread_buf);
382 });
383 });
384 });
385
386 // Initialize C
387 c_thread_buf.Clear();
388 __builtin_amdgcn_sched_barrier(0);
389 // main body
390 if constexpr(HasMainLoop)
391 {
392 // loop over k with the step KPerBlock
393 index_t i = 0;
394 do
395 {
396 auto LoopFunc = [&](auto scale_comp_buf, auto scale_mem_buf) {
397 b_blockwise_copy.Run(b_grid_desc,
398 b_grid_buf,
399 b_block_desc,
400 b_block_origin_idx,
401 b_thread_bufs(scale_mem_buf));
402
404 a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
405 // Prefetch a_scales
406 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
407 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
408 a_scale_thread_copy.Run(a_scale_grid_desc,
409 a_scale_grid_buf,
411 make_tuple(m0, k0, I0),
412 a_scale_thread_bufs(scale_mem_buf));
413
414 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
415 make_multi_index(0, I1, 0));
416 });
417 a_scale_thread_copy.MoveSrcSliceWindow(
418 a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
419 });
420
421 // restore row id and advance to the next set of scales
422 a_scale_thread_copy.MoveSrcSliceWindow(
423 a_scale_grid_desc,
424 make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
425
426 // Prefetch b_scales
427 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
428 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
429 b_scale_thread_copy.Run(b_scale_grid_desc,
430 b_scale_grid_buf,
432 make_tuple(n0, k0, I0),
433 b_scale_thread_bufs(scale_mem_buf));
434
435 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
436 make_multi_index(0, I1, 0));
437 });
438 b_scale_thread_copy.MoveSrcSliceWindow(
439 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
440 });
441
442 // restore col id and advance to the next set of scales
443 // NWaves * NPerXDL * NRepeat == NPerBlock
444 b_scale_thread_copy.MoveSrcSliceWindow(
445 b_scale_grid_desc,
446 make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
447
448 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
449 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
450
451 static_for<0, MRepeat, 1>{}([&](auto m0) {
452 constexpr auto im_major = m0 / MXdlPack;
453 constexpr auto im_minor = m0 % MXdlPack;
454 static_for<0, KRepeat, 1>{}([&](auto k0) {
455 constexpr auto ik_major = k0 / KXdlPack;
456 constexpr auto ik_minor = k0 % KXdlPack;
457 static_for<0, NRepeat, 1>{}([&](auto n0) {
458 constexpr auto in_major = n0 / NXdlPack;
459 constexpr auto in_minor = n0 % NXdlPack;
460
461 constexpr index_t a_scale_offset =
462 a_scale_thread_desc.CalculateOffset(
463 make_tuple(im_major, ik_major, I0));
464 constexpr index_t b_scale_offset =
465 b_scale_thread_desc.CalculateOffset(
466 make_tuple(in_major, ik_major, I0));
467
468 static_assert(0 < ScalesPerXdlopsRunPerThread,
469 "Must have at least one scale per Xdlops "
470 "per Thread.");
471
473 a_scale_thread_vec;
475 b_scale_thread_vec;
476
477 // Pack scale_thread_buf into scale_thread_vec
479 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
480 a_scale_thread_bufs(
481 scale_comp_buf)[Number<a_scale_offset + s>{}];
482 });
483
485 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
486 b_scale_thread_bufs(
487 scale_comp_buf)[Number<b_scale_offset + s>{}];
488 });
489
492
493 static_for<0, KPack, 1>{}([&](auto ik) {
494 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
495 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
496 make_tuple(im_major, I0, im_minor, k0, ik))>{}];
497 b_thread_vec.template AsType<ComputeTypeB>()(ik) = b_thread_bufs
498 [scale_comp_buf][Number<b_thread_desc_.CalculateOffset(
499 make_tuple(in_major, I0, in_minor, k0, ik))>{}];
500 });
501
502 using mfma_input_type_a =
503 typename vector_type<ComputeTypeA,
504 xdlops_gemm.K1PerXdlops /
505 APackedSize>::type;
506
507 using mfma_input_type_b =
508 typename vector_type<ComputeTypeB,
509 xdlops_gemm.K1PerXdlops /
510 BPackedSize>::type;
511
512 using mfma_scale_input_type_a =
513 typename vector_type<AScaleDataType,
515 using mfma_scale_input_type_b =
516 typename vector_type<BScaleDataType,
518
519 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
520 make_tuple(im_major, in_major, im_minor, in_minor, 0));
521
522 // MFMA accumulation
523 xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
524 ik_minor * NXdlPack + in_minor>(
525 a_thread_vec.template AsType<mfma_input_type_a>(),
526 a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
527 b_thread_vec.template AsType<mfma_input_type_b>(),
528 b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
529 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
530 });
531 });
532 });
533
535
536 static_for<0, MRepeat, 1>{}([&](auto m0) {
537 static_for<0, KRepeat, 1>{}([&](auto k) {
538 constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
539 (APackedSize * KPack / xdlops_gemm.K1PerXdlops);
540 static_for<0,
541 xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk),
542 1>{}([&](auto chunk) {
543 constexpr auto a_k_step_chunk =
544 k_step +
545 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
548 I0,
550 I0,
552 a_block_buf,
555 I0,
557 k,
559 a_thread_buf);
560 });
561 });
562 });
564 __builtin_amdgcn_sched_barrier(0);
565 };
566
567 LoopFunc(I0, I1);
568 LoopFunc(I1, I0);
569
570 i += 2;
571 } while(i < (num_loop - 2));
572 }
573
574 // tail
575 if constexpr(TailNum == TailNumber::Even)
576 {
577 b_blockwise_copy.Run(
578 b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_thread_bufs(I1));
579
581 a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
582 // Prefetch a_scales
583 static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
584 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
585 a_scale_thread_copy.Run(a_scale_grid_desc,
586 a_scale_grid_buf,
588 make_tuple(m0, k0, I0),
589 a_scale_thread_bufs(I1));
590
591 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
592 make_multi_index(0, I1, 0));
593 });
594 a_scale_thread_copy.MoveSrcSliceWindow(
595 a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
596 });
597
598 // Prefetch b_scales
599 static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
600 static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
601 b_scale_thread_copy.Run(b_scale_grid_desc,
602 b_scale_grid_buf,
604 make_tuple(n0, k0, I0),
605 b_scale_thread_bufs(I1));
606
607 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
608 make_multi_index(0, I1, 0));
609 });
610 b_scale_thread_copy.MoveSrcSliceWindow(
611 b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
612 });
613
614 static_for<0, MRepeat, 1>{}([&](auto m0) {
615 constexpr auto im_major = m0 / MXdlPack;
616 constexpr auto im_minor = m0 % MXdlPack;
617 static_for<0, KRepeat, 1>{}([&](auto k0) {
618 constexpr auto ik_major = k0 / KXdlPack;
619 constexpr auto ik_minor = k0 % KXdlPack;
620 static_for<0, NRepeat, 1>{}([&](auto n0) {
621 constexpr auto in_major = n0 / NXdlPack;
622 constexpr auto in_minor = n0 % NXdlPack;
623
624 constexpr index_t a_scale_offset =
625 a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
626 constexpr index_t b_scale_offset =
627 b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
628
629 static_assert(0 < ScalesPerXdlopsRunPerThread,
630 "Must have at least one scale per Xdlops "
631 "per Thread.");
632
635
636 // Pack scale_thread_buf into scale_thread_vec
638 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
639 a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
640 });
641
643 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
644 b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
645 });
646
649
650 static_for<0, KPack, 1>{}([&](auto ik) {
651 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
652 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
653 make_tuple(im_major, I0, im_minor, k0, ik))>{}];
654 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
655 b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
656 make_tuple(in_major, I0, in_minor, k0, ik))>{}];
657 });
658
659 using mfma_input_type_a =
660 typename vector_type<ComputeTypeA,
661 xdlops_gemm.K1PerXdlops / APackedSize>::type;
662
663 using mfma_input_type_b =
664 typename vector_type<ComputeTypeB,
665 xdlops_gemm.K1PerXdlops / BPackedSize>::type;
666
667 using mfma_scale_input_type_a =
669 using mfma_scale_input_type_b =
671
672 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
673 make_tuple(im_major, in_major, im_minor, in_minor, 0));
674
675 // MFMA accumulation
676 xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
677 ik_minor * NXdlPack + in_minor>(
678 a_thread_vec.template AsType<mfma_input_type_a>(),
679 a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
680 b_thread_vec.template AsType<mfma_input_type_b>(),
681 b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
682 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
683 });
684 });
685
686 // constexpr auto lds_buf = m0.value >= SwitchM ? I1 : I0;
687 });
688 __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding);
690
691 static_for<0, MRepeat, 1>{}([&](auto m0) {
692 static_for<0, KRepeat, 1>{}([&](auto k) {
693 constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
694 (APackedSize * KPack / xdlops_gemm.K1PerXdlops);
695 static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
696 [&](auto chunk) {
697 constexpr auto a_k_step_chunk =
698 k_step +
699 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
702 I0,
704 I0,
706 a_block_buf,
709 I0,
711 k,
713 a_thread_buf);
714 });
715 });
716 });
717 __builtin_amdgcn_sched_barrier(0);
718
719 static_for<0, MRepeat, 1>{}([&](auto m0) {
720 constexpr auto im_major = m0 / MXdlPack;
721 constexpr auto im_minor = m0 % MXdlPack;
722 static_for<0, KRepeat, 1>{}([&](auto k0) {
723 constexpr auto ik_major = k0 / KXdlPack;
724 constexpr auto ik_minor = k0 % KXdlPack;
725 static_for<0, NRepeat, 1>{}([&](auto n0) {
726 constexpr auto in_major = n0 / NXdlPack;
727 constexpr auto in_minor = n0 % NXdlPack;
728
729 constexpr index_t a_scale_offset =
730 a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
731 constexpr index_t b_scale_offset =
732 b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
733
734 static_assert(0 < ScalesPerXdlopsRunPerThread,
735 "Must have at least one scale per Xdlops "
736 "per Thread.");
737
740
741 // Pack scale_thread_buf into scale_thread_vec
743 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
744 a_scale_thread_bufs(I1)[Number<a_scale_offset + s>{}];
745 });
746
748 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
749 b_scale_thread_bufs(I1)[Number<b_scale_offset + s>{}];
750 });
751
754
755 static_for<0, KPack, 1>{}([&](auto ik) {
756 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
757 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
758 make_tuple(im_major, I0, im_minor, k0, ik))>{}];
759 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
760 b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
761 make_tuple(in_major, I0, in_minor, k0, ik))>{}];
762 });
763
764 using mfma_input_type_a =
765 typename vector_type<ComputeTypeA,
766 xdlops_gemm.K1PerXdlops / APackedSize>::type;
767
768 using mfma_input_type_b =
769 typename vector_type<ComputeTypeB,
770 xdlops_gemm.K1PerXdlops / BPackedSize>::type;
771
772 using mfma_scale_input_type_a =
774 using mfma_scale_input_type_b =
776
777 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
778 make_tuple(im_major, in_major, im_minor, in_minor, 0));
779
780 // MFMA accumulation
781 xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
782 ik_minor * NXdlPack + in_minor>(
783 a_thread_vec.template AsType<mfma_input_type_a>(),
784 a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
785 b_thread_vec.template AsType<mfma_input_type_b>(),
786 b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
787 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
788 });
789 });
790 });
791 }
792 else if constexpr(TailNum == TailNumber::Odd)
793 {
794 static_for<0, MRepeat, 1>{}([&](auto m0) {
795 constexpr auto im_major = m0 / MXdlPack;
796 constexpr auto im_minor = m0 % MXdlPack;
797 static_for<0, KRepeat, 1>{}([&](auto k0) {
798 constexpr auto ik_major = k0 / KXdlPack;
799 constexpr auto ik_minor = k0 % KXdlPack;
800 static_for<0, NRepeat, 1>{}([&](auto n0) {
801 constexpr auto in_major = n0 / NXdlPack;
802 constexpr auto in_minor = n0 % NXdlPack;
803
804 constexpr index_t a_scale_offset =
805 a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0));
806 constexpr index_t b_scale_offset =
807 b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0));
808
809 static_assert(0 < ScalesPerXdlopsRunPerThread,
810 "Must have at least one scale per Xdlops "
811 "per Thread.");
812
815
816 // Pack scale_thread_buf into scale_thread_vec
818 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
819 a_scale_thread_bufs(I0)[Number<a_scale_offset + s>{}];
820 });
821
823 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
824 b_scale_thread_bufs(I0)[Number<b_scale_offset + s>{}];
825 });
826
829
830 static_for<0, KPack, 1>{}([&](auto ik) {
831 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
832 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
833 make_tuple(im_major, I0, im_minor, k0, ik))>{}];
834 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
835 b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
836 make_tuple(in_major, I0, in_minor, k0, ik))>{}];
837 });
838
839 using mfma_input_type_a =
840 typename vector_type<ComputeTypeA,
841 xdlops_gemm.K1PerXdlops / APackedSize>::type;
842
843 using mfma_input_type_b =
844 typename vector_type<ComputeTypeB,
845 xdlops_gemm.K1PerXdlops / BPackedSize>::type;
846
847 using mfma_scale_input_type_a =
849 using mfma_scale_input_type_b =
851
852 constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
853 make_tuple(im_major, in_major, im_minor, in_minor, 0));
854
855 // MFMA accumulation
856 xdlops_gemm.template Run<ik_minor * MXdlPack + im_minor,
857 ik_minor * NXdlPack + in_minor>(
858 a_thread_vec.template AsType<mfma_input_type_a>(),
859 a_scale_thread_vec.template AsType<mfma_scale_input_type_a>(),
860 b_thread_vec.template AsType<mfma_input_type_b>(),
861 b_scale_thread_vec.template AsType<mfma_scale_input_type_b>(),
862 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
863 });
864 });
865 });
866 }
867 }
868
869 // TODO: make this field protected when a_scale_thread_copy_ is moved
870 // here
873 Number<KRepeat / KXdlPack>{},
875
876 // TODO: make this field protected when b_scale_thread_copy_ is moved
877 // here
880 Number<KRepeat / KXdlPack>{},
882
883 protected:
884 using Base::a_thread_copy_;
885 using Base::a_thread_desc_;
886 using Base::b_thread_copy_;
887 using Base::b_thread_desc_;
888 using Base::c_thread_desc_;
889};
890
891} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ void block_sync_lds()
Definition synchronization.hpp:16
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops,(packed_size_v< ComputeTypeA > > 1||packed_size_v< ComputeTypeB > > 1)> HotLoopInstList
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:88
__host__ __device__ BlockwiseGemmXdlops_mx_pipeline_base(Tuple5 a_origin=CalculateAThreadOriginDataIndex(), Tuple5 b_origin=CalculateBThreadOriginDataIndex())
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:204
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_bufs, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, const AScaleGridDesc &a_scale_grid_desc, AScaleThreadTransfer &a_scale_thread_copy, const AScaleGridBuffer &a_scale_grid_buf, const BScaleGridDesc &b_scale_grid_desc, BScaleThreadTransfer &b_scale_thread_copy, const BScaleGridBuffer &b_scale_grid_buf, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp:261
BlockwiseGemmXdlops_mx_pipeline_base< ThreadBlockSize, ADataType, BDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack > Base
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp:102
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp:38
Unsigned representation of a conventional biased Float32 exponent.
Definition utility/e8m0.hpp:26
Definition functional2.hpp:33
Definition dtype_vector.hpp:10