23template <
typename GridwiseGemm,
25 typename FloatDsPointer,
27 typename AElementwiseOperation,
28 typename BElementwiseOperation,
29 typename CDEElementwiseOperation,
30 typename AGridDesc_AK0_M_AK1,
31 typename BGridDesc_BK0_N_BK1,
32 typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
33 typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
34 typename Block2ETileMap,
35 bool HasMainKBlockLoop>
37#if CK_USE_LAUNCH_BOUNDS
41 const FloatAB* __restrict__ p_a_grid,
42 const FloatAB* __restrict__ p_b_grid,
43 FloatDsPointer p_ds_grid,
44 FloatE* __restrict__ p_e_grid,
45 const AElementwiseOperation a_element_op,
46 const BElementwiseOperation b_element_op,
47 const CDEElementwiseOperation cde_element_op,
48 const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
49 const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
50 const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
51 ds_grid_desc_mblock_mperblock_nblock_nperblock,
52 const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
53 e_grid_desc_mblock_mperblock_nblock_nperblock,
54 const Block2ETileMap block_2_etile_map)
56#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
57 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
59 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
61 GridwiseGemm::template Run<HasMainKBlockLoop, InMemoryDataOperationEnum::Set>(
70 a_grid_desc_ak0_m_ak1,
71 b_grid_desc_bk0_n_bk1,
72 ds_grid_desc_mblock_mperblock_nblock_nperblock,
73 e_grid_desc_mblock_mperblock_nblock_nperblock,
84 ignore = a_grid_desc_ak0_m_ak1;
85 ignore = b_grid_desc_bk0_n_bk1;
86 ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
87 ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
88 ignore = block_2_etile_map;
115 typename AccDataType,
116 typename CShuffleDataType,
119 typename AElementwiseOperation,
120 typename BElementwiseOperation,
121 typename CDEElementwiseOperation,
134 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
135 typename ABlockTransferThreadClusterArrangeOrder,
136 typename ABlockTransferSrcAccessOrder,
137 index_t ABlockTransferSrcVectorDim,
138 index_t ABlockTransferSrcScalarPerVector,
139 index_t ABlockTransferDstScalarPerVector_AK1,
140 bool ABlockLdsExtraM,
141 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
142 typename BBlockTransferThreadClusterArrangeOrder,
143 typename BBlockTransferSrcAccessOrder,
144 index_t BBlockTransferSrcVectorDim,
145 index_t BBlockTransferSrcScalarPerVector,
146 index_t BBlockTransferDstScalarPerVector_BK1,
147 bool BBlockLdsExtraN,
148 index_t CShuffleMXdlPerWavePerShuffle,
149 index_t CShuffleNXdlPerWavePerShuffle,
150 typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
151 index_t CDEBlockTransferScalarPerVector_NPerBlock,
152 typename ComputeDataType = ADataType,
162 AElementwiseOperation,
163 BElementwiseOperation,
164 CDEElementwiseOperation,
185 const std::vector<index_t>& a_ms_ks_strides_vec)
187 assert(a_ms_ks_lengths_vec.size() == NumDimM + NumDimK &&
188 a_ms_ks_strides_vec.size() == NumDimM + NumDimK);
190 const auto to_tuple = [&](
auto& vec,
auto num) {
201 constexpr auto kDimIds =
211 const auto a_grid_desc_ms_ks =
221 return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
226 const std::vector<index_t>& b_ns_ks_strides_vec)
228 assert(b_ns_ks_lengths_vec.size() == NumDimN + NumDimK &&
229 b_ns_ks_strides_vec.size() == NumDimN + NumDimK);
231 const auto to_tuple = [&](
auto& vec,
auto num) {
242 constexpr auto kDimIds =
252 const auto b_grid_desc_ns_ks =
262 return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
267 const std::vector<index_t>& e_ms_ns_strides_vec)
269 assert(e_ms_ns_lengths_vec.size() == NumDimM + NumDimN &&
270 e_ms_ns_strides_vec.size() == NumDimM + NumDimN);
272 const auto to_tuple = [&](
auto& vec,
auto num) {
283 constexpr auto nDimIds =
293 const auto e_grid_desc_ms_ns =
303 return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
307 const std::array<std::vector<index_t>,
NumDTensor>& ds_ms_ns_lengths_vec,
308 const std::array<std::vector<index_t>,
NumDTensor>& ds_ms_ns_strides_vec)
313 ds_ms_ns_strides_vec[i]);
324 template <index_t NXdlPerWave_>
333 AElementwiseOperation,
334 BElementwiseOperation,
335 CDEElementwiseOperation,
336 NumGemmKPrefetchStage,
347 ABlockTransferThreadClusterLengths_AK0_M_AK1,
348 ABlockTransferThreadClusterArrangeOrder,
349 ABlockTransferSrcAccessOrder,
350 ABlockTransferSrcVectorDim,
351 ABlockTransferSrcScalarPerVector,
352 ABlockTransferDstScalarPerVector_AK1,
355 BBlockTransferThreadClusterLengths_BK0_N_BK1,
356 BBlockTransferThreadClusterArrangeOrder,
357 BBlockTransferSrcAccessOrder,
358 BBlockTransferSrcVectorDim,
359 BBlockTransferSrcScalarPerVector,
360 BBlockTransferDstScalarPerVector_BK1,
363 CShuffleMXdlPerWavePerShuffle,
364 CShuffleNXdlPerWavePerShuffle,
365 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
366 CDEBlockTransferScalarPerVector_NPerBlock,
393 const void* p_b_grid,
394 std::array<const void*, NumDTensor> p_ds_grid,
396 const std::vector<index_t>& a_ms_ks_lengths,
397 const std::vector<index_t>& a_ms_ks_strides,
398 const std::vector<index_t>& b_ns_ks_lengths,
399 const std::vector<index_t>& b_ns_ks_strides,
400 const std::array<std::vector<index_t>,
NumDTensor>& ds_ms_ns_lengths,
401 const std::array<std::vector<index_t>,
NumDTensor>& ds_ms_ns_strides,
402 const std::vector<index_t>& e_ms_ns_lengths,
403 const std::vector<index_t>& e_ms_ns_strides,
404 AElementwiseOperation a_element_op,
405 BElementwiseOperation b_element_op,
406 CDEElementwiseOperation cde_element_op)
407 :
p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
408 p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
410 p_e_grid_{static_cast<EDataType*>(p_e_grid)},
429 p_ds_grid_(i) =
static_cast<const DDataType*
>(p_ds_grid[i]);
458 [&](
auto i) { std::cout <<
"Ds[M, N]: " <<
ds_grid_desc_m_n_[i] << std::endl; });
507 template <
typename Gr
idwiseGemm>
516 throw std::runtime_error(
517 "wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting");
519 auto e_grid_desc_mblock_mperblock_nblock_nperblock =
520 GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
523 auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
524 GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
532 auto launch_kernel = [&](
auto has_main_k_block_loop) {
533 constexpr bool has_main_loop = has_main_k_block_loop.value;
538 typename GridwiseGemm::DsGridPointer,
540 AElementwiseOperation,
541 BElementwiseOperation,
542 CDEElementwiseOperation,
564 ds_grid_desc_mblock_mperblock_nblock_nperblock,
565 e_grid_desc_mblock_mperblock_nblock_nperblock,
569 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
575 return launch_kernel(integral_constant<bool, false>{});
585 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
629 static_assert((ABlockTransferSrcVectorDim == 1 || ABlockTransferSrcVectorDim == 2) &&
630 (BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2),
633 const bool valid_a_vector_size =
635 const bool valid_a_access_dim_m =
637 const bool valid_a_access_dim_k =
639 const bool valid_a_access_dim =
640 valid_a_access_dim_m || valid_a_access_dim_k || ABlockTransferSrcScalarPerVector == 1;
641 if(!(valid_a_vector_size && valid_a_access_dim))
646 const bool valid_b_vector_size =
648 const bool valid_b_access_dim_n =
650 const bool valid_b_access_dim_k =
652 const bool valid_b_access_dim =
653 valid_b_access_dim_n || valid_b_access_dim_k || BBlockTransferSrcScalarPerVector == 1;
654 if(!(valid_b_vector_size && valid_b_access_dim))
659 bool valid_ds_access =
true;
661 const bool valid_d_vector_size =
664 const bool valid_d_access_dim =
666 if(!(valid_d_vector_size && valid_d_access_dim))
668 valid_ds_access =
false;
676 const bool valid_e_vector_size =
679 const bool valid_e_access_dim =
680 arg.
e_continous_dim_ == 1 || CDEBlockTransferScalarPerVector_NPerBlock == 1;
681 if(!(valid_e_vector_size && valid_e_access_dim))
697 std::array<const void*, NumDTensor> p_ds,
699 const std::vector<index_t>& a_ms_ks_lengths,
700 const std::vector<index_t>& a_ms_ks_strides,
701 const std::vector<index_t>& b_ns_ks_lengths,
702 const std::vector<index_t>& b_ns_ks_strides,
703 const std::array<std::vector<index_t>,
NumDTensor>& ds_ms_ns_lengths,
704 const std::array<std::vector<index_t>,
NumDTensor>& ds_ms_ns_strides,
705 const std::vector<index_t>& e_ms_ns_lengths,
706 const std::vector<index_t>& e_ms_ns_strides,
707 AElementwiseOperation a_element_op,
708 BElementwiseOperation b_element_op,
709 CDEElementwiseOperation cde_element_op)
731 std::unique_ptr<BaseArgument>
734 std::array<const void*, NumDTensor> p_ds,
736 const std::vector<index_t>& a_ms_ks_lengths,
737 const std::vector<index_t>& a_ms_ks_strides,
738 const std::vector<index_t>& b_ns_ks_lengths,
739 const std::vector<index_t>& b_ns_ks_strides,
740 const std::array<std::vector<index_t>,
NumDTensor>& ds_ms_ns_lengths,
741 const std::array<std::vector<index_t>,
NumDTensor>& ds_ms_ns_strides,
742 const std::vector<index_t>& e_ms_ns_lengths,
743 const std::vector<index_t>& e_ms_ns_strides,
744 AElementwiseOperation a_element_op,
745 BElementwiseOperation b_element_op,
746 CDEElementwiseOperation cde_element_op)
override
748 return std::make_unique<Argument>(p_a,
768 return std::make_unique<Invoker>(
Invoker{});
774 auto str = std::stringstream();
777 str <<
"DeviceContractionMultipleD_Xdl_CShuffle"
788 << ABlockTransferSrcVectorDim <<
", "
789 << BBlockTransferSrcVectorDim
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
auto CalculateMaxRead(const std::vector< index_t > &lengths, const std::vector< index_t > &strides)
Definition device_contraction_utils.hpp:33
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
bool is_lds_direct_load_supported()
Definition host_utility/device_prop.hpp:101
__global__ void kernel_contraction_multiple_d_xdl_cshuffle(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatDsPointer p_ds_grid, FloatE *__restrict__ p_e_grid, const index_t batch_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2ETileMap block_2_etile_map)
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:41
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
__host__ __device__ constexpr auto get_container_subset(const Array< T, N > &arr, Sequence< Is... >)
Definition utility/container_helper.hpp:346
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:78
ck::GridwiseGemmMultipleD_xdl_cshuffle< ADataType, BDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched >::MakeDefaultBGridDescriptor_BK0_N_BK1 __host__ static __device__ constexpr auto MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k)
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:207
ck::GridwiseGemmMultipleD_xdl_cshuffle< ADataType, BDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched >::MakeDefaultAGridDescriptor_AK0_M_AK1 __host__ static __device__ constexpr auto MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k)
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:190
ck::GridwiseGemmMultipleD_xdl_cshuffle< ADataType, BDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched >::CheckValidity __host__ static __device__ constexpr bool CheckValidity(const AGridDesc_M_K &a_grid_desc_m_k, const BGridDesc_N_K &b_grid_desc_n_k, const DsGridDesc_M_N &ds_grid_desc_m_n, const EGridDesc_M_N &e_grid_desc_m_n, const Block2ETileMap &, index_t k_batch=1)
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:334
ck::GridwiseGemmMultipleD_xdl_cshuffle< ADataType, BDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched >::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock __host__ static __device__ constexpr auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N &e_grid_desc_m_n)
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:224
ck::GridwiseGemmMultipleD_xdl_cshuffle< ADataType, BDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched >::MakeDefaultBlock2ETileMap __host__ static __device__ constexpr auto MakeDefaultBlock2ETileMap(const EGridDesc_M_N &e_grid_desc_m_n)
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:257
ck::GridwiseGemmMultipleD_xdl_cshuffle< ADataType, BDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched >::DsGridPointer decltype(MakeDsGridPointer()) DsGridPointer
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:411
ck::GridwiseGemmMultipleD_xdl_cshuffle< ADataType, BDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched >::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock __host__ static __device__ constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N &ds_grid_desc_m_n)
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:245
Definition utility/sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition utility/sequence.hpp:271
Definition utility/integral_constant.hpp:20
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:391
index_t a_continous_dim_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:491
BGridDesc_N_K b_grid_desc_n_k_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:471
Argument(const void *p_a_grid, const void *p_b_grid, std::array< const void *, NumDTensor > p_ds_grid, void *p_e_grid, const std::vector< index_t > &a_ms_ks_lengths, const std::vector< index_t > &a_ms_ks_strides, const std::vector< index_t > &b_ns_ks_lengths, const std::vector< index_t > &b_ns_ks_strides, const std::array< std::vector< index_t >, NumDTensor > &ds_ms_ns_lengths, const std::array< std::vector< index_t >, NumDTensor > &ds_ms_ns_strides, const std::vector< index_t > &e_ms_ns_lengths, const std::vector< index_t > &e_ms_ns_strides, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:392
index_t e_continous_dim_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:494
index_t b_continous_dim_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:492
CDEElementwiseOperation cde_element_op_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:488
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:477
index_t e_max_write_elems_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:499
EGridDesc_M_N e_grid_desc_m_n_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:473
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:480
AElementwiseOperation a_element_op_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:486
DsGridDesc_M_N ds_grid_desc_m_n_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:472
void Print() const
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:453
const ADataType * p_a_grid_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:464
Block2ETileMap block_2_etile_map_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:483
index_t b_max_read_elems_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:497
const BDataType * p_b_grid_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:465
EDataType * p_e_grid_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:467
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:476
AGridDesc_M_K a_grid_desc_m_k_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:470
std::array< index_t, NumDTensor > ds_continous_dim_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:493
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:479
GridwiseGemm64::DsGridPointer p_ds_grid_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:466
index_t a_max_read_elems_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:496
BElementwiseOperation b_element_op_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:487
std::array< index_t, NumDTensor > ds_max_read_elems_
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:498
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:504
DeviceOp::Argument Argument
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:505
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:508
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:582
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:166
static auto MakeEGridDescriptor_M_N(const std::vector< index_t > &e_ms_ns_lengths_vec, const std::vector< index_t > &e_ms_ns_strides_vec)
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:266
DeviceContractionMultipleD_Xdl_CShuffle DeviceOp
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:167
decltype(MakeBGridDescriptor_N_K({}, {})) BGridDesc_N_K
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:319
remove_cvref_t< decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))> DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:378
decltype(MakeEGridDescriptor_M_N({}, {})) EGridDesc_M_N
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:321
static auto MakeBGridDescriptor_N_K(const std::vector< index_t > &b_ns_ks_lengths_vec, const std::vector< index_t > &b_ns_ks_strides_vec)
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:225
static auto MakeArgument(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, const std::vector< index_t > &a_ms_ks_lengths, const std::vector< index_t > &a_ms_ks_strides, const std::vector< index_t > &b_ns_ks_lengths, const std::vector< index_t > &b_ns_ks_strides, const std::array< std::vector< index_t >, NumDTensor > &ds_ms_ns_lengths, const std::array< std::vector< index_t >, NumDTensor > &ds_ms_ns_strides, const std::vector< index_t > &e_ms_ns_lengths, const std::vector< index_t > &e_ms_ns_strides, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:695
remove_cvref_t< decltype(GridwiseGemm64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( EGridDesc_M_N{}))> EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:381
GridwiseGemmMultipleD_xdl_cshuffle< ADataType, BDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched > GridwiseGemmBase
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:325
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1( BGridDesc_N_K{}))> BGridDesc_BK0_N_BK1
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:375
static constexpr auto NXdlPerWave32
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:171
decltype(MakeAGridDescriptor_M_K({}, {})) AGridDesc_M_K
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:318
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:170
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:690
static auto MakeAGridDescriptor_M_K(const std::vector< index_t > &a_ms_ks_lengths_vec, const std::vector< index_t > &a_ms_ks_strides_vec)
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:184
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:766
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:369
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:368
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, const std::vector< index_t > &a_ms_ks_lengths, const std::vector< index_t > &a_ms_ks_strides, const std::vector< index_t > &b_ns_ks_lengths, const std::vector< index_t > &b_ns_ks_strides, const std::array< std::vector< index_t >, NumDTensor > &ds_ms_ns_lengths, const std::array< std::vector< index_t >, NumDTensor > &ds_ms_ns_strides, const std::vector< index_t > &e_ms_ns_lengths, const std::vector< index_t > &e_ms_ns_strides, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:732
std::string GetTypeString() const override
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:772
static constexpr auto matrix_padder
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:180
static constexpr index_t NumDTensor
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:173
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N({{}}, {{}}))> DsGridDesc_M_N
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:320
static constexpr auto I1
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:176
static auto MakeDsGridDescriptor_M_N(const std::array< std::vector< index_t >, NumDTensor > &ds_ms_ns_lengths_vec, const std::array< std::vector< index_t >, NumDTensor > &ds_ms_ns_strides_vec)
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:306
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1( AGridDesc_M_K{}))> AGridDesc_AK0_M_AK1
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:372
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))> Block2ETileMap
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:386
static constexpr auto I2
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:177
static bool IsSupportedArgument(const Argument &arg)
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:589
static constexpr auto I3
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:178
static constexpr auto I0
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:175
static auto MakeInvoker()
Definition device_contraction_multiple_d_xdl_cshuffle.hpp:728
Definition device_contraction_multiple_d.hpp:39
Definition matrix_padder.hpp:180