mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp Source File

mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp Source File#

Composable Kernel: mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp Source File
mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
11
12namespace ck_tile {
13
14template <typename ADataType_,
15 typename BDataType_,
16 typename CDataType_,
17 typename BlockGemmShape_,
18 typename Traits_,
20 bool HasHotLoop_ = true,
22 typename ComputeDataType_ = ADataType_>
24 ADataType_,
25 CDataType_,
26 BlockGemmShape_,
27 Traits_,
28 Scheduler_,
29 HasHotLoop_,
30 TailNum_,
31 ComputeDataType_>
32{
33 using BlockGemmShape = BlockGemmShape_;
34
35 // using QuantType = BDataType_;
36
37 static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
38
39 static constexpr int ScaleGranularityK = 32;
40
41 static constexpr int ContinuousKPerThread = 32; // it's fixed for fp4
42 static constexpr int MXdlPack = 2; // it's fixed for fp4
43 static constexpr int NXdlPack = 2; // it's fixed for fp4
44 static constexpr int KXdlPack = 2;
45 // static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp * KXdlPack;
46 static constexpr index_t flatKPerWarp = 64 * ContinuousKPerThread;
47};
48
49template <typename Problem, typename PipelinePolicy = MXF4FlatmmPipelineAgBgCrPolicy>
51{
53
58
60 static_assert(sizeof(ADataType) >= sizeof(BDataType));
61
65
68
69 static constexpr auto config =
70 BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
71
72 using WG = remove_cvref_t<decltype(config.template at<0>())>;
73
74 static constexpr index_t DsWritePreIssue = 3; // default 2, ds write at MIter - 2
75 static constexpr index_t DsReadPreload = 4; // default 4 for MXFP4 (MXdlPack * KXdlPack)
76
77 static constexpr index_t BlockSize = Problem::kBlockSize;
78 static constexpr index_t WaveSize = get_warp_size();
79
80 static constexpr index_t kMPerBlock = BlockGemmShape::kM;
81 static constexpr index_t kNPerBlock = BlockGemmShape::kN;
82 static constexpr index_t kKPerBlock = BlockGemmShape::kK;
83
84 static constexpr index_t flatKPerWarp = Problem::flatKPerWarp;
85 static constexpr index_t flatNPerWarp = Problem::flatNPerWarp;
86
87 static constexpr index_t GetVectorSizeA() { return 32; } /* fixed for fp4 shuffle layout*/
88 static constexpr index_t GetVectorSizeB() { return 32; } /* fixed for fp4 shuffle layout*/
89 static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; }
90
91 static constexpr bool kPadM = Problem::kPadM;
92 static constexpr bool kPadN = Problem::kPadN;
93 static constexpr bool kPadK = Problem::kPadK;
94
95 // static constexpr index_t kLdsAlignmentInBytes = 16;
96 static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
97 static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
98
99 static constexpr auto I0 = number<0>();
100 static constexpr auto I1 = number<1>();
101 static constexpr auto I2 = number<2>();
102 static constexpr auto idxM = I0;
103 static constexpr auto idxN = I1;
104 static constexpr auto idxK = I2;
108
109 static constexpr index_t MWarp = config.template at<1>();
110 static constexpr index_t NWarp = config.template at<2>();
111
112 static constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM);
113 static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN);
114 static constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
115
118
121
124
125 static constexpr index_t MXdlPack = Problem::MXdlPack;
126 static constexpr index_t NXdlPack = Problem::NXdlPack;
127 static constexpr index_t KXdlPack = Problem::KXdlPack;
128
129 static constexpr index_t AK1 = Problem::VectorLoadSize / sizeof(ADataType) * APackedSize;
130 static constexpr index_t BK1 = Problem::VectorLoadSize / sizeof(BDataType) * BPackedSize;
131
135
136 static constexpr bool HasHotLoop = Problem::HasHotLoop;
137 static constexpr auto TailNum = Problem::TailNum;
138
139 static constexpr index_t mfma_per_wg = 1; // 950 only
140
141 static constexpr index_t dsread_per_wg =
142 WG::kM * WG::kK * sizeof(ADataType) / APackedSize / WaveSize / Problem::VectorLoadSize;
143 static_assert((WG::kM * WG::kK * sizeof(ADataType) / APackedSize / WaveSize) %
144 Problem::VectorLoadSize ==
145 0);
146
151 static constexpr index_t Aload_rep = dswrite_rep;
152
153 static constexpr index_t Bload_num_perK = kNPerBlock * WG::kK / NWarp / BK1 / WaveSize;
154 static constexpr index_t ScaleBload_K1 = NXdlPack * KXdlPack; // fixed for fp4
155 static constexpr index_t ScaleBload_num =
158 static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2;
159 static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter;
160
164
165 // For the basic gemm pipelien DoubleSmemBuffer set to be false naturally.
166 static constexpr bool DoubleSmemBuffer = false;
167
168 CK_TILE_HOST_DEVICE static constexpr auto
169 SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM)
170 {
171 // Init inst order
172 index_t max_data_inst = dsread_perM > load_perM
173 ? (dsread_perM > dswrite_perM ? dsread_perM : dswrite_perM)
174 : (load_perM > dswrite_perM ? load_perM : dswrite_perM);
175 index_t sum_data_inst = dsread_perM + load_perM + dswrite_perM;
176 index_t round_data_inst = (sum_data_inst + mfma_perM_perK - 1) / mfma_perM_perK;
177
178 index_t inst_order[NIterPerWarp * 10];
179 _Pragma("unroll") for(int idx = 0; idx < NIterPerWarp * 10; idx++) { inst_order[idx] = 0; }
180
181 index_t index = 0;
182 _Pragma("unroll") for(int j = 0; j < max_data_inst; j++)
183 {
184 if(dswrite_perM > j)
185 {
186 inst_order[index] = 1;
187 index++;
188 }
189 if(load_perM > j)
190 {
191 inst_order[index] = 2;
192 index++;
193 }
194 if(dsread_perM > j)
195 {
196 inst_order[index] = 3;
197 index++;
198 }
199 }
200
201 // Schedule IGLP
202 _Pragma("unroll") for(int j = 0; j < mfma_perM_perK; j++)
203 {
204 index_t inst_idx = 0;
205 if(j == 0)
206 ;
207 else if(j == 1)
208 inst_idx = mfma_perM_perK == 2 ? 1 : mfma_perM_perK - 2;
209 else if(j == 2)
210 inst_idx = mfma_perM_perK - 1;
211 else
212 inst_idx = mfma_perM_perK - j;
213
214 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
215
216 _Pragma("unroll") for(int r = 0; r < round_data_inst; r++)
217 {
218 if(r % 2 == 0)
219 {
220 if(inst_order[inst_idx + r * mfma_perM_perK] == 1)
221 {
222 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
223 }
224 if(inst_order[inst_idx + r * mfma_perM_perK] == 2)
225 {
226 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
227 }
228 if(inst_order[inst_idx + r * mfma_perM_perK] == 3)
229 {
230 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
231 }
232 }
233 else
234 {
235 if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 1)
236 {
237 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
238 }
239 if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 2)
240 {
241 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
242 }
243 if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 3)
244 {
245 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
246 }
247 }
248 }
249 }
250 }
251
252 CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler()
253 {
254 // Keypoint of pipeline optimize is workload balance in time
255 // instruction schedule example(128X256X256, 1X4, 16X16X128):
256 // Iter MNK MFMA ds_read ds_write A_load b_load
257 // -1 M6N0: 57 - 8 - -
258 // -1 M6N1: 58 1 - - -
259 // -1 M6N2: 59 - - 7 -
260 // -1 M6N3: 60 2 - - -
261 // -1 M7N0: 61 - - - -
262 // -1 M7N1: 62 3 - - -
263 // -1 M7N2: 63 - - 8 -
264 // -1 M7N3: 64 4 - - -
265 // 0 M0N0K0: 1 - - - 1
266 // 0 M0N1: 2 5 - - -
267 // 0 M0N2: 3 - - - 2
268 // 0 M0N3: 4 6 - - -
269 // 0 M1N0: 5 - - - 3
270 // 0 M1N1: 6 7 - - -
271 // 0 M1N2: 7 - - - 4
272 // 0 M1N3: 8 8 - - -
273 // 0 M2N0: 9 - - - 5
274 // 0 M2N1: 10 9 - - -
275 // 0 M2N2: 11 - - - 6
276 // 0 M2N3: 12 10 - - -
277 // 0 M3N0: 13 - 1 - 7
278 // 0 M3N1: 14 11 - - -
279 // 0 M3N2: 15 - - - 8
280 // 0 M3N3: 16 12 - - -
281 // 0 M4N0: 17 - 2 - -
282 // 0 M4N1: 18 13 - - -
283 // 0 M4N2: 19 - - 1 -
284 // 0 M4N3: 20 14 - - -
285 // 0 M5N0: 21 - 3 - -
286 // 0 M5N1: 22 15 - - -
287 // 0 M5N2: 23 - - 2 -
288 // 0 M5N3: 24 16 - - -
289 // 0 M6N0: 25 - 4 - -
290 // 0 M6N1: 26 17 - - -
291 // 0 M6N2: 27 - - 3 -
292 // 0 M6N3: 28 18 - - -
293 // 0 M7N0: 29 - - - -
294 // 0 M7N1: 30 19 - - -
295 // 0 M7N2: 31 - - 4 -
296 // 0 M7N3: 32 20 - - -
297 // 0 M0N0K1: 33 - - - 9
298 // 0 M0N1: 34 21 - - -
299 // 0 M0N2: 35 - - - 10
300 // 0 M0N3: 36 22 - - -
301 // 0 M1N0: 37 - - - 11
302 // 0 M1N1: 38 23 - - -
303 // 0 M1N2: 39 - - - 12
304 // 0 M1N3: 40 24 - - -
305 // 0 M2N0: 41 - - - 13
306 // 0 M2N1: 42 25 - - -
307 // 0 M2N2: 43 - - - 14
308 // 0 M2N3: 44 26 - - -
309 // 0 M3N0: 45 - 5 - 15
310 // 0 M3N1: 46 27 - - -
311 // 0 M3N2: 47 - - - 16
312 // 0 M3N3: 48 28 - - -
313 // 0 M4N0: 49 - 6 - -
314 // 0 M4N1: 50 29 - - -
315 // 0 M4N2: 51 - - 5 -
316 // 0 M4N3: 52 30 - - -
317 // 0 M5N0: 53 - 7 - -
318 // 0 M5N1: 54 31 - - -
319 // 0 M5N2: 55 - - 6 -
320 // 0 M5N3: 56 32 - - -
321 // 0 M6N0: 57 - 8 - -
322 // 0 M6N1: 58 1 - - -
323 // 0 M6N2: 59 - - 7 -
324 // 0 M6N3: 60 2 - - -
325 // 0 M7N0: 61 - - - -
326 // 0 M7N1: 62 3 - - -
327 // 0 M7N2: 63 - - 8 -
328 // 0 M7N3: 64 4 - - -
329
330 _Pragma("unroll") for(int kIter = 0; kIter < KIterPerWarp; kIter++)
331 {
332 _Pragma("unroll") for(int mIter = 0; mIter < MIterPerWarp; mIter++)
333 {
334 index_t dsread_perM = 0;
335 index_t dswrite_perM = 0;
336 index_t load_perM = 0;
337
338 // Calculate ds_read number per M
339 dsread_perM = dsread_per_wg;
340
341 // Calculate ds_write number per M
342 if(mIter == 0)
343 {
344 dswrite_perM =
347 : 0;
348 }
349 else if(mIter >= MIterPerWarp - DsWritePreIssue + 1)
350 {
351 dswrite_perM = 0;
352 }
353 else
354 {
355 dswrite_perM = (dswrite_num_perK -
356 (MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0
358 : 0;
359 }
360 // Add ds write when ds write data > needed
361 if(dswrite_num_perK == 0 && kIter == (KIterPerWarp - 1 - dswrite_kIter))
362 {
363 if(mIter == MIterPerWarp - 1 - dswrite_mIter)
364 dswrite_perM = 1;
365 }
366
367 // Calculate buffer_load number per M
368 if(mIter < HalfMIter)
369 {
370 load_perM =
371 ((Aload_num_perK - (MIterPerWarp - 1 - mIter) * Aload_rep) > 0 ? Aload_rep
372 : 0) +
373 ((Bload_num_perK - (HalfMIter - 1 - mIter) * Bload_rep) > 0 ? Bload_rep
374 : 0);
375 }
376 else
377 {
378 load_perM = (Aload_num_perK - (MIterPerWarp - 1 - mIter) * Aload_rep) > 0
379 ? Aload_rep
380 : 0;
381 }
382 // if((kIter % KPerScaleLoad == 0) && (mIter == 0))
383 // {
384 // load_perM = load_perM + 1;
385 // }
386 SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
387 }
388 }
389 // Add Aload when Aload data > needed
390 if(Aload_num_perK == 0)
391 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
392 __builtin_amdgcn_sched_barrier(0);
393 }
394
396 {
397 _Pragma("unroll") for(int kIter = 0; kIter < KIterPerWarp; kIter++)
398 {
399 _Pragma("unroll") for(int mIter = 0; mIter < MIterPerWarp; mIter++)
400 {
401 index_t dsread_perM = 0;
402 index_t dswrite_perM = 0;
403 index_t load_perM = 0;
404
405 // Calculate ds_read number per M
406 dsread_perM = dsread_per_wg;
407
408 // Calculate ds_write number per M
409 if(mIter == 0)
410 {
411 dswrite_perM =
414 : 0;
415 }
416 else if(mIter >= MIterPerWarp - DsWritePreIssue + 1)
417 {
418 dswrite_perM = 0;
419 }
420 else
421 {
422 dswrite_perM = (dswrite_num_perK -
423 (MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0
425 : 0;
426 }
427 // Add ds write when ds write data > needed
428 if(dswrite_num_perK == 0 && kIter == (KIterPerWarp - 1 - dswrite_kIter))
429 {
430 if(mIter == MIterPerWarp - 1 - dswrite_mIter)
431 dswrite_perM = 1;
432 }
433
434 // Calculate buffer_load number per M
435 if(mIter < HalfMIter)
436 {
437 load_perM =
438 ((Bload_num_perK - (HalfMIter - 1 - mIter) * Bload_rep) > 0 ? Bload_rep
439 : 0);
440 }
441 SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
442 }
443 }
444 __builtin_amdgcn_sched_barrier(0);
445 }
446
448 {
449 _Pragma("unroll") for(int kIter = 0; kIter < KIterPerWarp; kIter++)
450 {
451 _Pragma("unroll") for(int mIter = 0; mIter < MIterPerWarp; mIter++)
452 {
453 index_t dsread_perM = 0;
454 index_t dswrite_perM = 0;
455 index_t load_perM = 0;
456
457 // Calculate ds_read number per M
458 if((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
459 dsread_perM = dsread_per_wg;
460
461 SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
462 }
463 }
464 // __builtin_amdgcn_sched_barrier(0);
465 }
466
468 {
469 return PipelinePolicy::template MakeADramTileDistribution<Problem>();
470 }
471
472 template <typename ADramBlockWindowTmp,
473 typename AElementFunction,
474 typename BFlatBlockWindowTmp,
475 typename ScaleADramBlockWindowTmp,
476 typename ScaleBDramBlockWindowTmp>
477 CK_TILE_HOST_DEVICE auto operator()(ADramBlockWindowTmp a_copy_dram_window,
478 const AElementFunction& a_element_func,
479 const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
480 const ScaleADramBlockWindowTmp& scale_a_window,
481 const ScaleBDramBlockWindowTmp& scale_b_window,
482 index_t num_loop,
483 void* p_smem_ping,
484 void* p_smem_pong) const
485 {
486#ifndef __gfx950__
487 static_assert(false, "Only gfx950 is supported for MXFP4 flatmm pipeline now.");
488#endif
489 static_assert(
490 std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>>,
491 "wrong!");
492
493 static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}],
494 "wrong!");
495 static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
496 "wrong!");
497
498 constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
499 const index_t iMWarp = get_warp_id() / NWarp;
500 // const index_t iNWarp = get_warp_id() % NWarp;
501
502 using CWarpDstr = typename WG::CWarpDstr;
503 using CWarpTensor = typename WG::CWarpTensor;
504
505 constexpr auto c_warp_y_lengths =
506 to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
507 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
508
509 __builtin_amdgcn_sched_barrier(0);
510
511 // A tile in LDS
512 ADataType* p_a_lds_ping = static_cast<ADataType*>(p_smem_ping);
513 ADataType* p_a_lds_pong = static_cast<ADataType*>(p_smem_pong);
514
515 constexpr auto a_lds_block_desc =
516 PipelinePolicy::template MakeMXFP4_ALdsBlockDescriptor<Problem>();
517
518 auto a_lds_block_ping =
519 make_tensor_view<address_space_enum::lds>(p_a_lds_ping, a_lds_block_desc);
520 auto a_lds_block_pong =
521 make_tensor_view<address_space_enum::lds>(p_a_lds_pong, a_lds_block_desc);
522
523 auto a_copy_lds_window_ping =
524 make_tile_window(a_lds_block_ping,
526 {0, 0},
527 PipelinePolicy::template MakeADramTileDistribution<Problem>());
528 auto a_copy_lds_window_pong =
529 make_tile_window(a_lds_block_pong,
531 {0, 0},
532 PipelinePolicy::template MakeADramTileDistribution<Problem>());
533
534 // ping-pong window for A LDS
535 auto a_warp_window_ping_tmp =
536 make_tile_window(a_lds_block_ping,
538 {iMWarp * WG::kM, 0},
539 PipelinePolicy::template MakeMXF4_ALDS_TileDistribution<Problem>());
540 auto a_warp_window_pong_tmp =
541 make_tile_window(a_lds_block_pong,
543 {iMWarp * WG::kM, 0},
544 PipelinePolicy::template MakeMXF4_ALDS_TileDistribution<Problem>());
545
547 statically_indexed_array<decltype(a_warp_window_ping_tmp), KIterPerWarp>,
549 a_warp_windows_ping;
550
552 statically_indexed_array<decltype(a_warp_window_pong_tmp), KIterPerWarp>,
554 a_warp_windows_pong;
555
556 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
557 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
558 a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
559 a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
560
561 auto packed_m_idx = mIter / number<MXdlPack>{};
562 auto packed_m_rank = mIter % number<MXdlPack>{};
563
565 a_warp_windows_ping(mIter)(kIter),
566 {packed_m_idx * MXdlPack * MPerBlockPerIter + packed_m_rank * WG::kM,
567 kIter * KPerBlockPerIter});
569 a_warp_windows_pong(mIter)(kIter),
570 {packed_m_idx * MXdlPack * MPerBlockPerIter + packed_m_rank * WG::kM,
571 kIter * KPerBlockPerIter});
572 });
573 });
574
575 // Block GEMM
576 auto block_flatmm = BlockFlatmm();
577 // Acc register tile
578 auto c_block_tile = block_flatmm.MakeCBlockTile();
579
580 // B flat DRAM window for load
581 auto b_flat_distribution =
582 PipelinePolicy::template MakeMXFP4_BFlatDramTileDistribution<Problem>();
583
584 auto b_flat_dram_window = make_tile_window(
585 b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
587 b_flat_dram_block_window_tmp.get_window_origin(),
588 b_flat_distribution);
589
590 using MXFP4_B_Buffer = decltype(load_tile(b_flat_dram_window));
591 // use v4i32 as the data type between basicblock to avoid unpack and repack operation.
592 using V4UInt_B_Buffer = thread_buffer<uint32_t, 4>;
593 union UnionBuf
594 {
595 V4UInt_B_Buffer u = 0;
596 MXFP4_B_Buffer mxfp4;
597 } ub;
598
599 // pingpong buffer for B
601 statically_indexed_array<decltype(b_flat_dram_window), KIterPerWarp>,
603 b_flat_dram_windows;
606 b_warp_tensor_ping;
609 b_warp_tensor_pong;
610
611 // pingpong buffer for Scale A and Scale B
612 auto scale_a_dram_window = make_tile_window(
613 scale_a_window.get_bottom_tensor_view(),
614 make_tuple(number<MWarp * WG::kM>{}, number<64 / WG::kM>{}),
615 scale_a_window.get_window_origin(),
616 PipelinePolicy::template MakeMXFP4_ScaleA_FlatDramTileDistribution<Problem>());
617
618 auto scale_b_dram_window = make_tile_window(
619 scale_b_window.get_bottom_tensor_view(),
620 make_tuple(number<NWarp * WG::kN>{}, number<64 / WG::kN>{}),
621 scale_b_window.get_window_origin(),
622 PipelinePolicy::template MakeMXFP4_ScaleB_DramTileDistribution<Problem>());
623
624 // ping pong buffer for scale A
626 statically_indexed_array<decltype(scale_a_dram_window), KIterPerWarp / KXdlPack>,
628 scale_a_dram_windows;
629 statically_indexed_array<statically_indexed_array<decltype(load_tile(scale_a_dram_window)),
632 scale_a_tile_tensor_ping;
633 statically_indexed_array<statically_indexed_array<decltype(load_tile(scale_a_dram_window)),
636 scale_a_tile_tensor_pong;
637
638 // ping pong buffer for scale B
640 statically_indexed_array<decltype(scale_b_dram_window), KIterPerWarp / KXdlPack>,
642 scale_b_dram_windows;
643 statically_indexed_array<statically_indexed_array<decltype(load_tile(scale_b_dram_window)),
646 scale_b_tile_tensor_ping;
647 statically_indexed_array<statically_indexed_array<decltype(load_tile(scale_b_dram_window)),
650 scale_b_tile_tensor_pong;
651
652 // HEAD
653 // Prefetch A0
654 auto a_block_tile = load_tile(a_copy_dram_window);
655 // move A window to next k
656 move_tile_window(a_copy_dram_window, {0, kKPerBlock});
657
658 // prefetch B
659 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
660 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
661 auto packed_n_idx = nIter / number<NXdlPack>{};
662 auto packed_n_rank = nIter % number<NXdlPack>{};
663
664 b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
665 move_tile_window(b_flat_dram_windows(nIter)(kIter),
666 {packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
667 kIter * KFlatPerBlockPerIter});
668
669 ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
670 b_warp_tensor_ping(nIter)(kIter) = ub.u;
671 });
672 });
673 // move B window to next flat K
674 move_tile_window(b_flat_dram_window, {0, KIterPerWarp * KFlatPerBlockPerIter});
675
676 // prefetch Scale A
677 static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
678 static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
679 scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window;
680 move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack),
681 {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)});
682
683 scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) =
684 load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack));
685 });
686 });
687 // move Scale A window to next K
688 move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
689
690 // prefetch Scale B
691 static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
692 static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
693 scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window;
694 move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack),
695 {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)});
696
697 scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) =
698 load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack));
699 });
700 });
701 // move Scale B window to next K
702 move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
703
704 // A_Lds_TileDist may differ with ADramTileDistribution
705 auto a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile);
706 store_tile(a_copy_lds_window_ping, a_block_tile_transformed);
707
708 __builtin_amdgcn_sched_barrier(0);
709
710 // Prefetch A1
711 a_block_tile = load_tile(a_copy_dram_window);
712 // move A window to next k
713 move_tile_window(a_copy_dram_window, {0, kKPerBlock});
714
715 // initialize C
716 tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
717
719
720 using MXFP4_A_Buffer_ping =
721 decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{})));
722 // use v4i32 as the data type between basicblock to avoid unpack and repack operation.
723 using V4UInt_A_Buffer = thread_buffer<uint32_t, 4>;
724 union UnionBuf_A_ping
725 {
726 V4UInt_A_Buffer u = 0;
727 MXFP4_A_Buffer_ping mxfp4;
728 } ua_ping;
729
730 using MXFP4_A_Buffer_pong =
731 decltype(load_tile(a_warp_windows_pong(number<0>{})(number<0>{})));
732 union UnionBuf_A_pong
733 {
734 V4UInt_A_Buffer u = 0;
735 MXFP4_A_Buffer_pong mxfp4;
736 } ua_pong;
737
738 // preload A00,A10... from lds
740
741 static_for<0, m_preload, 1>{}([&](auto loadIter) {
742 constexpr auto mIter = loadIter % MXdlPack;
743 constexpr auto kIter = loadIter / MXdlPack;
744
745 ua_ping.mxfp4 = load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
746 a_warp_tensor(loadIter) = ua_ping.u;
747 });
748 __builtin_amdgcn_sched_barrier(0);
749
750 // MAIN LOOP
751 index_t iCounter = (num_loop - 1) / 2;
752 while(iCounter > 0)
753 {
754 // prefetch B(2i+1)
755 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
756 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
757 auto packed_n_idx = nIter / number<NXdlPack>{};
758 auto packed_n_rank = nIter % number<NXdlPack>{};
759
760 b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
762 b_flat_dram_windows(nIter)(kIter),
763 {packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
764 kIter * KFlatPerBlockPerIter});
765
766 ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
767 b_warp_tensor_pong(nIter)(kIter) = ub.u;
768 });
769 });
770
771 // prefetch Scale A and Scale B (2i+1)
772 static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
773 static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
774 scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window;
775 move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack),
776 {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)});
777
778 scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) =
779 load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack));
780 });
781 });
782
783 static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
784 static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
785 scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window;
786 move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack),
787 {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)});
788
789 scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) =
790 load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack));
791 });
792 });
793
794 // Prefill A(2i+1)
795 a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile);
796 store_tile(a_copy_lds_window_pong, a_block_tile_transformed);
797
798 // Prefetch A(2i+2)
799 a_block_tile = load_tile(a_copy_dram_window);
800 // move A window to next k
801 move_tile_window(a_copy_dram_window, {0, kKPerBlock});
802
803 // GEMM 2i
804 static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
805 static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
806 static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
807 static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
808 static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
809 constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
810 static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
811 // read C warp tensor from C block tensor
812 CWarpTensor c_warp_tensor;
813 c_warp_tensor.get_thread_buffer() =
814 c_block_tile.get_y_sliced_thread_data(
816 sequence<mIter_pack * MXdlPack + imxdl,
817 nIter_pack * NXdlPack + inxdl>{},
818 c_warp_y_index_zeros),
819 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
820
821 UnionBuf_A_ping ua_compute;
822 ua_compute.u = a_warp_tensor(number<AwarpIter>{});
823
824 UnionBuf ub_compute;
825 ub_compute.u =
826 b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
827 kIter_pack * number<KXdlPack>{} + ikxdl);
828 // warp GEMM
829 WG{}.template
830 operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
831 c_warp_tensor,
832 ua_compute.mxfp4,
833 ub_compute.mxfp4,
834 scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
835 .get_thread_buffer()[0],
836 scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
837 .get_thread_buffer()[0]);
838
839 // write C warp tensor into C block tensor
840 c_block_tile.set_y_sliced_thread_data(
841 merge_sequences(sequence<mIter_pack * MXdlPack + imxdl,
842 nIter_pack * NXdlPack + inxdl>{},
843 c_warp_y_index_zeros),
844 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
845 c_warp_tensor.get_thread_buffer());
846 });
847 // preload next A from lds
848 constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
849 (kIter_pack * KXdlPack + ikxdl) * 2 +
850 (mIter_pack * MXdlPack + imxdl) / 2 * 4 +
851 m_preload;
852 if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
853 (nIter_pack == NIterPerWarp / NXdlPack - 1))
854 {
855 constexpr auto AmIter = addr % 2 + addr / 4 * 2;
856 constexpr auto AkIter = addr / 2 % 2;
857 ua_ping.mxfp4 = load_tile(
858 a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
859 a_warp_tensor(number<AwarpIter>{}) = ua_ping.u;
860 }
861
862 // barrier
863 if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 &&
864 mIter_pack * MXdlPack + imxdl == MIter_2nd_last)
865 {
867 }
868 });
869 });
870 });
871 });
872 });
873
874 // move B window to next flat K
875 move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
876 move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
877 move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
878
879 static_for<0, m_preload, 1>{}([&](auto loadIter) {
880 constexpr auto mIter = loadIter % MXdlPack;
881 constexpr auto kIter = loadIter / MXdlPack;
882 ua_pong.mxfp4 = load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
883 a_warp_tensor(loadIter) = ua_pong.u; // reload a_warp_tensor with pong buffer
884 });
886
887 // Next K
888
889 // prefetch B(2i+2)
890 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
891 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
892 auto packed_n_idx = nIter / number<NXdlPack>{};
893 auto packed_n_rank = nIter % number<NXdlPack>{};
894
895 b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
897 b_flat_dram_windows(nIter)(kIter),
898 {packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
899 kIter * KFlatPerBlockPerIter});
900
901 ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
902 b_warp_tensor_ping(nIter)(kIter) = ub.u;
903 });
904 });
905
906 // prefetch Scale A and Scale B (2i+2)
907 static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
908 static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
909 scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window;
910 move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack),
911 {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)});
912
913 scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) =
914 load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack));
915 });
916 });
917
918 static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
919 static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
920 scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window;
921 move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack),
922 {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)});
923
924 scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) =
925 load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack));
926 });
927 });
928
929 // Prefill A(2i+2)
930 a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile);
931 store_tile(a_copy_lds_window_ping, a_block_tile_transformed);
932
933 // Prefetch A(2i+3)
934 a_block_tile = load_tile(a_copy_dram_window);
935 // move A window to next k
936 move_tile_window(a_copy_dram_window, {0, kKPerBlock});
937
938 // GEMM 2i+1
939 static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
940 static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
941 static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
942 static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
943 static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
944 constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
945 static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
946 // read C warp tensor from C block tensor
947 CWarpTensor c_warp_tensor;
948 c_warp_tensor.get_thread_buffer() =
949 c_block_tile.get_y_sliced_thread_data(
951 sequence<mIter_pack * MXdlPack + imxdl,
952 nIter_pack * NXdlPack + inxdl>{},
953 c_warp_y_index_zeros),
954 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
955
956 UnionBuf_A_pong ua_compute;
957 ua_compute.u = a_warp_tensor(number<AwarpIter>{});
958
959 UnionBuf ub_compute;
960 ub_compute.u =
961 b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
962 kIter_pack * number<KXdlPack>{} + ikxdl);
963
964 // warp GEMM
965 WG{}.template
966 operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
967 c_warp_tensor,
968 ua_compute.mxfp4,
969 ub_compute.mxfp4,
970 scale_a_tile_tensor_pong(mIter_pack)(kIter_pack)
971 .get_thread_buffer()[0], // scale A
972 scale_b_tile_tensor_pong(nIter_pack)(kIter_pack)
973 .get_thread_buffer()[0]); // scale B
974
975 // write C warp tensor into C block tensor
976 c_block_tile.set_y_sliced_thread_data(
977 merge_sequences(sequence<mIter_pack * MXdlPack + imxdl,
978 nIter_pack * NXdlPack + inxdl>{},
979 c_warp_y_index_zeros),
980 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
981 c_warp_tensor.get_thread_buffer());
982 });
983 // preload next A from lds
984 constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
985 (kIter_pack * KXdlPack + ikxdl) * 2 +
986 (mIter_pack * MXdlPack + imxdl) / 2 * 4 +
987 m_preload;
988 if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
989 (nIter_pack == NIterPerWarp / NXdlPack - 1))
990 {
991 constexpr auto AmIter = addr % 2 + addr / 4 * 2;
992 constexpr auto AkIter = addr / 2 % 2;
993 ua_pong.mxfp4 = load_tile(
994 a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
995 a_warp_tensor(number<AwarpIter>{}) = ua_pong.u;
996 }
997
998 // barrier
999 if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 &&
1000 mIter_pack * MXdlPack + imxdl == MIter_2nd_last)
1001 {
1003 }
1004 });
1005 });
1006 });
1007 });
1008 });
1009
1010 // move B window to next flat K
1011 move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
1012 move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
1013 move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
1014
1015 static_for<0, m_preload, 1>{}([&](auto loadIter) {
1016 constexpr auto mIter = loadIter % MXdlPack;
1017 constexpr auto kIter = loadIter / MXdlPack;
1018 ua_ping.mxfp4 = load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
1019 a_warp_tensor(loadIter) = ua_ping.u; // reload a_warp_tensor with ping buffer
1020 });
1022
1023 iCounter--;
1024 }
1025
1026 // TAIL
1027 if constexpr(TailNum == TailNumber::Even)
1028 {
1029 // prefetch B(loopK)
1030 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
1031 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
1032 auto packed_n_idx = nIter / number<NXdlPack>{};
1033 auto packed_n_rank = nIter % number<NXdlPack>{};
1034
1035 b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
1036
1038 b_flat_dram_windows(nIter)(kIter),
1039 {packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
1040 kIter * KFlatPerBlockPerIter});
1041
1042 ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
1043 b_warp_tensor_pong(nIter)(kIter) = ub.u;
1044 });
1045 });
1046
1047 // prefetch Scale A and Scale B (2i+1)
1048 static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
1049 static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
1050 scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window;
1051 move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack),
1052 {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)});
1053
1054 scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) =
1055 load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack));
1056 });
1057 });
1058
1059 static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
1060 static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
1061 scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window;
1062 move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack),
1063 {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)});
1064
1065 scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) =
1066 load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack));
1067 });
1068 });
1069
1070 // Prefill A(loopK)
1071 a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile);
1072 store_tile(a_copy_lds_window_pong, a_block_tile_transformed);
1073
1074 // GEMM loopK-1
1075 static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
1076 static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
1077 static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
1078 static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
1079 static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
1080 constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
1081 static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
1082 // read C warp tensor from C block tensor
1083 CWarpTensor c_warp_tensor;
1084 c_warp_tensor.get_thread_buffer() =
1085 c_block_tile.get_y_sliced_thread_data(
1087 sequence<mIter_pack * MXdlPack + imxdl,
1088 nIter_pack * NXdlPack + inxdl>{},
1089 c_warp_y_index_zeros),
1090 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
1091
1092 UnionBuf_A_ping ua_compute;
1093 ua_compute.u = a_warp_tensor(number<AwarpIter>{});
1094
1095 UnionBuf ub_compute;
1096 ub_compute.u =
1097 b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
1098 kIter_pack * number<KXdlPack>{} + ikxdl);
1099
1100 // warp GEMM
1101 WG{}.template
1102 operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
1103 c_warp_tensor,
1104 ua_compute.mxfp4,
1105 ub_compute.mxfp4,
1106 scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
1107 .get_thread_buffer()[0], // scale A
1108 scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
1109 .get_thread_buffer()[0]); // scale B
1110
1111 // write C warp tensor into C block tensor
1112 c_block_tile.set_y_sliced_thread_data(
1113 merge_sequences(sequence<mIter_pack * MXdlPack + imxdl,
1114 nIter_pack * NXdlPack + inxdl>{},
1115 c_warp_y_index_zeros),
1116 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
1117 c_warp_tensor.get_thread_buffer());
1118 });
1119 // preload next A from lds
1120 constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
1121 (kIter_pack * KXdlPack + ikxdl) * 2 +
1122 (mIter_pack * MXdlPack + imxdl) / 2 * 4 +
1123 m_preload;
1124 if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
1125 (nIter_pack == NIterPerWarp / NXdlPack - 1))
1126 {
1127 constexpr auto AmIter = addr % 2 + addr / 4 * 2;
1128 constexpr auto AkIter = addr / 2 % 2;
1129 ua_ping.mxfp4 = load_tile(
1130 a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
1131 a_warp_tensor(number<AwarpIter>{}) = ua_ping.u;
1132 }
1133
1134 // barrier
1135 if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 &&
1136 mIter_pack * MXdlPack + imxdl == MIter_2nd_last)
1137 {
1139 }
1140 });
1141 });
1142 });
1143 });
1144 });
1145
1146 static_for<0, m_preload, 1>{}([&](auto loadIter) {
1147 constexpr auto mIter = loadIter % MXdlPack;
1148 constexpr auto kIter = loadIter / MXdlPack;
1149 ua_pong.mxfp4 = load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
1150 a_warp_tensor(loadIter) = ua_pong.u; // reload a_warp_tensor with pong buffer
1151 });
1152
1154
1155 // GEMM loopK
1156 static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
1157 static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
1158 static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
1159 static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
1160 static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
1161 constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
1162 static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
1163 // read C warp tensor from C block tensor
1164 CWarpTensor c_warp_tensor;
1165 c_warp_tensor.get_thread_buffer() =
1166 c_block_tile.get_y_sliced_thread_data(
1168 sequence<mIter_pack * MXdlPack + imxdl,
1169 nIter_pack * NXdlPack + inxdl>{},
1170 c_warp_y_index_zeros),
1171 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
1172
1173 UnionBuf_A_pong ua_compute;
1174 ua_compute.u = a_warp_tensor(number<AwarpIter>{});
1175
1176 UnionBuf ub_compute;
1177 ub_compute.u =
1178 b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
1179 kIter_pack * number<KXdlPack>{} + ikxdl);
1180 // warp GEMM
1181 WG{}.template
1182 operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
1183 c_warp_tensor,
1184 ua_compute.mxfp4,
1185 ub_compute.mxfp4,
1186 scale_a_tile_tensor_pong(mIter_pack)(kIter_pack)
1187 .get_thread_buffer()[0], // scale A
1188 scale_b_tile_tensor_pong(nIter_pack)(kIter_pack)
1189 .get_thread_buffer()[0]); // scale B
1190
1191 // write C warp tensor into C block tensor
1192 c_block_tile.set_y_sliced_thread_data(
1193 merge_sequences(sequence<mIter_pack * MXdlPack + imxdl,
1194 nIter_pack * NXdlPack + inxdl>{},
1195 c_warp_y_index_zeros),
1196 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
1197 c_warp_tensor.get_thread_buffer());
1198 });
1199 // preload next A from lds
1200 constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
1201 (kIter_pack * KXdlPack + ikxdl) * 2 +
1202 (mIter_pack * MXdlPack + imxdl) / 2 * 4 +
1203 m_preload;
1204 if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
1205 (nIter_pack == NIterPerWarp / NXdlPack - 1))
1206 {
1207 constexpr auto AmIter = addr % 2 + addr / 4 * 2;
1208 constexpr auto AkIter = addr / 2 % 2;
1209 ua_pong.mxfp4 = load_tile(
1210 a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
1211 a_warp_tensor(number<AwarpIter>{}) = ua_pong.u;
1212 }
1213
1214 // barrier
1215 if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 &&
1216 mIter_pack * MXdlPack + imxdl == MIter_2nd_last)
1217 {
1219 }
1220 });
1221 });
1222 });
1223 });
1224 });
1226 }
1227 else if constexpr(TailNum == TailNumber::Odd)
1228 {
1229 // GEMM loopK
1230 static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
1231 static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
1232 static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
1233 static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
1234 static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
1235 constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
1236 static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
1237 // read C warp tensor from C block tensor
1238 CWarpTensor c_warp_tensor;
1239 c_warp_tensor.get_thread_buffer() =
1240 c_block_tile.get_y_sliced_thread_data(
1242 sequence<mIter_pack * MXdlPack + imxdl,
1243 nIter_pack * NXdlPack + inxdl>{},
1244 c_warp_y_index_zeros),
1245 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
1246
1247 UnionBuf_A_ping ua_compute;
1248 ua_compute.u = a_warp_tensor(number<AwarpIter>{});
1249
1250 UnionBuf ub_compute;
1251 ub_compute.u =
1252 b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
1253 kIter_pack * number<KXdlPack>{} + ikxdl);
1254
1255 // warp GEMM
1256 WG{}.template
1257 operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
1258 c_warp_tensor,
1259 ua_compute.mxfp4,
1260 ub_compute.mxfp4,
1261 scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
1262 .get_thread_buffer()[0], // scale A
1263 scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
1264 .get_thread_buffer()[0]); // scale B
1265
1266 // write C warp tensor into C block tensor
1267 c_block_tile.set_y_sliced_thread_data(
1268 merge_sequences(sequence<mIter_pack * MXdlPack + imxdl,
1269 nIter_pack * NXdlPack + inxdl>{},
1270 c_warp_y_index_zeros),
1271 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
1272 c_warp_tensor.get_thread_buffer());
1273 });
1274 // preload next A from lds
1275 constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
1276 (kIter_pack * KXdlPack + ikxdl) * 2 +
1277 (mIter_pack * MXdlPack + imxdl) / 2 * 4 +
1278 m_preload;
1279 if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
1280 (nIter_pack == NIterPerWarp / NXdlPack - 1))
1281 {
1282 constexpr auto AmIter = addr % 2 + addr / 4 * 2;
1283 constexpr auto AkIter = addr / 2 % 2;
1284 ua_ping.mxfp4 = load_tile(
1285 a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
1286 a_warp_tensor(number<AwarpIter>{}) = ua_ping.u;
1287 }
1288
1289 // barrier
1290 if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 &&
1291 mIter_pack * MXdlPack + imxdl == MIter_2nd_last)
1292 {
1294 }
1295 });
1296 });
1297 });
1298 });
1299 });
1301 }
1302
1303 return c_block_tile;
1304 }
1305
1306 template <typename ADramBlockWindowTmp,
1307 typename BFlatBlockWindowTmp,
1308 typename ScaleADramBlockWindowTmp,
1309 typename ScaleBDramBlockWindowTmp>
1310 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
1311 const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
1312 const ScaleADramBlockWindowTmp& scale_a_flat_window_tmp,
1313 const ScaleBDramBlockWindowTmp& scale_b_flat_window_tmp,
1314 index_t num_loop,
1315 void* p_smem_ping,
1316 void* p_smem_pong) const
1317 {
1318 return operator()(
1319 a_dram_block_window_tmp,
1320 [](const ADataType & a) { return a; },
1321 b_flat_dram_block_window_tmp,
1322 scale_a_flat_window_tmp,
1323 scale_b_flat_window_tmp,
1324 num_loop,
1325 p_smem_ping,
1326 p_smem_pong);
1327 }
1328};
1329
1330} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
TailNumber
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:21
@ Even
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:24
@ Odd
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:23
@ Full
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:39
CK_TILE_DEVICE index_t get_warp_id(bool_constant< ReturnSgpr >={})
Definition arch.hpp:104
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
CK_TILE_HOST_DEVICE constexpr auto merge_sequences(Seqs...)
Definition tile/core/container/sequence.hpp:826
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple< number< Is >... >)
Definition tile/core/container/sequence.hpp:1055
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition tile/core/container/sequence.hpp:1026
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
GemmPipelineScheduler
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:14
@ Intrawave
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:16
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
tuple_array< T, N > statically_indexed_array
Definition tile/core/container/statically_indexed_array.hpp:16
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:47
Definition gemm_pipeline_problem.hpp:323
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:51
static constexpr index_t kMPerBlock
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:80
static constexpr index_t BlockSize
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:77
static constexpr auto idxK
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:104
FlatmmPipelineAGmemBGmemCRegV1< Problem, PipelinePolicy > Underlying
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:52
static constexpr index_t dsread_per_wg
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:141
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const BFlatBlockWindowTmp &b_flat_dram_block_window_tmp, const ScaleADramBlockWindowTmp &scale_a_flat_window_tmp, const ScaleBDramBlockWindowTmp &scale_b_flat_window_tmp, index_t num_loop, void *p_smem_ping, void *p_smem_pong) const
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:1310
static constexpr index_t mfma_perM_perK
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:161
static constexpr index_t BK1
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:130
remove_cvref_t< decltype(PipelinePolicy::template GetBlockFlatmm< Problem >())> BlockFlatmm
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:66
static constexpr index_t DsWritePreIssue
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:74
ADataType ComputeType
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:59
static constexpr auto idxM
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:102
static constexpr index_t KPerBlockPerIter
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:120
static constexpr index_t APackedSize
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:122
static CK_TILE_HOST_DEVICE constexpr auto SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM)
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:169
static constexpr auto config
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:69
static constexpr index_t KPerScaleLoad
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:157
static constexpr index_t GetVectorSizeC()
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:89
remove_cvref_t< typename BlockGemmShape::BlockWarps > BlockWarps
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:106
static constexpr index_t KIterPerWarp
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:114
static constexpr index_t GetVectorSizeB()
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:88
static constexpr index_t dswrite_kIter
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:163
remove_cvref_t< typename Problem::BDataType > BDataType
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:55
static constexpr index_t dsread_num_perK
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:147
static constexpr index_t Bload_num_perK
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:153
remove_cvref_t< typename BlockGemmShape::WarpTile > WarpTile
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:107
static CK_TILE_HOST_DEVICE constexpr auto HotLoopScheduler()
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:252
remove_cvref_t< decltype(config.template at< 0 >())> WG
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:72
static constexpr index_t HalfMIter
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:158
static constexpr index_t NXdlPack
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:126
static CK_TILE_HOST_DEVICE constexpr auto Last2ndHotLoopScheduler()
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:395
static constexpr index_t GetVectorSizeA()
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:87
static constexpr index_t BPackedSize
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:123
static constexpr index_t MWarp
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:109
static CK_TILE_HOST_DEVICE constexpr auto LastHotLoopScheduler()
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:447
static constexpr index_t ScaleBload_num
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:155
static constexpr bool kPadN
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:92
static constexpr index_t dswrite_rep
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:149
static constexpr index_t NumWaveGroups
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:96
static constexpr index_t NFlatPerBlockPerIter
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:117
remove_cvref_t< typename Problem::CLayout > CLayout
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:64
static constexpr index_t mfma_per_wg
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:139
static constexpr bool kPadK
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:93
static constexpr auto I1
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:100
static constexpr index_t WaveSize
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:78
static constexpr index_t flatKPerWarp
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:84
static constexpr index_t NWarp
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:110
static constexpr index_t MIterPerWarp
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:112
static constexpr index_t NIterPerWarp
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:113
static constexpr auto I0
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:99
static constexpr index_t MPerBlockPerIter
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:119
CK_TILE_HOST_DEVICE auto operator()(ADramBlockWindowTmp a_copy_dram_window, const AElementFunction &a_element_func, const BFlatBlockWindowTmp &b_flat_dram_block_window_tmp, const ScaleADramBlockWindowTmp &scale_a_window, const ScaleBDramBlockWindowTmp &scale_b_window, index_t num_loop, void *p_smem_ping, void *p_smem_pong) const
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:477
static constexpr index_t KXdlPack
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:127
remove_cvref_t< typename Problem::ADataType > ADataType
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:54
static constexpr index_t ScaleBload_K1
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:154
static constexpr bool kPadM
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:91
static constexpr index_t MXdlPack
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:125
static constexpr index_t KFlatPerBlockPerIter
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:116
static CK_TILE_HOST_DEVICE constexpr auto GetADramTileDistribution()
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:467
static constexpr index_t kKPerBlock
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:82
remove_cvref_t< typename Problem::CDataType > CDataType
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:56
static constexpr auto I2
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:101
static constexpr bool HasHotLoop
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:136
static constexpr index_t kNPerBlock
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:81
remove_cvref_t< typename Problem::BLayout > BLayout
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:63
remove_cvref_t< typename BlockGemmShape::BlockTile > BlockTile
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:105
remove_cvref_t< typename Problem::BlockGemmShape > BlockGemmShape
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:57
static constexpr index_t AK1
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:129
static constexpr index_t m_preload
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:132
static constexpr auto TailNum
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:137
remove_cvref_t< typename Problem::ALayout > ALayout
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:62
static constexpr bool UsePersistentKernel
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:97
static constexpr index_t dswrite_mIter
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:162
static constexpr index_t Aload_rep
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:151
static constexpr index_t flatNPerWarp
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:85
static constexpr index_t DsReadPreload
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:75
static constexpr index_t Bload_rep
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:159
static constexpr index_t Aload_num_perK
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:150
static constexpr bool DoubleSmemBuffer
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:166
static constexpr auto idxN
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:103
static constexpr index_t dswrite_num_perK
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:148
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:32
BlockGemmShape_ BlockGemmShape
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:33
static constexpr int ScaleGranularityK
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:39
static constexpr int KXdlPack
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:44
static constexpr index_t flatNPerWarp
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:37
static constexpr int MXdlPack
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:42
static constexpr int NXdlPack
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:43
static constexpr index_t flatKPerWarp
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:46
static constexpr int ContinuousKPerThread
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:41
static constexpr int PackedSize
Definition tile/core/numeric/numeric.hpp:82
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43
Definition tile/core/utility/debug.hpp:67