device_batched_gemm_gemm_wmma_cshuffle_v3.hpp Source File

device_batched_gemm_gemm_wmma_cshuffle_v3.hpp Source File#

Composable Kernel: device_batched_gemm_gemm_wmma_cshuffle_v3.hpp Source File
device_batched_gemm_gemm_wmma_cshuffle_v3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8#include <numeric>
9#include <initializer_list>
10#include <cstdlib>
11
12#include "ck/ck.hpp"
23
24namespace ck {
25namespace tensor_operation {
26namespace device {
27
28template <typename DeviceOp, typename GridwiseOp, bool HasMainKBlockLoop, TailNumber TailNum>
29__global__ void
30#if CK_USE_LAUNCH_BOUNDS
32#endif
33 kernel_batched_gemm_gemm_wmma_cshuffle_v3(typename DeviceOp::RawArg arg)
34{
35#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
36
37 __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
38 const index_t num_blocks_per_batch =
39 __builtin_amdgcn_readfirstlane(get_grid_size() / arg.batch_count);
40 const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
41
42 const long_index_t a_batch_offset =
43 __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetABasePtr(g_idx)));
44 const long_index_t b0_batch_offset =
45 __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB0BasePtr(g_idx)));
46 const long_index_t b1_batch_offset =
47 __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
48 const long_index_t c_batch_offset =
49 __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
50
51 GridwiseOp::template Run<HasMainKBlockLoop, TailNum>(
52 arg.p_a_grid + a_batch_offset,
53 arg.p_b0_grid + b0_batch_offset,
54 arg.p_b1_grid + b1_batch_offset,
55 arg.p_c_grid + c_batch_offset,
56 p_shared,
57 arg.a_grid_desc,
58 arg.b0_grid_desc,
59 arg.b1_grid_desc,
60 arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
61 arg.a_element_op,
62 arg.b0_element_op,
63 arg.acc_element_op,
64 arg.b1_element_op,
65 arg.c_element_op,
66 arg.block_2_ctile_map);
67#else
68 ignore = arg;
69#endif // (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)
70}
71
72// Computes C = A * B0 * B1
73// MN = MK * KL * LN
74// ^^^^^^ (Acc0)
75// ^^^^^^^^^^^ (Acc1)
76template <typename ALayout,
77 typename B0layout,
78 typename B1Layout,
79 typename CLayout,
80 typename ADataType,
81 typename B0DataType,
82 typename B1DataType,
83 typename CDataType,
84 typename AccDataType,
85 typename CShuffleDataType,
86 typename AElementwiseOperation,
87 typename B0ElementwiseOperation,
88 typename AccElementwiseOperation,
89 typename B1ElementwiseOperation,
90 typename CElementwiseOperation,
91 GemmSpecialization GemmSpec,
92 ck::index_t BlockSize,
93 ck::index_t MPerBlock,
94 ck::index_t LPerBlock, // Gemm0NPerBlock
95 ck::index_t KPerBlock, // Gemm0KPerBlock
96 ck::index_t NPerBlock, // Gemm1NPerBlock
97 ck::index_t LTilePerBlock, // Gemm1KPerBlock
98 ck::index_t AK1,
99 ck::index_t BK1,
100 ck::index_t L1, // B1K1
101 ck::index_t MPerWmma, // Gemm0/1 MPerWmma
102 ck::index_t LPerWmma, // Gemm0/1 NPerWmma
103 ck::index_t MRepeat, // Gemm0/1 MWmmaPerWave or Mrepeat
104 ck::index_t LRepeat, // Gemm0 NWmmaPerWave or Nrepeat
105 ck::index_t NRepeat, // Gemm1 NWmmaPerWave or Nrepeat
106 typename ABlockTransferThreadClusterLengths_K0_M_K1,
107 typename ABlockTransferThreadClusterArrangeOrder,
108 typename ABlockTransferSrcAccessOrder,
109 ck::index_t ABlockTransferSrcVectorDim,
110 ck::index_t ABlockTransferSrcScalarPerVector,
111 ck::index_t ABlockTransferDstScalarPerVector_K1,
112 bool ABlockLdsAddExtraM,
113 typename B0BlockTransferThreadClusterLengths_K0_L_K1,
114 typename B0BlockTransferThreadClusterArrangeOrder,
115 typename B0BlockTransferSrcAccessOrder,
116 ck::index_t B0BlockTransferSrcVectorDim,
117 ck::index_t B0BlockTransferSrcScalarPerVector,
118 ck::index_t B0BlockTransferDstScalarPerVector_K1,
119 bool B0BlockLdsAddExtraL,
120 typename B1BlockTransferThreadClusterLengths_L0_N_L1,
121 typename B1BlockTransferThreadClusterArrangeOrder,
122 typename B1BlockTransferSrcAccessOrder,
123 ck::index_t B1BlockTransferSrcVectorDim,
124 ck::index_t B1BlockTransferSrcScalarPerVector,
125 ck::index_t B1BlockTransferDstScalarPerVector_L1,
126 bool B1BlockLdsAddExtraN,
127 index_t CShuffleMRepeatPerShuffle,
128 index_t CShuffleNRepeatPerShuffle,
129 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
130 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
134 B0layout,
135 B1Layout,
136 CLayout,
137 ADataType,
138 B0DataType,
139 B1DataType,
140 CDataType,
141 AElementwiseOperation,
142 B0ElementwiseOperation,
143 AccElementwiseOperation,
144 B1ElementwiseOperation,
145 CElementwiseOperation>
146{
148
149 static constexpr auto I0 = Number<0>{};
150
151 // To match XDL implementation NPerWmma (A.k.a Gemm1 NPerWmma) is set equal
152 // to LPerWmma (A.k.a Gemm0 NPerWmma).
153 static constexpr index_t NPerWmma = LPerWmma;
154
155 // TODO: Now that we are no longer using NumDim or TensorSpec, we can probably use a simpler
156 // Transform operator or just not use one at all.
160 GemmSpec,
165
166 __host__ __device__ static auto
167 MakeAGridDescriptor(const std::array<index_t, 3>& a_g_m_k_lengths_vec,
168 const std::array<index_t, 3>& a_g_m_k_strides_vec)
169 {
171 Transform::MakeAGridDescriptor_M_K(a_g_m_k_lengths_vec, a_g_m_k_strides_vec),
172 Number<AK1>{});
173 }
174
175 __host__ __device__ static auto
176 MakeB0GridDescriptor(const std::array<index_t, 3>& b0_g_l_k_lengths_vec,
177 const std::array<index_t, 3>& b0_g_l_k_strides_vec)
178 {
180 Transform::MakeB0GridDescriptor_N_K(b0_g_l_k_lengths_vec, b0_g_l_k_strides_vec),
181 Number<BK1>{});
182 }
183
184 __host__ __device__ static auto
185 MakeB1GridDescriptor(const std::array<index_t, 3>& b1_g_n_l_lengths_vec,
186 const std::array<index_t, 3>& b1_g_n_l_strides_vec)
187 {
189 Transform::MakeB1GridDescriptor_N_K(b1_g_n_l_lengths_vec, b1_g_n_l_strides_vec),
190 Number<L1>{});
191 }
192
193 using AGridDesc = decltype(MakeAGridDescriptor({}, {}));
194 using B0GridDesc = decltype(MakeB0GridDescriptor({}, {}));
195 using B1GridDesc = decltype(MakeB1GridDescriptor({}, {}));
197
199 {
201 index_t BatchStrideB0,
202 index_t BatchStrideB1,
203 index_t BatchStrideC)
204 : BatchStrideA_(BatchStrideA),
205 BatchStrideB0_(BatchStrideB0),
206 BatchStrideB1_(BatchStrideB1),
207 BatchStrideC_(BatchStrideC)
208 {
209 }
210
211 __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
212 {
213 return g_idx * static_cast<long_index_t>(BatchStrideA_);
214 }
215
216 __host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const
217 {
218 return g_idx * static_cast<long_index_t>(BatchStrideB0_);
219 }
220
221 __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
222 {
223 return g_idx * static_cast<long_index_t>(BatchStrideB1_);
224 }
225
226 __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
227 {
228 return g_idx * static_cast<long_index_t>(BatchStrideC_);
229 }
230
231 private:
232 index_t BatchStrideA_;
233 index_t BatchStrideB0_;
234 index_t BatchStrideB1_;
235 index_t BatchStrideC_;
236 };
237
238 // GridwiseOp
240 // DataType Family
241 ADataType,
242 B0DataType,
243 AccDataType, // Acc0DataType
244 B1DataType,
245 AccDataType, // Acc1DataType
246 CShuffleDataType,
247 CDataType,
248 // ElementwiseOp Family
249 AElementwiseOperation,
250 B0ElementwiseOperation,
251 AccElementwiseOperation,
252 B1ElementwiseOperation,
253 CElementwiseOperation,
255 // InMemory Data Descriptor
256 AGridDesc,
260 // Tiling Family
261 MPerBlock,
262 LPerBlock,
263 KPerBlock,
264 AK1,
265 BK1,
266 NPerBlock,
267 LTilePerBlock,
268 L1,
269 MPerWmma,
270 LPerWmma,
271 NPerWmma,
272 MRepeat,
273 LRepeat,
274 NRepeat,
275 // ThreadCluster Family
276 BlockSize,
277 ABlockTransferThreadClusterLengths_K0_M_K1,
278 ABlockTransferThreadClusterArrangeOrder,
279 ABlockTransferSrcAccessOrder,
280 ABlockTransferSrcVectorDim,
281 ABlockTransferSrcScalarPerVector,
282 ABlockTransferDstScalarPerVector_K1,
283 true,
284 ABlockLdsAddExtraM,
285 B0BlockTransferThreadClusterLengths_K0_L_K1,
286 B0BlockTransferThreadClusterArrangeOrder,
287 B0BlockTransferSrcAccessOrder,
288 B0BlockTransferSrcVectorDim,
289 B0BlockTransferSrcScalarPerVector,
290 B0BlockTransferDstScalarPerVector_K1,
291 true,
292 B0BlockLdsAddExtraL,
293 B1BlockTransferThreadClusterLengths_L0_N_L1,
294 B1BlockTransferThreadClusterArrangeOrder,
295 B1BlockTransferSrcAccessOrder,
296 B1BlockTransferSrcVectorDim,
297 B1BlockTransferSrcScalarPerVector,
298 B1BlockTransferDstScalarPerVector_L1,
299 false,
300 B1BlockLdsAddExtraN,
301 CShuffleMRepeatPerShuffle,
302 CShuffleNRepeatPerShuffle,
303 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
304 CShuffleBlockTransferScalarPerVector_NPerBlock,
306 BlkGemmPipeSched,
307 BlkGemmPipelineVer>;
308
309 struct RawArg : public BaseArgument
310 {
311 using arr3 = std::array<ck::index_t, 3>;
312
313 RawArg(const ADataType* p_a_grid_,
314 const B0DataType* p_b0_grid_,
315 const B1DataType* p_b1_grid_,
316 CDataType* p_c_grid_,
317 index_t M_,
318 index_t N_,
319 index_t K_,
320 index_t O_,
321 index_t Batch,
322 index_t StrideA,
323 index_t StrideB0,
324 index_t StrideB1,
325 index_t StrideC,
326 index_t BatchStrideA,
327 index_t BatchStrideB0,
328 index_t BatchStrideB1,
329 index_t BatchStrideC,
330 AElementwiseOperation a_element_op_,
331 B0ElementwiseOperation b0_element_op_,
332 AccElementwiseOperation acc_element_op_,
333 B1ElementwiseOperation b1_element_op_,
334 CElementwiseOperation c_element_op_)
335 : p_a_grid{p_a_grid_},
336 p_b0_grid{p_b0_grid_},
337 p_b1_grid{p_b1_grid_},
338 p_c_grid{p_c_grid_},
339 M{M_},
340 N{N_},
341 K{K_},
342 O{O_},
343 batch_count{Batch},
344 a_element_op{a_element_op_},
345 b0_element_op{b0_element_op_},
346 acc_element_op{acc_element_op_},
347 b1_element_op{b1_element_op_},
348 c_element_op{c_element_op_},
349 compute_base_ptr_of_batch{BatchStrideA, BatchStrideB0, BatchStrideB1, BatchStrideC}
350 {
351
353 a_g_m_k_strides = arr3{BatchStrideA, StrideA, 1}; // A layout [batch_count, M, K]
354
356 b0_g_n_k_strides = arr3{BatchStrideB0, StrideB0, 1}; // B0 layout [batch_count, N, K]
357
361 ? arr3{BatchStrideB1, 1, StrideB1} // B1 layout [batch_count, N, O]
362 : arr3{BatchStrideB1, StrideB1, 1}; // B1 layout [batch_count, O, N]
363
365 c_g_m_o_strides = arr3{BatchStrideC, StrideC, 1}; // C layout [batch_count, M, O]
366
374 }
375 // Pointers
376 const ADataType* p_a_grid;
377 const B0DataType* p_b0_grid;
378 const B1DataType* p_b1_grid;
379 CDataType* p_c_grid;
380
381 // Raw Problem Size
387
396
397 AElementwiseOperation a_element_op;
398 B0ElementwiseOperation b0_element_op;
399 AccElementwiseOperation acc_element_op;
400 B1ElementwiseOperation b1_element_op;
401 CElementwiseOperation c_element_op;
402
403 // Grid descriptors and other mem calculators
410
412
414 };
415
416 static bool IsSupportedArgument([[maybe_unused]] const RawArg& arg)
417 {
418 // Print lambda with env check and printf() style formmating.
419 const char* curFunc = __func__;
420 auto print = [&curFunc](const char* format, ...) -> void {
421 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
422 {
423#if defined(__clang__)
424#pragma clang diagnostic push
425#pragma clang diagnostic ignored "-Wformat-nonliteral"
426#endif
427 va_list args;
428 va_start(args, format);
429 std::vfprintf(stdout, format, args);
430 va_end(args);
431#if defined(__clang__)
432#pragma clang diagnostic pop
433#endif
434 std::cout << "In file: " << __FILE__ << ", function: " << curFunc << "\n";
435 }
436 };
437
439 {
440 print("DeviceOp: Arch err\n");
441 return false;
442 }
443
444 if constexpr(std::is_same_v<ADataType, f8_t> || std::is_same_v<ADataType, bf8_t> ||
445 std::is_same_v<B0DataType, f8_t> || std::is_same_v<B0DataType, bf8_t> ||
446 std::is_same_v<B1DataType, f8_t> || std::is_same_v<B1DataType, bf8_t>)
447 {
449 {
450 print("DeviceOp: gfx 11 does not support fp8\n");
451 return false;
452 }
453 }
454
456 {
457 print("DeviceOp: Acc0 Type err\n");
458 return false;
459 }
460
462 {
463 print("DeviceOp: A layout must be Row\n");
464 return false;
465 }
466
468 {
469 print("DeviceOp: B layout must be Column\n");
470 return false;
471 }
472
475 {
476 print("DeviceOp: B1 layout must be Column or Row\n");
477 return false;
478 }
479
481 {
482 print("DeviceOp: C layout must be Row\n");
483 return false;
484 }
485
486 // Other padding modes have not been tested and do not get checked individually.
487 if constexpr(GemmSpec != GemmSpecialization::Default &&
489 {
490 print("Padding mode must be default or MNKO\n");
491 return false;
492 }
493
494 // Per wmma dimensions not equal to 16 are very untested.
495 if constexpr(MPerWmma != 16 || LPerWmma != 16 || NPerWmma != 16)
496 {
497 print("M, L, N per Wmma must be 16\n");
498 return false;
499 }
500
501 if(!GridwiseOp::CheckValidity(arg.a_grid_desc,
502 arg.b0_grid_desc,
503 arg.b1_grid_desc,
504 arg.c_grid_desc_m_n,
505 arg.block_2_ctile_map))
506 {
507 return false;
508 }
509
510 // Check scalar per vector requirement
511 const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M;
512 const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N;
513 const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O;
514 const auto c_extent_lowest = arg.O;
515
516 if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
517 b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
518 b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
519 c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
520 {
521 print("DeviceOp: Data Transfer Vector scalar err\n");
522 return false;
523 }
524
525 // Check vector load/store requirement
526 const auto a_stride_lowest =
527 ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1];
528 const auto b0_stride_lowest =
529 B0BlockTransferSrcVectorDim == 2 ? arg.b0_g_n_k_strides[2] : arg.b0_g_n_k_strides[1];
530 const auto b1_stride_lowest =
531 B1BlockTransferSrcVectorDim == 2 ? arg.b1_g_o_n_strides[2] : arg.b1_g_o_n_strides[1];
532 const auto c_stride_lowest = arg.c_g_m_o_strides[2];
533
534 if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
535 c_stride_lowest == 1))
536 {
537 print("DeviceOp: Data Vectorize transfer err\n");
538 return false;
539 }
540
541 if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MNKOPadding))
542 {
543 return false;
544 }
545
546 return true;
547 }
548
549 // polymorphic
550 bool IsSupportedArgument(const BaseArgument* p_arg) override
551 {
552 return IsSupportedArgument(*dynamic_cast<const RawArg*>(p_arg));
553 }
554
555 struct Invoker : public BaseInvoker
556 {
558
559 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
560 {
561 const auto M0 = math::integer_divide_ceil(arg.M, MPerBlock);
562 const auto N0 = math::integer_divide_ceil(arg.O, NPerBlock);
563
564 const index_t grid_size = arg.batch_count * M0 * N0;
565
566 auto launch_kernel = [&](auto has_main_k_block_loop, auto tail_number) {
567 constexpr bool has_loop = decltype(has_main_k_block_loop)::value;
568 constexpr TailNumber tn = tail_number;
569
570 const auto kernel =
572
574 stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, arg);
575 };
576
577 bool HasMainKBlockLoop = GridwiseOp::CalculateHasMainKBlockLoop(arg.K);
579
580 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
581 {
582 if(HasMainKBlockLoop && TailNum == TailNumber::Full)
583 {
584 return launch_kernel(std::integral_constant<bool, true>{},
585 std::integral_constant<TailNumber, TailNumber::Full>{});
586 }
587 else if(!HasMainKBlockLoop && TailNum == TailNumber::Full)
588 {
589 return launch_kernel(std::integral_constant<bool, false>{},
590 std::integral_constant<TailNumber, TailNumber::Full>{});
591 }
592 else
593 {
594 printf("Invalid HasMainKBlockLoop and TailNum combination for V1!\n");
595 return 0.0f;
596 }
597 }
598 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
599 {
600 if(HasMainKBlockLoop && TailNum == TailNumber::Full)
601 {
602 return launch_kernel(std::integral_constant<bool, true>{},
603 std::integral_constant<TailNumber, TailNumber::Full>{});
604 }
605 else if(!HasMainKBlockLoop && TailNum == TailNumber::Even)
606 {
607 return launch_kernel(std::integral_constant<bool, false>{},
608 std::integral_constant<TailNumber, TailNumber::Even>{});
609 }
610 else if(!HasMainKBlockLoop && TailNum == TailNumber::Odd)
611 {
612 return launch_kernel(std::integral_constant<bool, false>{},
613 std::integral_constant<TailNumber, TailNumber::Odd>{});
614 }
615 else
616 {
617 printf("Invalid HasMainKBlockLoop and TailNum combination for V3!\n");
618 return 0.0f;
619 }
620 }
621 else
622 {
623 printf("Invalid pipeline version!\n");
624 return 0.0f;
625 }
626 }
627
628 // polymorphic
629 float Run(const BaseArgument* p_arg,
630 const StreamConfig& stream_config = StreamConfig{}) override
631 {
632 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
633 }
634 };
635
636 // polymorphic
637 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
638 const void* p_b0,
639 const void* p_b1,
640 void* p_c,
641 ck::index_t M,
642 ck::index_t N,
643 ck::index_t K,
644 ck::index_t O,
645 ck::index_t Batch,
646 ck::index_t StrideA,
647 ck::index_t StrideB0,
648 ck::index_t StrideB1,
649 ck::index_t StrideC,
650 ck::index_t BatchStrideA,
651 ck::index_t BatchStrideB0,
652 ck::index_t BatchStrideB1,
653 ck::index_t BatchStrideC,
654 AElementwiseOperation a_element_op,
655 B0ElementwiseOperation b0_element_op,
656 AccElementwiseOperation acc_element_op,
657 B1ElementwiseOperation b1_element_op,
658 CElementwiseOperation c_element_op) override
659 {
660 return std::make_unique<RawArg>(static_cast<const ADataType*>(p_a),
661 static_cast<const B0DataType*>(p_b0),
662 static_cast<const B1DataType*>(p_b1),
663 static_cast<CDataType*>(p_c),
664 M,
665 N,
666 K,
667 O,
668 Batch,
669 StrideA,
670 StrideB0,
671 StrideB1,
672 StrideC,
673 BatchStrideA,
674 BatchStrideB0,
675 BatchStrideB1,
676 BatchStrideC,
677 a_element_op,
678 b0_element_op,
679 acc_element_op,
680 b1_element_op,
681 c_element_op);
682 }
683
684 static auto MakeInvoker() { return Invoker{}; }
685
686 // polymorphic
687 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
688 {
689 return std::make_unique<Invoker>(Invoker{});
690 }
691
692 template <typename T>
693 static constexpr const char* DataTypeToString()
694 {
695 if constexpr(std::is_same_v<T, float>)
696 {
697 return "fp32";
698 }
699 else if constexpr(std::is_same_v<T, ck::half_t>)
700 {
701 return "fp16";
702 }
703 else if constexpr(std::is_same_v<T, ck::bhalf_t>)
704 {
705 return "bf16";
706 }
707 else if constexpr(std::is_same_v<T, ck::f8_t>)
708 {
709 return "fp8";
710 }
711 else if constexpr(std::is_same_v<T, ck::bf8_t>)
712 {
713 return "bf8";
714 }
715 else if constexpr(std::is_same_v<T, int32_t>)
716 {
717 return "int32";
718 }
719 else if constexpr(std::is_same_v<T, int8_t>)
720 {
721 return "int8";
722 }
723 else if constexpr(std::is_same_v<T, ck::int4_t>)
724 {
725 return "int4";
726 }
727 else
728 {
729 return "unknown";
730 }
731 }
732
733 // polymorphic
734 std::string GetTypeString() const override
735 {
736 auto str = std::stringstream();
737
738 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
741
742 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
748
749 // clang-format off
750 str << "DeviceBatchedGemmGemm_Wmma_CShuffleV3"
751 << "<"
752 << ALayout::name[0]
753 << B0layout::name[0]
754 << B1Layout::name[0]
755 << CLayout::name[0] << ", "
756 << "A " << DataTypeToString<ADataType>() << ", "
757 << "B0 " << DataTypeToString<B0DataType>() << ", "
758 << "B1 " << DataTypeToString<B1DataType>() << ", "
759 << "C " << DataTypeToString<CDataType>() << ", "
760 << "Acc " << DataTypeToString<AccDataType>() << ", "
761 << "Cshuf " << DataTypeToString<CShuffleDataType>() << ", "
762 << BlockSize << ", "
763 << MPerBlock << ", "
764 << LPerBlock << ", "
765 << KPerBlock << ", "
766 << AK1 << ", "
767 << BK1 << ", "
768 << MPerBlock << ", "
769 << NPerBlock << ", "
770 << LTilePerBlock << ", "
771 << L1 << ", "
772 << getGemmSpecializationString(GemmSpec)
773 << ">"
774 << "BlkGemmPipelineScheduler: "
775 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
776 << "BlkGemmPipelineVersion: "
777 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
778 << "BlkGemmPipelinePrefetchStages: "
779 << GridwiseOp::BlockwiseGemmPipe::PrefetchStages;
780 // clang-format on
781
782 return str.str();
783 }
784};
785
786} // namespace device
787} // namespace tensor_operation
788} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
@ Default
Definition tensor_specialization.hpp:12
GemmSpecialization
Definition gemm_specialization.hpp:11
@ Default
Definition gemm_specialization.hpp:13
@ MNKOPadding
Definition gemm_specialization.hpp:29
__global__ void kernel_batched_gemm_gemm_wmma_cshuffle_v3(typename DeviceOp::RawArg arg)
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:33
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
Definition ck.hpp:268
__device__ index_t get_grid_size()
Definition get_id.hpp:49
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Full
Definition blkgemmpipe_scheduler.hpp:49
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
constexpr bool is_same_v
Definition type.hpp:283
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
int64_t long_index_t
Definition ck.hpp:300
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
Definition ck/stream_config.hpp:10
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:88
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))> DefaultBlock2CTileMap
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:498
__host__ static __device__ constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:462
__host__ static __device__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:456
__host__ static __device__ constexpr auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n, index_t, index_t)
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:488
ck::GridwiseBatchedGemmGemm_wmma_cshuffle_v3< ADataType, B0DataType, AccDataType, B1DataType, AccDataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, BlkGemmPipeSched, BlkGemmPipelineVer >::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
__host__ static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:469
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc &a_grid_desc, const B0GridDesc &b0_grid_desc, const B1GridDesc &b1_grid_desc, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:363
ck::GridwiseBatchedGemmGemm_wmma_cshuffle_v3< ADataType, B0DataType, AccDataType, B1DataType, AccDataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, BlkGemmPipeSched, BlkGemmPipelineVer >::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))> CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp:495
Definition utility/sequence.hpp:43
Definition transform_contraction_to_gemm_arraybase.hpp:122
__host__ static __device__ auto MakeB1GridDescriptor_N_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_os_ns_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_os_ns_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:307
__host__ static __device__ auto MakeCGridDescriptor_M_N(const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_os_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_os_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:381
__host__ static __device__ auto MakeAGridDescriptor_M_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:162
__host__ static __device__ auto MakeB0GridDescriptor_N_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ns_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ns_ks_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:234
Definition device_base.hpp:197
__host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:216
ComputeBasePtrOfStridedBatch(index_t BatchStrideA, index_t BatchStrideB0, index_t BatchStrideB1, index_t BatchStrideC)
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:200
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:211
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:226
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:221
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:556
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:629
DeviceOp::RawArg Argument
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:557
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:559
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:310
arr3 a_g_m_k_lengths
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:388
std::array< ck::index_t, 3 > arr3
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:311
arr3 b1_g_o_n_strides
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:393
B1ElementwiseOperation b1_element_op
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:400
AElementwiseOperation a_element_op
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:397
GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:409
AccElementwiseOperation acc_element_op
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:399
arr3 a_g_m_k_strides
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:389
B1GridDesc b1_grid_desc
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:406
const B1DataType * p_b1_grid
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:378
arr3 c_g_m_o_lengths
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:394
index_t N
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:383
index_t batch_count
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:386
arr3 b1_g_o_n_lengths
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:392
arr3 b0_g_n_k_lengths
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:390
CGridDesc_M_N c_grid_desc_m_n
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:407
arr3 b0_g_n_k_strides
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:391
arr3 c_g_m_o_strides
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:395
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:413
B0GridDesc b0_grid_desc
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:405
RawArg(const ADataType *p_a_grid_, const B0DataType *p_b0_grid_, const B1DataType *p_b1_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t O_, index_t Batch, index_t StrideA, index_t StrideB0, index_t StrideB1, index_t StrideC, index_t BatchStrideA, index_t BatchStrideB0, index_t BatchStrideB1, index_t BatchStrideC, AElementwiseOperation a_element_op_, B0ElementwiseOperation b0_element_op_, AccElementwiseOperation acc_element_op_, B1ElementwiseOperation b1_element_op_, CElementwiseOperation c_element_op_)
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:313
CElementwiseOperation c_element_op
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:401
CDataType * p_c_grid
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:379
GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:411
index_t K
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:384
const B0DataType * p_b0_grid
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:377
AGridDesc a_grid_desc
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:404
B0ElementwiseOperation b0_element_op
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:398
index_t O
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:385
index_t M
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:382
const ADataType * p_a_grid
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:376
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:146
static constexpr const char * DataTypeToString()
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:693
TransformBatchedContractionContractionToBatchedGemmGemm_Wmma< Sequence< 1, 1, 1, 1, 1 >, Sequence< MPerBlock, LPerBlock, KPerBlock, NPerBlock >, GemmSpec, TensorSpecialization::Default, TensorSpecialization::Default, TensorSpecialization::Default, TensorSpecialization::Default > Transform
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:157
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:550
decltype(MakeB0GridDescriptor({}, {})) B0GridDesc
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:194
GridwiseBatchedGemmGemm_wmma_cshuffle_v3< ADataType, B0DataType, AccDataType, B1DataType, AccDataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, BlkGemmPipeSched, BlkGemmPipelineVer > GridwiseOp
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:239
__host__ static __device__ auto MakeB1GridDescriptor(const std::array< index_t, 3 > &b1_g_n_l_lengths_vec, const std::array< index_t, 3 > &b1_g_n_l_strides_vec)
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:185
static constexpr index_t NPerWmma
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:153
static bool IsSupportedArgument(const RawArg &arg)
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:416
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:687
decltype(MakeAGridDescriptor({}, {})) AGridDesc
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:193
__host__ static __device__ auto MakeAGridDescriptor(const std::array< index_t, 3 > &a_g_m_k_lengths_vec, const std::array< index_t, 3 > &a_g_m_k_strides_vec)
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:167
static auto MakeInvoker()
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:684
__host__ static __device__ auto MakeB0GridDescriptor(const std::array< index_t, 3 > &b0_g_l_k_lengths_vec, const std::array< index_t, 3 > &b0_g_l_k_strides_vec)
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:176
decltype(Transform::MakeCGridDescriptor_M_N({}, {})) CGridDesc_M_N
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:196
decltype(MakeB1GridDescriptor({}, {})) B1GridDesc
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:195
DeviceBatchedGemmGemm_Wmma_CShuffleV3 DeviceOp
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:147
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b0, const void *p_b1, void *p_c, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t O, ck::index_t Batch, ck::index_t StrideA, ck::index_t StrideB0, ck::index_t StrideB1, ck::index_t StrideC, ck::index_t BatchStrideA, ck::index_t BatchStrideB0, ck::index_t BatchStrideB1, ck::index_t BatchStrideC, AElementwiseOperation a_element_op, B0ElementwiseOperation b0_element_op, AccElementwiseOperation acc_element_op, B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op) override
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:637
std::string GetTypeString() const override
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:734
static constexpr auto I0
Definition device_batched_gemm_gemm_wmma_cshuffle_v3.hpp:149
Definition device_batched_gemm_gemm.hpp:29
#define CK_ENV(name)
Definition utility/env.hpp:129