device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp Source File

device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp Source File#

Composable Kernel: device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp Source File
device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
19
20namespace ck {
21
22// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
23// kernel function Blockers:
24// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
25// two lds chunks.
26// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
27// buffer when we declare __shared__ inside blkgemmpipe
28template <typename GridwiseGemm,
29 typename BatchedGemmArg,
30 bool HasMainKBlockLoop,
31 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
32 index_t MinimumOccupancy = 1,
34__global__ void
35#if CK_USE_LAUNCH_BOUNDS
36__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
37#endif
39{
40#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
41 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
42 {
43 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
44
45 const index_t g_idx = blockIdx.z % karg.Batch;
46 const index_t k_idx = blockIdx.z / karg.Batch;
47
48 const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
49 const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
50 const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
51 const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
52
53 auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx);
54
55 // populate pointer, desc for Ds
57 // D pointer
58 karg.p_ds_grid(i) = karg.p_ds_grid(i) + ds_batch_offset[i];
59 });
60
61 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
62 karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset,
63 karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset,
64 karg.p_ds_grid,
65 karg.p_c_grid + c_batch_offset,
66 p_shared,
67 karg,
68 karg.a_element_op,
69 karg.b_element_op,
70 karg.c_element_op);
71 }
72#else
73 ignore = karg;
74#endif // end of if (defined(__gfx9__))
75}
76
77template <typename GridwiseGemm,
78 typename BatchedGemmArg,
79 bool HasMainKBlockLoop,
80 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
81 index_t MinimumOccupancy = 1,
83__global__ void
84#if CK_USE_LAUNCH_BOUNDS
85__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
86#endif
88{
89#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
90 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
91 {
92 // Pass two lds pointer is the key to tell compiler that ds_read/write
93 // operate on different lds chunk at same time without order dependecy
94 __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
95 __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
96
97 const index_t g_idx = blockIdx.z % karg.Batch;
98 const index_t k_idx = blockIdx.z / karg.Batch;
99
100 const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
101 const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
102 const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
103 const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
104
105 auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx);
106
107 // populate pointer, desc for Ds
109 // D pointer
110 karg.p_ds_grid(i) = karg.p_ds_grid(i) + ds_batch_offset[i];
111 });
112
113 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
114 karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset,
115 karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset,
116 karg.p_ds_grid,
117 karg.p_c_grid + c_batch_offset,
118 p_shared_0,
119 p_shared_1,
120 karg,
121 karg.a_element_op,
122 karg.b_element_op,
123 karg.c_element_op);
124 }
125#else
126 ignore = karg;
127#endif // end of if (defined(__gfx9__))
128}
129
130namespace tensor_operation {
131namespace device {
132
133template <typename ALayout,
134 typename BLayout,
135 typename DsLayout,
136 typename CLayout,
137 typename ADataType,
138 typename BDataType,
139 typename DsDataType,
140 typename CDataType,
141 typename GemmAccDataType,
142 typename CShuffleDataType,
143 typename AElementwiseOperation,
144 typename BElementwiseOperation,
145 typename CElementwiseOperation,
146 GemmSpecialization GemmSpec,
147 index_t BlockSize,
148 index_t MPerBlock,
149 index_t NPerBlock,
150 index_t KPerBlock,
151 index_t AK1,
152 index_t BK1,
153 index_t MPerXDL,
154 index_t NPerXDL,
155 index_t MXdlPerWave,
156 index_t NXdlPerWave,
157 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
158 typename ABlockTransferThreadClusterArrangeOrder,
159 typename ABlockTransferSrcAccessOrder,
160 index_t ABlockTransferSrcVectorDim,
161 index_t ABlockTransferSrcScalarPerVector,
162 index_t ABlockTransferDstScalarPerVector_AK1,
163 bool ABlockLdsExtraM,
164 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
165 typename BBlockTransferThreadClusterArrangeOrder,
166 typename BBlockTransferSrcAccessOrder,
167 index_t BBlockTransferSrcVectorDim,
168 index_t BBlockTransferSrcScalarPerVector,
169 index_t BBlockTransferDstScalarPerVector_BK1,
170 bool BBlockLdsExtraN,
171 index_t CShuffleMXdlPerWavePerShuffle,
172 index_t CShuffleNXdlPerWavePerShuffle,
173 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
174 typename CDEShuffleBlockTransferScalarPerVectors,
177 typename ComputeTypeA = ADataType,
178 typename ComputeTypeB = BDataType,
179 typename LDSTypeA = ComputeTypeA,
180 typename LDSTypeB = ComputeTypeB>
182 : public DeviceBatchedGemmV2MultiD<ALayout,
183 BLayout,
184 DsLayout,
185 CLayout,
186 ADataType,
187 BDataType,
188 DsDataType,
189 CDataType,
190 AElementwiseOperation,
191 BElementwiseOperation,
192 CElementwiseOperation>
193{
195 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
196 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
197
198 static constexpr index_t NumDTensor = DsDataType::Size();
199 using CDEShuffleBlockTransferScalarPerVectors_ = CDEShuffleBlockTransferScalarPerVectors;
200 using CDataType_ = CDataType;
201
202 // GridwiseGemm
203 template <index_t NXdlPerWave_>
205 ALayout,
206 BLayout,
207 DsLayout,
208 CLayout,
209 ADataType,
210 BDataType,
211 GemmAccDataType,
212 CShuffleDataType,
213 DsDataType,
214 CDataType,
215 AElementwiseOperation,
216 BElementwiseOperation,
217 CElementwiseOperation,
218 GemmSpec,
219 BlockSize,
220 MPerBlock,
221 NPerBlock,
222 KPerBlock,
223 AK1,
224 BK1,
225 MPerXDL,
226 NPerXDL,
227 MXdlPerWave,
228 NXdlPerWave_,
229 ABlockTransferThreadClusterLengths_AK0_M_AK1,
230 ABlockTransferThreadClusterArrangeOrder,
231 ABlockTransferSrcAccessOrder,
232 ABlockTransferSrcVectorDim,
233 ABlockTransferSrcScalarPerVector,
234 ABlockTransferDstScalarPerVector_AK1,
235 false,
236 ABlockLdsExtraM,
237 BBlockTransferThreadClusterLengths_BK0_N_BK1,
238 BBlockTransferThreadClusterArrangeOrder,
239 BBlockTransferSrcAccessOrder,
240 BBlockTransferSrcVectorDim,
241 BBlockTransferSrcScalarPerVector,
242 BBlockTransferDstScalarPerVector_BK1,
243 false,
244 BBlockLdsExtraN,
245 CShuffleMXdlPerWavePerShuffle,
246 CShuffleNXdlPerWavePerShuffle,
247 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
248 CDEShuffleBlockTransferScalarPerVectors,
249 BlkGemmPipeSched,
250 BlkGemmPipelineVer,
251 ComputeTypeA,
252 ComputeTypeB,
253 LDSTypeA,
254 LDSTypeB>;
258
260 {
263 index_t BatchStrideB,
264 std::array<ck::index_t, NumDTensor> BatchStrideDs,
265 index_t BatchStrideC)
266 : BatchStrideA_(BatchStrideA),
267 BatchStrideB_(BatchStrideB),
268 BatchStrideDs_(BatchStrideDs),
269 BatchStrideC_(BatchStrideC)
270 {
271 }
272
273 __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
274 {
275 return static_cast<long_index_t>(BatchStrideA_) * g_idx;
276 }
277
278 __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
279 {
280 return static_cast<long_index_t>(BatchStrideB_) * g_idx;
281 }
282
283 __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
284 {
285 std::array<long_index_t, NumDTensor> ds_offset_;
286
287 static_for<0, NumDTensor, 1>{}([&](auto i) {
288 ds_offset_[i] = static_cast<long_index_t>(BatchStrideDs_[i]) * g_idx;
289 });
290
291 return ds_offset_;
292 }
293
294 __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
295 {
296 return static_cast<long_index_t>(BatchStrideC_) * g_idx;
297 }
298
299 private:
300 index_t BatchStrideA_;
301 index_t BatchStrideB_;
302 std::array<ck::index_t, NumDTensor> BatchStrideDs_;
303 index_t BatchStrideC_;
304 };
305
306 template <typename GridwiseGemm>
307 struct ArgumentBase : public GridwiseGemm::Argument
308 {
311
312 ArgumentBase() = default;
313 ArgumentBase(const ADataType* p_a_grid_,
314 const BDataType* p_b_grid_,
315 std::array<const void*, NumDTensor> p_ds_grid_,
316 CDataType* p_e_grid_,
317 index_t M_,
318 index_t N_,
319 index_t K_,
320 index_t StrideA_,
321 index_t StrideB_,
322 std::array<index_t, NumDTensor> StrideDs_,
323 index_t StrideE_,
324 index_t BatchStrideA_,
325 index_t BatchStrideB_,
326 const std::array<ck::index_t, NumDTensor>& BatchStrideDs_,
327 index_t BatchStrideE_,
328 index_t Batch_,
329 AElementwiseOperation a_element_op_,
330 BElementwiseOperation b_element_op_,
331 CElementwiseOperation c_element_op_,
332 index_t KBatch_)
333 : GridwiseGemm::Argument{p_a_grid_,
334 p_b_grid_,
335 p_ds_grid_,
336 p_e_grid_,
337 M_,
338 N_,
339 K_,
340 StrideA_,
341 StrideB_,
342 StrideDs_,
343 StrideE_,
344 KBatch_,
345 a_element_op_,
346 b_element_op_,
347 c_element_op_},
348 Batch{Batch_},
350 BatchStrideA_, BatchStrideB_, BatchStrideDs_, BatchStrideE_}
351 {
352 }
353 };
355
357 {
359 {
360 constexpr int dynamic_smem_size = 0;
361 int max_occupancy = 0;
362
363 constexpr index_t minimum_occupancy = []() {
364 if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
365 {
366 return 2;
367 }
368 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
369 {
370 return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
371 }
372 else
373 {
374 return 1;
375 }
376 }();
377
378 if(get_warp_size() == 64)
379 {
380 if constexpr(NXdlPerWave64 > 0)
381 {
382 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
383 {
384 hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
385 &max_occupancy,
388 Argument,
389 true,
391 minimum_occupancy>,
392 BlockSize,
393 dynamic_smem_size));
394 }
395 else
396 {
397 hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
398 &max_occupancy,
401 Argument,
402 true,
404 minimum_occupancy>,
405 BlockSize,
406 dynamic_smem_size));
407 }
408 }
409 }
410 else
411 {
412 if constexpr(NXdlPerWave32 > 0)
413 {
414 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
415 {
416 hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
417 &max_occupancy,
421 true,
423 minimum_occupancy>,
424 BlockSize,
425 dynamic_smem_size));
426 }
427 else
428 {
429 hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
430 &max_occupancy,
434 true,
436 minimum_occupancy>,
437 BlockSize,
438 dynamic_smem_size));
439 }
440 }
441 }
442
443 max_occupancy_ = std::max(1, max_occupancy);
444 }
446 };
447
448 // Invoker
449 struct Invoker : public BaseInvoker
450 {
451 template <typename GridwiseGemm>
452 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
453 {
454 using BatchGemmArgument = ArgumentBase<GridwiseGemm>;
455 if(stream_config.log_level_ > 0)
456 {
457 arg.Print();
458 }
459
460 if(!GridwiseGemm::CheckValidity(reinterpret_cast<const BatchGemmArgument&>(arg)))
461 {
462 throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
463 }
464
465 index_t gdx, gdy, gdz;
466 std::tie(gdx, gdy, gdz) =
467 GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.Batch * arg.KBatch);
468
469 float ave_time = 0;
470
471 index_t k_grain = arg.KBatch * KPerBlock;
472 index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
473
474 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
475
476 const auto Run = [&](const auto& kernel) {
477 if(stream_config.flush_cache)
478 {
479
480 std::array<std::size_t, NumDTensor> DsSize;
481
482 BatchGemmArgument arg_ = reinterpret_cast<const BatchGemmArgument&>(arg);
483
484 const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
485 arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
486 const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
487 arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
488
489 auto size_a_buffer =
490 a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType) * arg.Batch;
491 auto size_b_buffer =
492 b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType) * arg.Batch;
493
494 const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
495 arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
496
497 static_for<0, NumDTensor, 1>{}([&](auto i) {
498 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
499 DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
500 });
502 rotating_mem(arg_,
503 stream_config.rotating_count,
504 size_a_buffer,
505 size_b_buffer,
506 DsSize);
507 rotating_mem.Print();
508
509 auto run_flush_cache = [&]() {
510 // flush icache
512 // rotating mem
513 rotating_mem.Next();
514 // clear c mem
515 if(arg_.KBatch > 1)
516 hipGetErrorString(
517 hipMemsetAsync(arg_.p_c_grid,
518 0,
519 arg.Batch * arg_.M * arg_.N * sizeof(CDataType),
520 stream_config.stream_id_));
521 };
522
524 stream_config,
525 run_flush_cache,
526 kernel,
527 dim3(gdx, gdy, gdz),
528 dim3(BlockSize),
529 0,
530 arg_);
531 }
532 else
533 {
534 const auto clear_workspace = [&]() {
535 if(arg.KBatch > 1)
536 hipGetErrorString(
537 hipMemsetAsync(arg.p_c_grid,
538 0,
539 arg.Batch * arg.M * arg.N * sizeof(CDataType),
540 stream_config.stream_id_));
541 };
542
543 BatchGemmArgument arg_ = reinterpret_cast<const BatchGemmArgument&>(arg);
544 ave_time = launch_and_time_kernel_with_preprocess(stream_config,
545 clear_workspace,
546 kernel,
547 dim3(gdx, gdy, gdz),
548 dim3(BlockSize),
549 0,
550 arg_);
551 }
552 };
553
554 constexpr index_t minimum_occupancy = []() {
555 if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
556 {
557 return 2;
558 }
559 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
560 {
561 return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
562 }
563 else
564 {
565 return 1;
566 }
567 }();
568
569 if(has_main_k_block_loop)
570 {
571 // Tail number always full
572 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
573 BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
574 {
575 if(arg.KBatch > 1)
576 {
579 BatchGemmArgument,
580 true,
582 minimum_occupancy>;
583 Run(kernel);
584 }
585 else
586 {
589 BatchGemmArgument,
590 true,
592 minimum_occupancy>;
593 Run(kernel);
594 }
595 }
596 // Tail number could be One to Seven
597 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
598 {
599 if(arg.KBatch > 1)
600 {
602 {
605 BatchGemmArgument,
606 true,
608 minimum_occupancy,
610 Run(kernel);
611 }
614 {
617 BatchGemmArgument,
618 true,
620 minimum_occupancy,
622 Run(kernel);
623 }
624
625 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
626 {
628 {
631 BatchGemmArgument,
632 true,
634 minimum_occupancy,
636 Run(kernel);
637 }
638 }
639
640 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
641 {
644 {
647 BatchGemmArgument,
648 true,
650 minimum_occupancy,
652 Run(kernel);
653 }
654 }
655
656 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
657 {
660 {
663 BatchGemmArgument,
664 true,
666 minimum_occupancy,
668 Run(kernel);
669 }
670 }
671
672 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
673 {
676 {
679 BatchGemmArgument,
680 true,
682 minimum_occupancy,
684 Run(kernel);
685 }
686 }
687
688 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
689 {
691 {
694 BatchGemmArgument,
695 true,
697 minimum_occupancy,
699 Run(kernel);
700 }
701 }
702
703 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
704 {
707 {
710 BatchGemmArgument,
711 true,
713 minimum_occupancy,
715 Run(kernel);
716 }
717 }
718 }
719 else
720 {
722 {
725 BatchGemmArgument,
726 true,
728 minimum_occupancy,
730 Run(kernel);
731 }
734 {
737 BatchGemmArgument,
738 true,
740 minimum_occupancy,
742 Run(kernel);
743 }
744
745 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
746 {
748 {
751 BatchGemmArgument,
752 true,
754 minimum_occupancy,
756 Run(kernel);
757 }
758 }
759
760 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
761 {
764 {
767 BatchGemmArgument,
768 true,
770 minimum_occupancy,
772 Run(kernel);
773 }
774 }
775
776 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
777 {
780 {
783 BatchGemmArgument,
784 true,
786 minimum_occupancy,
788 Run(kernel);
789 }
790 }
791
792 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
793 {
796 {
799 BatchGemmArgument,
800 true,
802 minimum_occupancy,
804 Run(kernel);
805 }
806 }
807
808 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
809 {
811 {
814 BatchGemmArgument,
815 true,
817 minimum_occupancy,
819 Run(kernel);
820 }
821 }
822
823 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
824 {
827 {
830 BatchGemmArgument,
831 true,
833 minimum_occupancy,
835 Run(kernel);
836 }
837 }
838 }
839 }
840 // Tail number could be Odd or Even
841 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
842 {
843 if(arg.KBatch > 1)
844 {
846 {
849 BatchGemmArgument,
850 true,
852 minimum_occupancy,
854 Run(kernel);
855 }
856 else
857 {
860 BatchGemmArgument,
861 true,
863 minimum_occupancy,
865 Run(kernel);
866 }
867 }
868 else
869 {
871 {
874 BatchGemmArgument,
875 true,
877 minimum_occupancy,
879 Run(kernel);
880 }
881 else
882 {
885 BatchGemmArgument,
886 true,
888 minimum_occupancy,
890 Run(kernel);
891 }
892 }
893 }
894 else
895 {
896 if(arg.KBatch > 1)
897 {
899 {
902 BatchGemmArgument,
903 true,
905 minimum_occupancy,
907 Run(kernel);
908 }
909 else
910 {
913 BatchGemmArgument,
914 true,
916 minimum_occupancy,
918 Run(kernel);
919 }
920 }
921 else
922 {
924 {
927 BatchGemmArgument,
928 true,
930 minimum_occupancy,
932 Run(kernel);
933 }
934 else
935 {
938 BatchGemmArgument,
939 true,
941 minimum_occupancy,
943 Run(kernel);
944 }
945 }
946 }
947 }
948 else
949 {
950 // Tail number always 1
951 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
952 {
953 if(arg.KBatch > 1)
954 {
957 BatchGemmArgument,
958 false,
960 minimum_occupancy>;
961 Run(kernel);
962 }
963 else
964 {
967 BatchGemmArgument,
968 false,
970 minimum_occupancy>;
971 Run(kernel);
972 }
973 }
974 }
975
976 return ave_time;
977 }
978
980
981 // polymorphic
982 float Run(const BaseArgument* p_arg,
983 const StreamConfig& stream_config = StreamConfig{}) override
984 {
985 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
986 }
987 };
988
989 static constexpr bool IsValidCompilationParameter()
990 {
991 // TODO: properly implement this check
992 return true;
993 }
994
995 static bool IsSupportedArgument(const Argument& arg)
996 {
998 {
999 return false;
1000 }
1001 if(is_gfx11_supported() && arg.KBatch > 1)
1002 {
1003 return false;
1004 }
1005 if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
1006 {
1007 return false;
1008 }
1009
1010 if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
1011 GemmSpec == GemmSpecialization::NKPadding ||
1012 GemmSpec == GemmSpecialization::MNKPadding ||
1013 GemmSpec == GemmSpecialization::KPadding))
1014 {
1015 return false;
1016 }
1017 if(get_warp_size() == 64)
1018 {
1019 if constexpr(NXdlPerWave64 > 0)
1020 {
1022 }
1023 }
1024 else
1025 {
1026 if constexpr(NXdlPerWave32 > 0)
1027 {
1029 reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg));
1030 }
1031 }
1032 return false;
1033 }
1034
1035 // polymorphic
1036 bool IsSupportedArgument(const BaseArgument* p_arg) override
1037 {
1038 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
1039 }
1040
1041 static auto MakeArgument(const void* p_a,
1042 const void* p_b,
1043 std::array<const void*, NumDTensor> p_ds,
1044 void* p_e,
1045 index_t M,
1046 index_t N,
1047 index_t K,
1048 index_t Batch,
1049 index_t StrideA,
1050 index_t StrideB,
1051 std::array<index_t, NumDTensor> StrideDs,
1052 index_t StrideE,
1053 index_t BatchStrideA,
1054 index_t BatchStrideB,
1055 const std::array<ck::index_t, NumDTensor>& BatchStrideDs,
1056 index_t BatchStrideE,
1057 AElementwiseOperation a_element_op,
1058 BElementwiseOperation b_element_op,
1059 CElementwiseOperation c_element_op,
1060 index_t KBatch = 1)
1061 {
1062 return Argument{static_cast<const ADataType*>(p_a),
1063 static_cast<const BDataType*>(p_b),
1064 p_ds,
1065 static_cast<CDataType*>(p_e),
1066 M,
1067 N,
1068 K,
1069 StrideA,
1070 StrideB,
1071 StrideDs,
1072 StrideE,
1073 BatchStrideA,
1074 BatchStrideB,
1075 BatchStrideDs,
1076 BatchStrideE,
1077 Batch,
1078 a_element_op,
1079 b_element_op,
1080 c_element_op,
1081 KBatch};
1082 }
1083
1084 static auto MakeInvoker() { return Invoker{}; }
1085
1086 // polymorphic
1087 std::unique_ptr<BaseArgument>
1088 MakeArgumentPointer(const void* p_a,
1089 const void* p_b,
1090 const std::array<const void*, NumDTensor>& p_ds,
1091 void* p_e,
1092 index_t M,
1093 index_t N,
1094 index_t K,
1095 index_t Batch,
1096 index_t StrideA,
1097 index_t StrideB,
1098 const std::array<ck::index_t, NumDTensor>& StrideDs,
1099 index_t StrideE,
1100 index_t BatchStrideA,
1101 index_t BatchStrideB,
1102 const std::array<ck::index_t, NumDTensor>& BatchStrideDs,
1103 index_t BatchStrideE,
1104 AElementwiseOperation a_element_op,
1105 BElementwiseOperation b_element_op,
1106 CElementwiseOperation c_element_op,
1107 index_t KBatch = 1) override
1108 {
1109 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
1110 static_cast<const BDataType*>(p_b),
1111 p_ds,
1112 static_cast<CDataType*>(p_e),
1113 M,
1114 N,
1115 K,
1116 StrideA,
1117 StrideB,
1118 StrideDs,
1119 StrideE,
1120 BatchStrideA,
1121 BatchStrideB,
1122 BatchStrideDs,
1123 BatchStrideE,
1124 Batch,
1125 a_element_op,
1126 b_element_op,
1127 c_element_op,
1128 KBatch);
1129 }
1130
1131 // polymorphic
1132 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
1133 {
1134 return std::make_unique<Invoker>(Invoker{});
1135 }
1136
1137 // polymorphic
1138 std::string GetTypeString() const override
1139 {
1140 auto str = std::stringstream();
1141
1142 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
1145
1146 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
1152
1153 // clang-format off
1154 str << "DeviceBatchedGemmXdlUniversal"
1155 << "<"
1156 << getGemmSpecializationString(GemmSpec) << ", "
1157 << std::string(ALayout::name)[0]
1158 << std::string(BLayout::name)[0]
1159 << std::string(CLayout::name)[0]
1160 << ">"
1161 << " BlkSize: "
1162 << BlockSize << ", "
1163 << "BlkTile: "
1164 << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
1165 << "WaveTile: "
1166 << MPerXDL<<"x"<<NPerXDL << ", "
1167 << "WaveMap: "
1168 << MXdlPerWave<<"x" << NXdlPerWave<<", "
1169 << "VmemReadVec: "
1170 << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
1171 << "BlkGemmPipelineScheduler: "
1172 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
1173 << "BlkGemmPipelineVersion: "
1174 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
1175 << "BlkGemmPipelinePrefetchStages: "
1176 << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
1177 // clang-format on
1178
1179 return str.str();
1180 }
1181
1183 {
1184 static ActiveWorkgroupsPerCU active_workgroups_per_cu;
1185 return active_workgroups_per_cu.max_occupancy_;
1186 }
1187};
1188
1189} // namespace device
1190} // namespace tensor_operation
1191} // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
void hip_check_error(hipError_t x)
Definition host_utility/hip_check_error.hpp:10
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:91
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
void flush_icache()
Definition flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition flush_cache.hpp:398
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ One
Definition blkgemmpipe_scheduler.hpp:37
@ Seven
Definition blkgemmpipe_scheduler.hpp:43
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Four
Definition blkgemmpipe_scheduler.hpp:40
@ Two
Definition blkgemmpipe_scheduler.hpp:38
@ Full
Definition blkgemmpipe_scheduler.hpp:49
@ Three
Definition blkgemmpipe_scheduler.hpp:39
@ Five
Definition blkgemmpipe_scheduler.hpp:41
@ Six
Definition blkgemmpipe_scheduler.hpp:42
__global__ void kernel_batched_gemm_xdl_cshuffle_v3_multi_d(BatchedGemmArg karg)
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:38
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
__global__ void kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds(BatchedGemmArg karg)
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:87
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
int64_t long_index_t
Definition ck.hpp:300
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
bool is_bf16_atomic_supported()
Definition host_utility/device_prop.hpp:108
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp:157
Definition functional2.hpp:33
Definition device_base.hpp:197
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:273
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:278
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, index_t BatchStrideB, std::array< ck::index_t, NumDTensor > BatchStrideDs, index_t BatchStrideC)
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:262
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:283
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:294
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:357
int max_occupancy_
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:445
ActiveWorkgroupsPerCU()
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:358
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:308
ArgumentBase(const ADataType *p_a_grid_, const BDataType *p_b_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_e_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideE_, index_t BatchStrideA_, index_t BatchStrideB_, const std::array< ck::index_t, NumDTensor > &BatchStrideDs_, index_t BatchStrideE_, index_t Batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_, index_t KBatch_)
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:313
index_t Batch
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:309
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:310
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:450
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:452
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:982
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:193
std::string GetTypeString() const override
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:1138
static constexpr auto NXdlPerWave32
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:196
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:1036
GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB > GridwiseGemmBase
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:204
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:1132
static constexpr index_t NumDTensor
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:198
GridwiseGemm64 GridwiseGemm
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:257
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:195
CDataType CDataType_
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:200
static constexpr bool IsValidCompilationParameter()
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:989
ArgumentBase< GridwiseGemm64 > Argument
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:354
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:255
CDEShuffleBlockTransferScalarPerVectors CDEShuffleBlockTransferScalarPerVectors_
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:199
static ck::index_t GetMaxOccupancy()
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:1182
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, index_t M, index_t N, index_t K, index_t Batch, index_t StrideA, index_t StrideB, const std::array< ck::index_t, NumDTensor > &StrideDs, index_t StrideE, index_t BatchStrideA, index_t BatchStrideB, const std::array< ck::index_t, NumDTensor > &BatchStrideDs, index_t BatchStrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, index_t KBatch=1) override
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:1088
static bool IsSupportedArgument(const Argument &arg)
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:995
static auto MakeInvoker()
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:1084
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:256
static auto MakeArgument(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, index_t M, index_t N, index_t K, index_t Batch, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideE, index_t BatchStrideA, index_t BatchStrideB, const std::array< ck::index_t, NumDTensor > &BatchStrideDs, index_t BatchStrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, index_t KBatch=1)
Definition device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp:1041
Definition device_batched_gemm_multi_d.hpp:68
Definition flush_cache.hpp:174