blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp Source File

blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp Source File
blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Compute optimized pipeline
11// GlobalPrefetchStages: 2
12// LocalPreFillStages: 1
13// LocalPreFetchStages: 1
14// LocalSharedMemoryBuffer: 1
15
16template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17 index_t BlockSize,
18 typename ADataType,
19 typename BDataType,
20 typename ComputeDataType,
21 typename AccDataType,
22 typename ATileDesc,
23 typename BTileDesc,
24 typename AMmaTileDesc,
25 typename BMmaTileDesc,
26 index_t ABlockTransferSrcScalarPerVector,
27 index_t BBlockTransferSrcScalarPerVector,
28 index_t MPerBlock,
29 index_t NPerBlock,
30 index_t KPerBlock,
31 index_t MPerXDL,
32 index_t NPerXDL,
33 index_t MRepeat,
34 index_t NRepeat,
35 index_t KPacks>
39
40template <index_t BlockSize,
41 typename ADataType,
42 typename BDataType,
43 typename ComputeDataType,
44 typename AccDataType,
45 typename ATileDesc,
46 typename BTileDesc,
47 typename AMmaTileDesc,
48 typename BMmaTileDesc,
49 index_t ABlockTransferSrcScalarPerVector,
50 index_t BBlockTransferSrcScalarPerVector,
51 index_t MPerBlock,
52 index_t NPerBlock,
53 index_t KPerBlock,
54 index_t MPerXDL,
55 index_t NPerXDL,
56 index_t MRepeat,
57 index_t NRepeat,
58 index_t KPack
59 // ,bool TransposeC //disable transposec right now...
60 >
62 BlockSize,
63 ADataType,
64 BDataType,
65 ComputeDataType,
66 AccDataType,
67 ATileDesc,
68 BTileDesc,
69 AMmaTileDesc,
70 BMmaTileDesc,
71 ABlockTransferSrcScalarPerVector,
72 BBlockTransferSrcScalarPerVector,
73 MPerBlock,
74 NPerBlock,
75 KPerBlock,
76 MPerXDL,
77 NPerXDL,
78 MRepeat,
79 NRepeat,
80 KPack>
82 ADataType,
83 BDataType,
84 ComputeDataType,
85 AccDataType,
86 ATileDesc,
87 BTileDesc,
88 AMmaTileDesc,
89 BMmaTileDesc,
90 ABlockTransferSrcScalarPerVector,
91 BBlockTransferSrcScalarPerVector,
92 MPerBlock,
93 NPerBlock,
94 KPerBlock,
95 MPerXDL,
96 NPerXDL,
97 MRepeat,
98 NRepeat,
99 KPack>
100
101{
103 ADataType,
104 BDataType,
105 ComputeDataType,
106 AccDataType,
107 ATileDesc,
108 BTileDesc,
109 AMmaTileDesc,
110 BMmaTileDesc,
111 ABlockTransferSrcScalarPerVector,
112 BBlockTransferSrcScalarPerVector,
113 MPerBlock,
114 NPerBlock,
115 KPerBlock,
116 MPerXDL,
117 NPerXDL,
118 MRepeat,
119 NRepeat,
120 KPack>;
121 using Base::A_K1;
122 using Base::B_K1;
123 using Base::I0;
124 using Base::I1;
125 using Base::KGroup;
126 using Base::KRepeat;
127 using Base::xdlops_gemm;
128 using typename Base::HotLoopInstList;
129
142
143 using Base::AMmaKStride;
144 using Base::BMmaKStride;
145 using Base::MWaves;
146 using Base::WaveSize;
147
148 static constexpr index_t PrefetchStages = 2;
149 static constexpr index_t PrefillStages = 1;
150 static constexpr index_t GlobalBufferNum = 2;
151
152 template <typename TileDesc_M0_M1_M2_K>
153 __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
154 {
155 constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
156 constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
157 constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
158 constexpr index_t K2 = KPack / KGroup;
159 constexpr index_t K1 = WaveSize / NPerXDL;
160 constexpr index_t K0 = KRepeat * KGroup;
161
163 TileDesc_M0_M1_M2_K{},
171 }
172
173 static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
175
176 __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
177 {
178 return num_loop > PrefetchStages;
179 }
180
181 __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
182 {
183 return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
184 }
185
186 __device__ static constexpr auto HotLoopScheduler()
187 {
188 constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
189 constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
190 constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves;
191 constexpr auto mfma_interleave = MPerXDL == 32 ? 1 : 2;
192 // B global
194 ignore = i;
195 if constexpr(MPerBlock >= 128 && NPerBlock >= 128)
196 {
197 __builtin_amdgcn_sched_group_barrier(0x008, 2 * mfma_interleave, 0);
198 }
199 else
200 {
201 __builtin_amdgcn_sched_group_barrier(0x008, mfma_interleave, 0);
202 }
203 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
204 });
205
206 // A global
208 ignore = i;
209 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
210 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
211 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
212 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
213 });
214
215 // A local
216 static_for<0, MPerXDL == 32 ? num_ds_read_inst_a / 2 : num_ds_read_inst_a, 1>{}(
217 [&](auto i) {
218 ignore = i;
219 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
220 __builtin_amdgcn_sched_group_barrier(0x100, MPerXDL == 32 ? 2 : 1, 0); // DS read
221 });
222 }
223
224 template <bool HasMainLoop,
225 TailNumber TailNum,
226 typename AGridDesc,
227 typename ABlockDesc,
228 typename ABlockTransfer,
229 typename AGridBuffer,
230 typename ABlockBuffer,
231 typename ABlockTransferStep,
232 typename BGridDesc,
233 typename BBlockTransfer,
234 typename BGridBuffer,
235 typename BBlockBuffer,
236 typename BBlockTransferStep,
237 typename CThreadBuffer>
238 __device__ void Run(const AGridDesc& a_grid_desc,
239 const ABlockDesc& a_block_desc,
240 ABlockTransfer& a_blockwise_copy,
241 const AGridBuffer& a_grid_buf,
242 ABlockBuffer& a_block_buf,
243 const ABlockTransferStep& a_block_copy_step,
244 const BGridDesc& b_grid_desc,
245 BBlockTransfer& b_blockwise_copy,
246 const BGridBuffer& b_grid_buf,
247 BBlockBuffer& b_block_buf,
248 const BBlockTransferStep& b_block_copy_step,
249 CThreadBuffer& c_thread_buf,
250 index_t num_loop) const
251 {
252 ignore = b_block_buf;
253 __builtin_amdgcn_sched_barrier(0);
255 a_thread_desc_.GetElementSpaceSize());
257 b_thread_desc_.GetElementSpaceSize());
258
259 StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
260 constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
261
262 // Global prefetch A1 B1
263 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
264 b_blockwise_copy.Run(b_grid_desc,
265 b_grid_buf,
267 b_block_origin_idx,
268 b_thread_bufs(I0));
269
270 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
271 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
272 __builtin_amdgcn_sched_barrier(0);
273
274 // Local prefill A1
275 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
276
277 // Global prefetch A2
278 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
279 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
280
281 // Local prefetch A1
283 static_for<0, MRepeat, 1>{}([&](auto m0) {
284 static_for<0, KRepeat, 1>{}([&](auto k0) {
285 static_for<0, KGroup, 1>{}([&](auto kg0) {
288 a_block_buf,
290 make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
291 a_thread_buf);
292 });
293 });
294 });
295
296 // Initialize C
297 c_thread_buf.Clear();
298
299 __builtin_amdgcn_sched_barrier(0);
300
301 // main body
302 if constexpr(HasMainLoop)
303 {
304 index_t i = 0;
305 do
306 {
307 auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
308 b_blockwise_copy.Run(b_grid_desc,
309 b_grid_buf,
311 b_block_origin_idx,
312 b_thread_bufs(local_read_buf));
313 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
314
316 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf);
317
318 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
319 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
320
321 static_for<0, MRepeat, 1>{}([&](auto m0) {
322 static_for<0, NRepeat, 1>{}([&](auto n0) {
323 static_for<0, KRepeat, 1>{}([&](auto k0) {
326
327 static_for<0, KPack, 1>{}([&](auto ik) {
328 a_thread_vec.template AsType<ComputeDataType>()(ik) =
329 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
330 make_tuple(m0, I0, I0, k0, I0, ik))>{}];
331 b_thread_vec.template AsType<ComputeDataType>()(ik) =
332 b_thread_bufs[mfma_reg_buf]
333 [Number<b_thread_desc_.CalculateOffset(
334 make_tuple(n0, I0, k0, ik))>{}];
335 });
336 using mfma_input_type =
337 typename vector_type<ComputeDataType,
338 xdlops_gemm.K1PerXdlops>::type;
339
340 constexpr index_t c_offset =
341 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
342
343 xdlops_gemm.Run(
344 a_thread_vec.template AsType<mfma_input_type>(),
345 b_thread_vec.template AsType<mfma_input_type>(),
346 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
347 });
348 });
349 });
350
352
353 // loop prefetch copy
354 static_for<0, MRepeat, 1>{}([&](auto m0) {
355 static_for<0, KRepeat, 1>{}([&](auto k0) {
356 static_for<0, KGroup, 1>{}([&](auto kg0) {
357 a_thread_copy_.Run(
360 a_block_buf,
362 make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
363 a_thread_buf);
364 });
365 });
366 });
367
369 __builtin_amdgcn_sched_barrier(0);
370 };
371
372 LoopFunc(I0, I1);
373 LoopFunc(I1, I0);
374
375 i += 2;
376 } while(i < (num_loop - 2));
377 }
378 // tail
379 if constexpr(TailNum == TailNumber::Even)
380 {
381 b_blockwise_copy.Run(b_grid_desc,
382 b_grid_buf,
384 b_block_origin_idx,
385 b_thread_bufs(I1));
386
388 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
389
390 static_for<0, MRepeat, 1>{}([&](auto m0) {
391 static_for<0, NRepeat, 1>{}([&](auto n0) {
392 static_for<0, KRepeat, 1>{}([&](auto k0) {
395
396 static_for<0, KPack, 1>{}([&](auto ik) {
397 a_thread_vec.template AsType<ComputeDataType>()(ik) =
398 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
399 make_tuple(m0, I0, I0, k0, I0, ik))>{}];
400 b_thread_vec.template AsType<ComputeDataType>()(ik) =
401 b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
402 make_tuple(n0, I0, k0, ik))>{}];
403 });
404
405 using mfma_input_type =
406 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
407
408 constexpr index_t c_offset =
409 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
410
411 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
412 b_thread_vec.template AsType<mfma_input_type>(),
413 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
414 });
415 });
416 });
417
419
420 // tail Local Prefetch A1
421 static_for<0, MRepeat, 1>{}([&](auto m0) {
422 static_for<0, KRepeat, 1>{}([&](auto k0) {
423 static_for<0, KGroup, 1>{}([&](auto kg0) {
424 a_thread_copy_.Run(
427 a_block_buf,
429 make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
430 a_thread_buf);
431 });
432 });
433 });
434
435 __builtin_amdgcn_sched_barrier(0);
436
437 static_for<0, MRepeat, 1>{}([&](auto m0) {
438 static_for<0, NRepeat, 1>{}([&](auto n0) {
439 static_for<0, KRepeat, 1>{}([&](auto k0) {
442
443 static_for<0, KPack, 1>{}([&](auto ik) {
444 a_thread_vec.template AsType<ComputeDataType>()(ik) =
445 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
446 make_tuple(m0, I0, I0, k0, I0, ik))>{}];
447 b_thread_vec.template AsType<ComputeDataType>()(ik) =
448 b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
449 make_tuple(n0, I0, k0, ik))>{}];
450 });
451
452 using mfma_input_type =
453 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
454
455 constexpr index_t c_offset =
456 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
457
458 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
459 b_thread_vec.template AsType<mfma_input_type>(),
460 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
461 });
462 });
463 });
464 // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
465 // latency
466 // __builtin_amdgcn_sched_barrier(0);
467 }
468 else if constexpr(TailNum == TailNumber::Odd)
469 {
470 static_for<0, MRepeat, 1>{}([&](auto m0) {
471 static_for<0, NRepeat, 1>{}([&](auto n0) {
472 static_for<0, KRepeat, 1>{}([&](auto k0) {
475
476 static_for<0, KPack, 1>{}([&](auto ik) {
477 a_thread_vec.template AsType<ComputeDataType>()(ik) =
478 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
479 make_tuple(m0, I0, I0, k0, I0, ik))>{}];
480 b_thread_vec.template AsType<ComputeDataType>()(ik) =
481 b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
482 make_tuple(n0, I0, k0, ik))>{}];
483 });
484
485 using mfma_input_type =
486 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
487
488 constexpr index_t c_offset =
489 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
490
491 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
492 b_thread_vec.template AsType<mfma_input_type>(),
493 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
494 });
495 });
496 });
497 }
498 }
499
500 protected:
501 // MRepeat MWave MLane KRepeat KLane KPack
502 // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack
505
507 ComputeDataType,
509 decltype(a_thread_desc_),
510 Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
512 5,
513 A_K1,
514 A_K1>;
515
517
520
521 static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
522
524};
525
526} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ BlockwiseGemmXdlops_pipeline_base(Tuple4 a_origin=CalculateAThreadOriginDataIndex(), Tuple4 b_origin=CalculateBThreadOriginDataIndex())
Constructor for BlockwiseGemmXdlops_pipeline_base.
Definition blockwise_gemm_pipeline_xdlops_base.hpp:222
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:280
static constexpr index_t MWaves
Definition blockwise_gemm_pipeline_xdlops_base.hpp:44
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:239
static constexpr auto c_thread_desc_
Definition blockwise_gemm_pipeline_xdlops_base.hpp:378
static constexpr auto xdlops_gemm
Definition blockwise_gemm_pipeline_xdlops_base.hpp:54
static constexpr index_t KGroup
Definition blockwise_gemm_pipeline_xdlops_base.hpp:67
static constexpr auto I1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:37
__host__ static __device__ constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:266
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:294
static constexpr index_t AMmaKStride
Definition blockwise_gemm_pipeline_xdlops_base.hpp:60
static __device__ auto CalculateAThreadOriginDataIndex6D()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:136
static constexpr index_t WaveSize
Definition blockwise_gemm_pipeline_xdlops_base.hpp:46
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:253
static constexpr index_t B_K1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:51
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops > HotLoopInstList
Definition blockwise_gemm_pipeline_xdlops_base.hpp:82
__host__ __device__ constexpr auto & GetCThreadBuffer()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:111
static constexpr auto I0
Definition blockwise_gemm_pipeline_xdlops_base.hpp:36
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:160
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:189
static constexpr index_t KRepeat
Definition blockwise_gemm_pipeline_xdlops_base.hpp:64
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition blockwise_gemm_pipeline_xdlops_base.hpp:359
static constexpr index_t A_K1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:50
static constexpr index_t BMmaKStride
Definition blockwise_gemm_pipeline_xdlops_base.hpp:61
__host__ static __device__ constexpr auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:341
__host__ static __device__ constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:307
__host__ static __device__ constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:324
BlockwiseGemmXdlops_pipeline_base< BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack > Base
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp:102
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp:238
ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeDataType, decltype(a_block_desc_m0_m1_m2_k0_k1_k2), decltype(a_thread_desc_), Sequence< 1, 1, 1, 1, 1, KPack/KGroup >, Sequence< 0, 1, 2, 3, 4, 5 >, 5, A_K1, A_K1 > AThreadCopy
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp:506
Definition blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp:37
Definition utility/sequence.hpp:43
Definition threadwise_tensor_slice_transfer.hpp:1260
Definition functional2.hpp:33
Definition dtype_vector.hpp:10