24template <
typename ALayout,
29 typename AScaleDataType,
31 typename BScaleDataType,
34 typename GemmAccDataType,
35 typename CShuffleDataType,
36 typename AElementwiseOperation,
37 typename BElementwiseOperation,
38 typename CElementwiseOperation,
51 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
52 typename ABlockTransferThreadClusterArrangeOrder,
53 typename ABlockTransferSrcAccessOrder,
54 index_t ABlockTransferSrcVectorDim,
55 index_t ABlockTransferSrcScalarPerVector,
56 index_t ABlockTransferDstScalarPerVector_AK1,
58 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
59 typename BBlockTransferThreadClusterArrangeOrder,
60 typename BBlockTransferSrcAccessOrder,
61 index_t BBlockTransferSrcVectorDim,
62 index_t BBlockTransferSrcScalarPerVector,
63 index_t BBlockTransferDstScalarPerVector_BK1,
65 index_t CShuffleMXdlPerWavePerShuffle,
66 index_t CShuffleNXdlPerWavePerShuffle,
67 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
68 typename CDEShuffleBlockTransferScalarPerVectors,
72 bool NSwizzle =
false,
73 bool IsInputGemm =
true,
74 bool MulRoutedWeight =
true,
76 typename ComputeTypeA = ADataType,
77 typename ComputeTypeB = BDataType>
89 AElementwiseOperation,
90 BElementwiseOperation,
91 CElementwiseOperation>
97 template <index_t NXdlPerWave_>
111 AElementwiseOperation,
112 BElementwiseOperation,
113 CElementwiseOperation,
126 ABlockTransferThreadClusterLengths_AK0_M_AK1,
127 ABlockTransferThreadClusterArrangeOrder,
128 ABlockTransferSrcAccessOrder,
129 ABlockTransferSrcVectorDim,
130 ABlockTransferSrcScalarPerVector,
131 ABlockTransferDstScalarPerVector_AK1,
134 BBlockTransferThreadClusterLengths_BK0_N_BK1,
135 BBlockTransferThreadClusterArrangeOrder,
136 BBlockTransferSrcAccessOrder,
137 BBlockTransferSrcVectorDim,
138 BBlockTransferSrcScalarPerVector,
139 BBlockTransferDstScalarPerVector_BK1,
142 CShuffleMXdlPerWavePerShuffle,
143 CShuffleNXdlPerWavePerShuffle,
144 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
145 CDEShuffleBlockTransferScalarPerVectors,
158 using Argument =
typename GridwiseGemm64::Argument;
168 template <
typename Gr
idwiseGemm>
169 float RunImp(
const typename GridwiseGemm::Argument& arg,
172 if(stream_config.log_level_ > 0)
177 if(!GridwiseGemm::CheckValidity(arg))
179 throw std::runtime_error(
"wrong! GridwiseGemm has invalid setting");
183 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
187 index_t k_grain = arg.KBatch * KPerBlock;
188 index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
190 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
192 const auto RunKernel = [&](
const auto& kernel) {
193 if(stream_config.flush_cache)
196 std::array<std::size_t, NumDTensor> DsSize;
200 const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
201 arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
202 const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
203 arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
206 a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
sizeof(ADataType);
208 b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
sizeof(BDataType);
210 const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
211 arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
215 DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() *
sizeof(DDataType);
220 stream_config.rotating_count,
224 rotating_mem.Print();
226 auto run_flush_cache = [&]() {
233 hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
235 arg_.M * arg_.N *
sizeof(CDataType),
236 stream_config.stream_id_));
251 hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
253 arg.M * arg.N *
sizeof(CDataType),
254 stream_config.stream_id_));
257 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
262 constexpr index_t minimum_occupancy =
265 MPerBlock * NPerBlock * KPerBlock *
sizeof(ADataType) <= 128 * 128 * 64 * 2)
270 constexpr auto MemoryDataOp =
273 if(has_main_k_block_loop)
279 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
301 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
322 throw std::runtime_error(
"todo: only v1 & v3 support now");
329 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
350 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
380 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
413 if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
430 reinterpret_cast<const typename GridwiseGemm32::Argument&
>(arg));
443 const void* p_sorted_expert_ids,
444 const void* p_max_token_id,
446 const void* p_a_scale,
448 const void* p_b_scale,
449 std::array<const void*, NumDTensor> p_ds,
460 std::array<index_t, NumDTensor> StrideDs,
463 AElementwiseOperation a_element_op,
464 BElementwiseOperation b_element_op,
465 CElementwiseOperation c_element_op)
468 static_cast<const index_t*
>(p_sorted_expert_ids),
469 static_cast<const index_t*
>(p_max_token_id),
470 static_cast<const ADataType*
>(p_a),
471 static_cast<const AScaleDataType*
>(p_a_scale),
472 static_cast<const BDataType*
>(p_b),
473 static_cast<const BScaleDataType*
>(p_b_scale),
475 static_cast<CDataType*
>(p_c),
497 const void* p_a_scale,
499 const void* p_b_scale,
500 std::array<const void*, NumDTensor> p_ds,
509 std::array<ck::index_t, NumDTensor> StrideDs,
512 AElementwiseOperation a_element_op,
513 BElementwiseOperation b_element_op,
514 CElementwiseOperation c_element_op)
override
516 return std::make_unique<Argument>(
nullptr,
519 static_cast<const ADataType*
>(p_a),
520 static_cast<const AScaleDataType*
>(p_a_scale),
521 static_cast<const BDataType*
>(p_b),
522 static_cast<const BScaleDataType*
>(p_b_scale),
524 static_cast<CDataType*
>(p_c),
545 return std::make_unique<Invoker>(
Invoker{});
551 auto str = std::stringstream();
553 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
557 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
565 str <<
"DeviceMoeGEmmMx"
568 << std::string(ALayout::name)[0]
569 << std::string(BLayout::name)[0]
570 << std::string(CLayout::name)[0]
575 << MPerBlock<<
"x"<<NPerBlock<<
"x"<<KPerBlock <<
", "
577 << MPerXDL<<
"x"<<NPerXDL <<
", "
579 << MXdlPerWave<<
"x" << NXdlPerWave<<
", "
581 << ABlockTransferSrcScalarPerVector<<
"x"<<BBlockTransferSrcScalarPerVector<<
", "
582 <<
"BlkGemmPipelineScheduler: "
583 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] <<
", "
584 <<
"BlkGemmPipelineVersion: "
585 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] <<
", "
586 <<
"BlkGemmPipelinePrefetchStages: "
587 << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
#define INVOKER_RUN3_IMPL
Definition device_base.hpp:114
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
void flush_icache()
Definition flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition flush_cache.hpp:398
__global__ void kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
Definition gridwise_moe_mx_gemm_bns.hpp:48
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
__global__ void kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_moe_mx_gemm.hpp:90
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
constexpr index_t packed_size_v
Definition data_type.hpp:411
bool is_bf16_atomic_supported()
Definition host_utility/device_prop.hpp:108
Definition ck/stream_config.hpp:10
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:174
ck::GridwiseMoeGemmMX_BPreshuffle< ALayout, BLayout, DsLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, math::max(2, NXdlPerWave_), ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ActivationOP, NSwizzle, IsInputGemm, MulRoutedWeight, IndexType, ComputeTypeA, ComputeTypeB >::CheckValidity static __host__ constexpr bool CheckValidity(const Argument &karg)
Definition gridwise_moe_mx_gemm_bpreshuffle.hpp:1068
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_gemm_multiple_d.hpp:167
Definition device_moe_mx_gemm_bpreshuffle.hpp:167
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_moe_mx_gemm_bpreshuffle.hpp:169
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_moe_mx_gemm_bpreshuffle.hpp:377
Definition device_moe_mx_gemm_bpreshuffle.hpp:92
typename GridwiseGemm64::Argument Argument
Definition device_moe_mx_gemm_bpreshuffle.hpp:158
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_moe_mx_gemm_bpreshuffle.hpp:94
GridwiseMoeGemmMX_BPreshuffle< ALayout, BLayout, DsLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, math::max(2, NXdlPerWave_), ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ActivationOP, NSwizzle, IsInputGemm, MulRoutedWeight, IndexType, ComputeTypeA, ComputeTypeB > GridwiseGemmBase
Definition device_moe_mx_gemm_bpreshuffle.hpp:98
static constexpr index_t APackedSize
Definition device_moe_mx_gemm_bpreshuffle.hpp:160
static constexpr index_t NumDTensor
Definition device_moe_mx_gemm_bpreshuffle.hpp:96
static constexpr bool IsValidCompilationParameter()
Definition device_moe_mx_gemm_bpreshuffle.hpp:384
static constexpr index_t BPackedSize
Definition device_moe_mx_gemm_bpreshuffle.hpp:161
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_moe_mx_gemm_bpreshuffle.hpp:543
static bool IsSupportedArgument(const Argument &arg)
Definition device_moe_mx_gemm_bpreshuffle.hpp:390
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_moe_mx_gemm_bpreshuffle.hpp:156
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_moe_mx_gemm_bpreshuffle.hpp:437
int GetPreShuffleParameters() override
Definition device_moe_mx_gemm_bpreshuffle.hpp:163
std::string GetTypeString() const override
Definition device_moe_mx_gemm_bpreshuffle.hpp:549
static auto MakeInvoker()
Definition device_moe_mx_gemm_bpreshuffle.hpp:493
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_moe_mx_gemm_bpreshuffle.hpp:155
static constexpr auto NXdlPerWave32
Definition device_moe_mx_gemm_bpreshuffle.hpp:95
static auto MakeArgument(const void *p_sorted_token_ids, const void *p_sorted_expert_ids, const void *p_max_token_id, const void *p_a, const void *p_a_scale, const void *p_b, const void *p_b_scale, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t NumTokens, index_t TopK, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideScaleA, index_t StrideB, index_t StrideScaleB, std::array< index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition device_moe_mx_gemm_bpreshuffle.hpp:442
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_a_scale, const void *p_b, const void *p_b_scale, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideScaleA, index_t StrideB, index_t StrideScaleB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition device_moe_mx_gemm_bpreshuffle.hpp:496
Definition flush_cache.hpp:174