device_batched_gemm_multiple_d_dl.hpp Source File

device_batched_gemm_multiple_d_dl.hpp Source File#

Composable Kernel: device_batched_gemm_multiple_d_dl.hpp Source File
device_batched_gemm_multiple_d_dl.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
18
19namespace ck {
20namespace tensor_operation {
21namespace device {
22
23/*
24 * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
25 * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
26 * strided batched, but we can easily extend to other layouts. The returned offset can be either \p
27 * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
28 * limitations.
29 *
30 * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
31 * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
32 * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
33 * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for
34 * \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the
35 * computing of pointer offset into \p ComputePtrOffsetOfStridedBatch.
36 */
37
38template <typename GridwiseGemm,
39 typename ABDataType,
40 typename DsPointer,
41 typename EDataType,
42 typename AElementwiseOperation,
43 typename BElementwiseOperation,
44 typename CDEElementwiseOperation,
45 typename AGridDesc_K0_M0_M1_K1,
46 typename BGridDesc_K0_N0_N1_K1,
47 typename DsGridDesc_M0_M10_M11_N0_N10_N11,
48 typename CGridDesc_M0_M10_M11_N0_N10_N11,
49 typename ComputePtrOffsetOfBatch,
50 typename Block2CTileMap,
51 bool HasMainKBlockLoop,
52 bool HasDoubleTailKBlockLoop>
53__global__ void
54#if CK_USE_LAUNCH_BOUNDS
56#endif
58 const ABDataType* __restrict__ p_a_grid,
59 const ABDataType* __restrict__ p_b_grid,
60 DsPointer p_ds_grid,
61 EDataType* __restrict__ p_e_grid,
62 const index_t batch_count,
63 const AElementwiseOperation a_element_op,
64 const BElementwiseOperation b_element_op,
65 const CDEElementwiseOperation cde_element_op,
66 const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
67 const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
68 const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11,
69 const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
70 const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
71 const Block2CTileMap block_2_ctile_map)
72{
73#if(defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || \
74 defined(__gfx103__) || defined(__gfx11__) || defined(__gfx12__))
75
76 const index_t num_blocks_per_batch =
77 __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
78 const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
79
80 const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
81 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
82 const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
83 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
84 const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
85 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
86
87 const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
88
89 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
90
91 DsPointer p_ds_grid_grp;
92
93 static constexpr index_t NumDTensor = DsGridDesc_M0_M10_M11_N0_N10_N11::Size();
94
96 [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
97
98 GridwiseGemm::Run(p_a_grid + a_batch_offset,
99 p_b_grid + b_batch_offset,
100 p_ds_grid_grp,
101 p_e_grid + e_batch_offset,
102 p_shared,
103 a_element_op,
104 b_element_op,
105 cde_element_op,
106 a_grid_desc_k0_m0_m1_k1,
107 b_grid_desc_k0_n0_n1_k1,
108 ds_grid_desc_m0_m10_m11_n0_n10_n11,
109 e_grid_desc_m0_m10_m11_n0_n10_n11,
110 block_2_ctile_map,
113#else
114 ignore = p_a_grid;
115 ignore = p_b_grid;
116 ignore = p_ds_grid;
117 ignore = p_e_grid;
118 ignore = batch_count;
119 ignore = a_element_op;
120 ignore = b_element_op;
121 ignore = cde_element_op;
122 ignore = a_grid_desc_k0_m0_m1_k1;
123 ignore = b_grid_desc_k0_n0_n1_k1;
124 ignore = ds_grid_desc_m0_m10_m11_n0_n10_n11;
125 ignore = e_grid_desc_m0_m10_m11_n0_n10_n11;
126 ignore = compute_ptr_offset_of_batch;
127 ignore = block_2_ctile_map;
128
129#endif
130}
131
132template <typename ALayout,
133 typename BLayout,
134 typename DsLayout,
135 typename ELayout,
136 typename ADataType,
137 typename BDataType,
138 typename AccDataType,
139 typename DsDataType,
140 typename EDataType,
141 typename AElementwiseOperation,
142 typename BElementwiseOperation,
143 typename CDEElementwiseOperation,
144 GemmSpecialization GemmSpec,
145 index_t BlockSize,
146 index_t MPerBlock,
147 index_t NPerBlock,
148 index_t K0PerBlock,
149 index_t K1,
150 index_t M1PerThread,
151 index_t N1PerThread,
152 index_t KPerThread,
153 typename M1N1ThreadClusterM1Xs,
154 typename M1N1ThreadClusterN1Xs,
155 typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
156 typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
157 typename ABlockTransferThreadClusterArrangeOrder,
158 typename ABlockTransferSrcAccessOrder,
159 typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
160 typename ABlockTransferSrcVectorTensorContiguousDimOrder,
161 typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
162 typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
163 typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
164 typename BBlockTransferThreadClusterArrangeOrder,
165 typename BBlockTransferSrcAccessOrder,
166 typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
167 typename BBlockTransferSrcVectorTensorContiguousDimOrder,
168 typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
169 typename CThreadTransferSrcDstAccessOrder,
170 index_t CThreadTransferSrcDstVectorDim,
171 index_t CThreadTransferDstScalarPerVector,
175 bool> = false>
177 BLayout,
178 DsLayout,
179 ELayout,
180 ADataType,
181 BDataType,
182 DsDataType,
183 EDataType,
184 AElementwiseOperation,
185 BElementwiseOperation,
186 CDEElementwiseOperation>
187
188{
190 static constexpr index_t NumDTensor = DsDataType::Size();
191
192 static constexpr auto I0 = Number<0>{};
193 static constexpr auto I1 = Number<1>{};
194 static constexpr auto I2 = Number<2>{};
195 static constexpr auto I3 = Number<3>{};
196 static constexpr auto I4 = Number<4>{};
197 static constexpr auto I5 = Number<5>{};
198
199 static constexpr auto K1Number = Number<K1>{};
200
202 {
203 const index_t K0 = K / K1;
204
205 const auto a_grid_desc_m_k = [&]() {
207 {
208 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
209 }
211 {
212 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
213 }
214 }();
215
216 if constexpr(GemmSpec == GemmSpecialization::MNPadding)
217 {
218 const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
219
221 a_grid_desc_m_k,
223 make_right_pad_transform(M, PadM)),
226 }
227 else
228 {
230 a_grid_desc_m_k,
235 }
236 }
237
239 {
240 const index_t K0 = K / K1;
241
242 const auto b_grid_desc_k_n = [&]() {
244 {
245 return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
246 }
248 {
249 return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
250 }
251 }();
252
253 if constexpr(GemmSpec == GemmSpecialization::MNPadding)
254 {
255 const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
256
258 b_grid_desc_k_n,
260 make_right_pad_transform(N, PadN)),
263 }
264 else
265 {
267 b_grid_desc_k_n,
272 }
273 }
274
275 template <typename ELay>
277 {
278 const auto c_grid_desc_m_n = [&]() {
280 {
281 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideE, I1));
282 }
284 {
285 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideE));
286 }
287 }();
288
289 if constexpr(GemmSpec == GemmSpecialization::MNPadding)
290 {
291 const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
292 const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
293
295 c_grid_desc_m_n,
299 }
300 else
301 {
302
304 c_grid_desc_m_n,
308 }
309 }
310
311 static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
312 const std::array<index_t, NumDTensor>& NRaws,
313 const std::array<index_t, NumDTensor>& DsStride)
314 {
315 return generate_tuple(
316 [&](auto i) {
317 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
318
319 return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
320 },
322 }
323
326 using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {}, {}));
328
330 {
332 index_t BatchStrideB,
333 std::array<ck::index_t, NumDTensor> BatchStrideDs,
334 index_t BatchStrideE)
335 : BatchStrideA_(BatchStrideA),
336 BatchStrideB_(BatchStrideB),
337 BatchStrideDs_(BatchStrideDs),
338 BatchStrideE_(BatchStrideE)
339 {
340 }
341
342 __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
343 {
344 return g_idx * static_cast<long_index_t>(BatchStrideA_);
345 }
346
347 __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
348 {
349 return g_idx * static_cast<long_index_t>(BatchStrideB_);
350 }
351
352 __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
353 {
354 std::array<long_index_t, NumDTensor> ds_offset;
355 static_for<0, NumDTensor, 1>{}([&](auto i) {
356 ds_offset[i] = g_idx * static_cast<long_index_t>(BatchStrideDs_[i]);
357 });
358 return ds_offset;
359 }
360
361 __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
362 {
363 return g_idx * static_cast<long_index_t>(BatchStrideE_);
364 }
365
366 private:
367 index_t BatchStrideA_;
368 index_t BatchStrideB_;
369 std::array<ck::index_t, NumDTensor> BatchStrideDs_;
370 index_t BatchStrideE_;
371 };
372
373 // GridwiseGemm
376 ADataType,
377 AccDataType,
378 DsDataType,
379 EDataType,
380 AElementwiseOperation,
381 BElementwiseOperation,
382 CDEElementwiseOperation,
387 MPerBlock,
388 NPerBlock,
389 K0PerBlock,
390 K1,
391 M1PerThread,
392 N1PerThread,
393 KPerThread,
394 M1N1ThreadClusterM1Xs,
395 M1N1ThreadClusterN1Xs,
396 ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
397 ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
398 ABlockTransferThreadClusterArrangeOrder,
399 ABlockTransferSrcAccessOrder,
400 ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
401 ABlockTransferSrcVectorTensorContiguousDimOrder,
402 ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
403 BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
404 BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
405 BBlockTransferThreadClusterArrangeOrder,
406 BBlockTransferSrcAccessOrder,
407 BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
408 BBlockTransferSrcVectorTensorContiguousDimOrder,
409 BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
410 CThreadTransferSrcDstAccessOrder,
411 CThreadTransferSrcDstVectorDim,
412 CThreadTransferDstScalarPerVector>;
413
424
425 // Argument
426 struct Argument : public BaseArgument
427 {
428 Argument(const void* p_a_grid,
429 const void* p_b_grid,
430 std::array<const void*, NumDTensor> p_ds_grid,
431 void* p_e_grid,
432 index_t M,
433 index_t N,
434 index_t K,
435 index_t Batch,
436 index_t StrideA,
437 index_t StrideB,
438 std::array<index_t, NumDTensor> StrideDs,
439 index_t StrideE,
440 index_t BatchStrideA,
441 index_t BatchStrideB,
442 const std::array<ck::index_t, NumDTensor>& BatchStrideDs,
443 index_t BatchStrideE,
444 AElementwiseOperation a_element_op,
445 BElementwiseOperation b_element_op,
446 CDEElementwiseOperation cde_element_op)
447 : p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
448 p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
449 p_ds_grid_{},
450 p_e_grid_{static_cast<EDataType*>(p_e_grid)},
451 K_(K),
452 Batch_(Batch),
456 compute_ptr_offset_of_batch_{BatchStrideA, BatchStrideB, BatchStrideDs, BatchStrideE},
458 a_element_op_{a_element_op},
459 b_element_op_{b_element_op},
460 cde_element_op_{cde_element_op}
461 {
466 static_for<0, NumDTensor, 1>{}([&](auto i) {
467 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
468 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
469
470 // D pointer
471 p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
472
473 // D desc
476 });
479
482 {
487
490
493
495 }
496 }
497
498 // private:
499 const ADataType* p_a_grid_;
500 const BDataType* p_b_grid_;
502 EDataType* p_e_grid_;
503
505
506 // Batch
508
513
518
519 // for calculating batch offset
521
523
524 // TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being.
525 AElementwiseOperation a_element_op_;
526 BElementwiseOperation b_element_op_;
527 CDEElementwiseOperation cde_element_op_;
528 };
529
530 // Invoker
531 struct Invoker : public BaseInvoker
532 {
534
535 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
536 {
537 {
538 std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{"
539 << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
540 << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
541 << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
542
543 std::cout << "arg.b_grid_desc_k0_n0_n1_k1_{"
544 << arg.b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
545 << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
546 << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
547
548 std::cout << "arg.e_grid_desc_m_n_{ " << arg.e_grid_desc_m_n_.GetLength(I0) << ", "
549 << arg.e_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
550 }
551
554 {
555 throw std::runtime_error(
556 "wrong! GridwiseGemmDlMultipleD_km_kn_mn has invalid setting");
557 }
558
559 const index_t grid_size =
561 arg.e_grid_desc_m_n_.GetLength(I1)) *
562 arg.Batch_;
563
564 auto launch_kernel = [&](auto has_main_k_block_loop,
565 auto has_double_tail_k_block_loop) {
566 constexpr bool has_main_loop = has_main_k_block_loop.value;
567 constexpr bool has_double_loop = has_double_tail_k_block_loop.value;
568
569 const auto kernel =
570 kernel_gemm_dl_multiple_d<GridwiseGemm,
571 ADataType,
573 EDataType,
574 AElementwiseOperation,
575 BElementwiseOperation,
576 CDEElementwiseOperation,
581 ComputePtrOffsetOfStridedBatch,
583 has_main_loop,
584 has_double_loop>;
585
586 return launch_and_time_kernel(stream_config,
587 kernel,
588 dim3(grid_size),
589 dim3(BlockSize),
590 0,
591 arg.p_a_grid_,
592 arg.p_b_grid_,
593 arg.p_ds_grid_,
594 arg.p_e_grid_,
595 arg.Batch_,
596 arg.a_element_op_,
597 arg.b_element_op_,
598 arg.cde_element_op_,
605 };
606
607 const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0);
608 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
609 const bool has_double_tail_k_block_loop =
611
612 if(has_main_k_block_loop && has_double_tail_k_block_loop)
613 {
614 return launch_kernel(integral_constant<bool, true>{},
615 integral_constant<bool, true>{});
616 }
617 else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
618 {
619 return launch_kernel(integral_constant<bool, true>{},
620 integral_constant<bool, false>{});
621 }
622 else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
623 {
624 return launch_kernel(integral_constant<bool, false>{},
625 integral_constant<bool, true>{});
626 }
627 else
628 {
629 return launch_kernel(integral_constant<bool, false>{},
630 integral_constant<bool, false>{});
631 }
632 }
633
634 // polymorphic
635 float Run(const BaseArgument* p_arg,
636 const StreamConfig& stream_config = StreamConfig{}) override
637 {
638 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
639 }
640 };
641
642 static constexpr bool IsValidCompilationParameter()
643 {
644 // TODO: properly implement this check
645 return true;
646 }
647
648 static bool IsSupportedArgument(const Argument& arg)
649 {
650 if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
652 {
653 bool pass = true;
654 pass = pass && arg.K_ % K1 == 0;
655
658 arg.e_grid_desc_m_n_);
659
660 return pass;
661 }
662 else
663 {
664 return false;
665 }
666 }
667
668 // polymorphic
669 bool IsSupportedArgument(const BaseArgument* p_arg) override
670 {
671 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
672 }
673
674 static auto MakeArgument(const void* p_a,
675 const void* p_b,
676 std::array<const void*, NumDTensor> p_ds,
677 void* p_e,
678 index_t M,
679 index_t N,
680 index_t K,
681 index_t Batch,
682 index_t StrideA,
683 index_t StrideB,
684 std::array<ck::index_t, NumDTensor> StrideDs,
685 index_t StrideE,
686 index_t BatchStrideA,
687 index_t BatchStrideB,
688 const std::array<ck::index_t, NumDTensor>& BatchStrideDs,
689 index_t BatchStrideE,
690 AElementwiseOperation a_element_op,
691 BElementwiseOperation b_element_op,
692 CDEElementwiseOperation cde_element_op)
693 {
694 return Argument{p_a,
695 p_b,
696 p_ds,
697 p_e,
698 M,
699 N,
700 K,
701 Batch,
702 StrideA,
703 StrideB,
704 StrideDs,
705 StrideE,
706 BatchStrideA,
707 BatchStrideB,
708 BatchStrideDs,
709 BatchStrideE,
710 a_element_op,
711 b_element_op,
712 cde_element_op};
713 }
714
715 static auto MakeInvoker() { return Invoker{}; }
716
717 // polymorphic
718 std::unique_ptr<BaseArgument>
719 MakeArgumentPointer(const void* p_a,
720 const void* p_b,
721 const std::array<const void*, NumDTensor>& p_ds,
722 void* p_e,
723 index_t M,
724 index_t N,
725 index_t K,
726 index_t Batch,
727 index_t StrideA,
728 index_t StrideB,
729 const std::array<ck::index_t, NumDTensor>& StrideDs,
730 index_t StrideE,
731 index_t BatchStrideA,
732 index_t BatchStrideB,
733 const std::array<ck::index_t, NumDTensor>& BatchStrideDs,
734 index_t BatchStrideE,
735 AElementwiseOperation a_element_op,
736 BElementwiseOperation b_element_op,
737 CDEElementwiseOperation cde_element_op) override
738 {
739 return std::make_unique<Argument>(p_a,
740 p_b,
741 p_ds,
742 p_e,
743 M,
744 N,
745 K,
746 Batch,
747 StrideA,
748 StrideB,
749 StrideDs,
750 StrideE,
751 BatchStrideA,
752 BatchStrideB,
753 BatchStrideDs,
754 BatchStrideE,
755 a_element_op,
756 b_element_op,
757 cde_element_op);
758 }
759
760 // polymorphic
761 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
762 {
763 return std::make_unique<Invoker>(Invoker{});
764 }
765
766 // polymorphic
767 std::string GetTypeString() const override
768 {
769 auto str = std::stringstream();
770
771 // clang-format off
772 str << "DeviceBatchedGemmMultipleD_Dl"
773 << "<"
774 << BlockSize << ", "
775 << MPerBlock << ", "
776 << NPerBlock << ", "
777 << K0PerBlock << ", "
778 << K1 << ", "
779 << M1PerThread << ", "
780 << N1PerThread << ", "
781 << KPerThread
782 << ">";
783 // clang-format on
784
785 return str.str();
786 }
787};
788
789} // namespace device
790} // namespace tensor_operation
791} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition convolution_backward_data_specialization.hpp:8
__global__ void kernel_gemm_dl_multiple_d(const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid, const index_t batch_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2CTileMap block_2_ctile_map)
Definition device_batched_gemm_multiple_d_dl.hpp:57
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MNPadding
Definition gemm_specialization.hpp:17
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
Definition ck.hpp:268
bool is_xdl_supported()
Definition host_utility/device_prop.hpp:68
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__device__ index_t get_grid_size()
Definition get_id.hpp:49
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
std::string get_device_name()
Definition host_utility/device_prop.hpp:19
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
__global__ void kernel_gemm_dl_multiple_d(const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap block_2_ctile_map)
Definition device_gemm_multiple_d_dl.hpp:39
bool is_gfx103_supported()
Definition host_utility/device_prop.hpp:120
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
typename std::enable_if< B, T >::type enable_if_t
Definition enable_if.hpp:27
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_dl_multiple_d.hpp:60
Definition utility/sequence.hpp:43
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition utility/integral_constant.hpp:20
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_batched_gemm_multi_d.hpp:27
Definition device_batched_gemm_multiple_d_dl.hpp:427
EDataType * p_e_grid_
Definition device_batched_gemm_multiple_d_dl.hpp:502
BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_
Definition device_batched_gemm_multiple_d_dl.hpp:515
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_
Definition device_batched_gemm_multiple_d_dl.hpp:520
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_
Definition device_batched_gemm_multiple_d_dl.hpp:509
BElementwiseOperation b_element_op_
Definition device_batched_gemm_multiple_d_dl.hpp:526
CDEElementwiseOperation cde_element_op_
Definition device_batched_gemm_multiple_d_dl.hpp:527
DefaultBlock2CTileMap block_2_ctile_map_
Definition device_batched_gemm_multiple_d_dl.hpp:522
EGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11_
Definition device_batched_gemm_multiple_d_dl.hpp:517
Argument(const void *p_a_grid, const void *p_b_grid, std::array< const void *, NumDTensor > p_ds_grid, void *p_e_grid, index_t M, index_t N, index_t K, index_t Batch, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideE, index_t BatchStrideA, index_t BatchStrideB, const std::array< ck::index_t, NumDTensor > &BatchStrideDs, index_t BatchStrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_batched_gemm_multiple_d_dl.hpp:428
EGridDesc_M_N e_grid_desc_m_n_
Definition device_batched_gemm_multiple_d_dl.hpp:512
AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_
Definition device_batched_gemm_multiple_d_dl.hpp:514
const ADataType * p_a_grid_
Definition device_batched_gemm_multiple_d_dl.hpp:499
DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11_
Definition device_batched_gemm_multiple_d_dl.hpp:516
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_
Definition device_batched_gemm_multiple_d_dl.hpp:510
const BDataType * p_b_grid_
Definition device_batched_gemm_multiple_d_dl.hpp:500
index_t K_
Definition device_batched_gemm_multiple_d_dl.hpp:504
index_t Batch_
Definition device_batched_gemm_multiple_d_dl.hpp:507
AElementwiseOperation a_element_op_
Definition device_batched_gemm_multiple_d_dl.hpp:525
DsGridDesc_M_N ds_grid_desc_m_n_
Definition device_batched_gemm_multiple_d_dl.hpp:511
GridwiseGemm::DsGridPointer p_ds_grid_
Definition device_batched_gemm_multiple_d_dl.hpp:501
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, index_t BatchStrideB, std::array< ck::index_t, NumDTensor > BatchStrideDs, index_t BatchStrideE)
Definition device_batched_gemm_multiple_d_dl.hpp:331
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
Definition device_batched_gemm_multiple_d_dl.hpp:347
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
Definition device_batched_gemm_multiple_d_dl.hpp:361
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
Definition device_batched_gemm_multiple_d_dl.hpp:352
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
Definition device_batched_gemm_multiple_d_dl.hpp:342
Definition device_batched_gemm_multiple_d_dl.hpp:532
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_gemm_multiple_d_dl.hpp:535
DeviceBatchedGemmMultipleD_Dl::Argument Argument
Definition device_batched_gemm_multiple_d_dl.hpp:533
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_batched_gemm_multiple_d_dl.hpp:635
Definition device_batched_gemm_multiple_d_dl.hpp:188
static constexpr index_t NumDTensor
Definition device_batched_gemm_multiple_d_dl.hpp:190
static constexpr auto I5
Definition device_batched_gemm_multiple_d_dl.hpp:197
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(EGridDesc_M_N{})) DefaultBlock2CTileMap
Definition device_batched_gemm_multiple_d_dl.hpp:422
GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector > GridwiseGemm
Definition device_batched_gemm_multiple_d_dl.hpp:374
static constexpr bool IsValidCompilationParameter()
Definition device_batched_gemm_multiple_d_dl.hpp:642
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_batched_gemm_multiple_d_dl.hpp:761
static constexpr auto I3
Definition device_batched_gemm_multiple_d_dl.hpp:195
static constexpr auto I0
Definition device_batched_gemm_multiple_d_dl.hpp:192
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
Definition device_batched_gemm_multiple_d_dl.hpp:201
static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE)
Definition device_batched_gemm_multiple_d_dl.hpp:276
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, index_t M, index_t N, index_t K, index_t Batch, index_t StrideA, index_t StrideB, const std::array< ck::index_t, NumDTensor > &StrideDs, index_t StrideE, index_t BatchStrideA, index_t BatchStrideB, const std::array< ck::index_t, NumDTensor > &BatchStrideDs, index_t BatchStrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition device_batched_gemm_multiple_d_dl.hpp:719
static bool IsSupportedArgument(const Argument &arg)
Definition device_batched_gemm_multiple_d_dl.hpp:648
decltype(MakeDsGridDescriptor_M_N({}, {}, {})) DsGridDesc_M_N
Definition device_batched_gemm_multiple_d_dl.hpp:326
static auto MakeInvoker()
Definition device_batched_gemm_multiple_d_dl.hpp:715
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
Definition device_batched_gemm_multiple_d_dl.hpp:238
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_batched_gemm_multiple_d_dl.hpp:669
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})) AGridDesc_K0_M0_M1_K1
Definition device_batched_gemm_multiple_d_dl.hpp:414
decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)) BGridDesc_K0_N_K1
Definition device_batched_gemm_multiple_d_dl.hpp:325
static constexpr auto K1Number
Definition device_batched_gemm_multiple_d_dl.hpp:199
decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)) AGridDesc_K0_M_K1
Definition device_batched_gemm_multiple_d_dl.hpp:324
decltype(MakeEGridDescriptor_M_N< ELayout >(1, 1, 1)) EGridDesc_M_N
Definition device_batched_gemm_multiple_d_dl.hpp:327
static constexpr auto I4
Definition device_batched_gemm_multiple_d_dl.hpp:196
decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})) BGridDesc_K0_N0_N1_K1
Definition device_batched_gemm_multiple_d_dl.hpp:416
std::string GetTypeString() const override
Definition device_batched_gemm_multiple_d_dl.hpp:767
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(EGridDesc_M_N{})) EGridDesc_M0_M10_M11_N0_N10_N11
Definition device_batched_gemm_multiple_d_dl.hpp:420
static constexpr auto I2
Definition device_batched_gemm_multiple_d_dl.hpp:194
static constexpr auto I1
Definition device_batched_gemm_multiple_d_dl.hpp:193
decltype(GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(DsGridDesc_M_N{})) DsGridDesc_M0_M10_M11_N0_N10_N11
Definition device_batched_gemm_multiple_d_dl.hpp:418
static auto MakeArgument(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, index_t M, index_t N, index_t K, index_t Batch, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideE, index_t BatchStrideA, index_t BatchStrideB, const std::array< ck::index_t, NumDTensor > &BatchStrideDs, index_t BatchStrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_batched_gemm_multiple_d_dl.hpp:674
DeviceBatchedGemmMultipleD_Dl DeviceOp
Definition device_batched_gemm_multiple_d_dl.hpp:189
static auto MakeDsGridDescriptor_M_N(const std::array< index_t, NumDTensor > &MRaws, const std::array< index_t, NumDTensor > &NRaws, const std::array< index_t, NumDTensor > &DsStride)
Definition device_batched_gemm_multiple_d_dl.hpp:311