36template <
typename GridwiseGemm,
37 typename AGridDesc_AK0_M_K1,
38 typename BGridDesc_BK0_N_K1,
39 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
40 typename ComputePtrOffsetOfBatch,
41 bool HasMainKBlockLoop,
46#if CK_USE_LAUNCH_BOUNDS
50 typename GridwiseGemm::Argument karg,
51 const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
52 const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
53 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
54 c_grid_desc_mblock_mperblock_nblock_nperblock,
55 const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
58#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
59 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
61 const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
62 const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
71 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
72 GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
74 CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
76 CGlobalMemoryDataOperation,
77 TailNum>(karg.p_a_grid + a_batch_offset,
78 karg.p_b_grid + b_batch_offset,
79 karg.p_c_grid + e_batch_offset,
82 a_grid_desc_ak0_m_ak1,
83 b_grid_desc_bk0_n_bk1,
84 c_grid_desc_mblock_mperblock_nblock_nperblock,
89 ignore = a_grid_desc_ak0_m_ak1;
90 ignore = b_grid_desc_bk0_n_bk1;
91 ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
92 ignore = compute_ptr_offset_of_batch;
97template <
typename GridwiseGemm,
98 typename AGridDesc_AK0_M_K1,
99 typename BGridDesc_BK0_N_K1,
100 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
101 typename ComputePtrOffsetOfBatch,
102 bool HasMainKBlockLoop,
107#if CK_USE_LAUNCH_BOUNDS
111 typename GridwiseGemm::Argument karg,
112 const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
113 const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
114 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
115 c_grid_desc_mblock_mperblock_nblock_nperblock,
116 const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
119#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
120 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
123 const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
124 const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
135 __shared__
char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
136 __shared__
char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
138 GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
140 CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
142 CGlobalMemoryDataOperation,
143 TailNum>(karg.p_a_grid + a_batch_offset,
144 karg.p_b_grid + b_batch_offset,
145 karg.p_c_grid + e_batch_offset,
149 a_grid_desc_ak0_m_ak1,
150 b_grid_desc_bk0_n_bk1,
151 c_grid_desc_mblock_mperblock_nblock_nperblock,
156 ignore = a_grid_desc_ak0_m_ak1;
157 ignore = b_grid_desc_bk0_n_bk1;
158 ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
159 ignore = compute_ptr_offset_of_batch;
170 typename WeiDataType,
171 typename OutDataType,
172 typename AccDataType,
173 typename InElementwiseOperation,
174 typename WeiElementwiseOperation,
175 typename OutElementwiseOperation,
186 typename ABlockTransferThreadClusterLengths_K0_M_K1,
187 typename ABlockTransferThreadClusterArrangeOrder,
188 typename ABlockTransferSrcAccessOrder,
192 bool ABlockLdsAddExtraM,
193 typename BBlockTransferThreadClusterLengths_K0_N_K1,
194 typename BBlockTransferThreadClusterArrangeOrder,
195 typename BBlockTransferSrcAccessOrder,
199 bool BBlockLdsAddExtraN,
200 index_t CShuffleMXdlPerWavePerShuffle,
201 index_t CShuffleNXdlPerWavePerShuffle,
202 typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
203 index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
206 typename ComputeTypeA = InDataType,
207 typename ComputeTypeB = ComputeTypeA>
216 InElementwiseOperation,
217 WeiElementwiseOperation,
218 OutElementwiseOperation,
259 ConvBackwardWeightSpecialization>{};
261 template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type =
false>
266 const std::array<ck::index_t, NDimSpatial> lengths{1};
267 const std::array<ck::index_t, NDimSpatial + 3> strides{1, 1, 1, 1};
268 const std::array<ck::index_t, NDimSpatial> params{1};
286 template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type =
false>
291 const std::array<ck::index_t, NDimSpatial> lengths{1, 1};
292 const std::array<ck::index_t, NDimSpatial + 3> strides{1, 1, 1, 1, 1};
293 const std::array<ck::index_t, NDimSpatial> params{1, 1};
311 template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type =
false>
316 const std::array<ck::index_t, NDimSpatial> lengths{1, 1, 1};
317 const std::array<ck::index_t, NDimSpatial + 3> strides{1, 1, 1, 1, 1, 1};
318 const std::array<ck::index_t, NDimSpatial> params{1, 1, 1};
342 template <index_t NXdlPerWave_>
366 ABlockTransferThreadClusterLengths_K0_M_K1,
367 ABlockTransferThreadClusterArrangeOrder,
368 ABlockTransferSrcAccessOrder,
369 ABlockTransferSrcVectorDim,
370 ABlockTransferSrcScalarPerVector,
371 ABlockTransferDstScalarPerVector_K1,
374 BBlockTransferThreadClusterLengths_K0_N_K1,
375 BBlockTransferThreadClusterArrangeOrder,
376 BBlockTransferSrcAccessOrder,
377 BBlockTransferSrcVectorDim,
378 BBlockTransferSrcScalarPerVector,
379 BBlockTransferDstScalarPerVector_K1,
382 CShuffleMXdlPerWavePerShuffle,
383 CShuffleNXdlPerWavePerShuffle,
384 CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
385 CBlockTransferScalarPerVector_NWaveNPerXdl,
400 template <
typename Gr
idwiseGemm>
403 constexpr int dynamic_smem_size = 0;
404 constexpr index_t minimum_occupancy =
406 int max_occupancy = 0;
417 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
433 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
440 return std::max(1, max_occupancy);
467 WeiDataType* p_wei_grid,
468 const OutDataType* p_out_grid,
469 const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_lengths,
470 const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_strides,
471 const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths,
472 const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_strides,
473 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths,
474 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides,
475 const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
476 const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
477 const std::array<ck::index_t, NDimSpatial>& input_left_pads,
478 const std::array<ck::index_t, NDimSpatial>& input_right_pads,
481 InElementwiseOperation in_element_op,
482 WeiElementwiseOperation wei_element_op,
483 OutElementwiseOperation out_element_op,
498 Conv_G_{b_g_n_c_wis_lengths[0]},
499 Conv_N_{b_g_n_c_wis_lengths[1]},
500 Conv_K_{e_g_k_c_xs_lengths[1]},
501 Conv_C_{b_g_n_c_wis_lengths[2]},
513 e_g_k_c_xs_lengths.begin(), NDimSpatial +
I3, 1, std::multiplies<>()) *
516 constexpr index_t spatial_offset = 3;
517 std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset,
518 end(b_g_n_c_wis_lengths),
520 std::copy(begin(e_g_k_c_xs_lengths) + spatial_offset,
521 end(e_g_k_c_xs_lengths),
523 std::copy(begin(a_g_n_k_wos_lengths) + spatial_offset,
524 end(a_g_n_k_wos_lengths),
527#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
531 std::tie(gemmM, gemmN, gemmK) =
534 const auto grid_size =
541 const auto k_batch_max =
static_cast<index_t>((gemmK - 1) / K0PerBlock);
546 std::cout <<
"[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max
548 std::cout <<
"[SPLIT-K AUTODEDUCE] Final k_batch value: " <<
k_batch_
560 .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
571 conv_filter_dilations,
588 std::multiplies<>{});
638 std::cout <<
"arg.a_grid_desc_kbatch_k0_m_k1_{"
644 std::cout <<
"arg.b_grid_desc_kbatch_k0_n_k1_{"
654 template <
typename Gr
idwiseGemm>
664 typename GridwiseGemm::Argument gemm_arg{
668 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(
669 gemm_arg.M, gemm_arg.N, gemm_arg.KBatch, arg.
Conv_G_);
673 index_t k_grain = gemm_arg.KBatch * K0PerBlock;
674 index_t K_split = (gemm_arg.K + k_grain - 1) / k_grain * K0PerBlock;
675 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
677 const auto num_k_per_block =
680 const auto clear_workspace = [&]() {
688 const auto Run = [&](
const auto& kernel) {
689 if(stream_config.flush_cache)
691 typename GridwiseGemm::Argument gemm_arg_ = gemm_arg;
694 stream_config.rotating_count,
695 gemm_arg_.M * gemm_arg_.K *
sizeof(
ADataType),
696 gemm_arg_.K * gemm_arg_.N *
sizeof(
BDataType));
697 rotating_mem.Print();
699 auto run_flush_cache = [&]() {
738 constexpr index_t minimum_occupancy =
741 if(has_main_k_block_loop)
747 if(gemm_arg.KBatch > 1)
755 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
769 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
779 if(gemm_arg.KBatch > 1)
781 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::One)
789 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
796 else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
805 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
813 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
815 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Two)
823 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
832 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
834 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
843 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
852 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
854 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
863 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
872 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
874 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
883 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
892 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
894 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Six)
902 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
911 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
913 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
922 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
933 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::One)
941 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
948 else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
957 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
965 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
967 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Two)
975 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
984 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
986 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
995 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1004 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
1006 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
1015 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1024 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
1026 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
1035 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1044 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
1046 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Six)
1054 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1063 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
1065 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
1074 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1088 if(gemm_arg.KBatch > 1)
1090 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
1098 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1113 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1123 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
1131 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1146 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1157 if(gemm_arg.KBatch > 1)
1159 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
1167 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1182 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1192 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
1200 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1215 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1230 if(gemm_arg.KBatch > 1)
1238 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1252 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1268 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
1280#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
1302 std::cout <<
"ComputeDataType for A and B should be same while using TF32"
1313 typename GridwiseGemm64::Argument gemm_arg{
1314 nullptr,
nullptr,
nullptr, GemmM, GemmN, GemmK,
I0,
I0,
I0, arg.
k_batch_};
1316 const auto num_k_loop = gemm_arg.AK0 / (K0PerBlock / K1);
1319 if(num_k_loop <= GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages)
1334 typename GridwiseGemm32::Argument gemm_arg{
1335 nullptr,
nullptr,
nullptr, GemmM, GemmN, GemmK,
I0,
I0,
I0, arg.
k_batch_};
1337 const auto num_k_loop = gemm_arg.AK0 / (K0PerBlock / K1);
1340 if(num_k_loop <= GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages)
1366 if constexpr(NDimSpatial == 1)
1373 else if constexpr(NDimSpatial == 2)
1381 else if constexpr(NDimSpatial == 3)
1394 if constexpr(ConvBackwardWeightSpecialization ==
1398 for(
int i = 0; i < NDimSpatial; i++)
1407 if(!(ABlockTransferSrcVectorDim == 1 && BBlockTransferSrcVectorDim == 1 &&
1408 arg.
Conv_K_ % ABlockTransferSrcScalarPerVector == 0 &&
1409 arg.
Conv_C_ % BBlockTransferSrcScalarPerVector == 0))
1415 if(!(arg.
Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0))
1439 WeiDataType* p_wei_grid,
1440 const OutDataType* p_out_grid,
1441 const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_lengths,
1442 const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_strides,
1443 const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths,
1444 const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_strides,
1445 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths,
1446 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides,
1447 const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
1448 const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
1449 const std::array<ck::index_t, NDimSpatial>& input_left_pads,
1450 const std::array<ck::index_t, NDimSpatial>& input_right_pads,
1451 InElementwiseOperation in_element_op,
1452 WeiElementwiseOperation wei_element_op,
1453 OutElementwiseOperation out_element_op,
1459 b_g_n_c_wis_lengths,
1460 b_g_n_c_wis_strides,
1463 a_g_n_k_wos_lengths,
1464 a_g_n_k_wos_strides,
1465 conv_filter_strides,
1466 conv_filter_dilations,
1479 std::unique_ptr<BaseArgument>
1482 const void* p_out_grid,
1483 const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_lengths,
1484 const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_strides,
1485 const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths,
1486 const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_strides,
1487 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths,
1488 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides,
1489 const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
1490 const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
1491 const std::array<ck::index_t, NDimSpatial>& input_left_pads,
1492 const std::array<ck::index_t, NDimSpatial>& input_right_pads,
1493 InElementwiseOperation in_element_op,
1494 WeiElementwiseOperation wei_element_op,
1495 OutElementwiseOperation out_element_op,
1498 return std::make_unique<Argument>(
static_cast<const InDataType*
>(p_in_grid),
1499 static_cast<WeiDataType*
>(p_wei_grid),
1500 static_cast<const OutDataType*
>(p_out_grid),
1501 b_g_n_c_wis_lengths,
1502 b_g_n_c_wis_strides,
1505 a_g_n_k_wos_lengths,
1506 a_g_n_k_wos_strides,
1507 conv_filter_strides,
1508 conv_filter_dilations,
1521 return std::make_unique<Invoker>(
Invoker{});
1526 auto str = std::stringstream();
1529 str <<
"DeviceGroupedConvBwdWeight_Xdl_CShuffleV3"
1531 << BlockSize <<
", "
1532 << MPerBlock <<
", "
1533 << NPerBlock <<
", "
1534 << K0PerBlock <<
", "
1537 << MXdlPerWave <<
", "
1538 << NXdlPerWave <<
", "
1539 << ABlockTransferSrcScalarPerVector <<
", "
1540 << ABlockTransferDstScalarPerVector_K1 <<
", "
1541 << BBlockTransferSrcScalarPerVector <<
", "
1542 << BBlockTransferDstScalarPerVector_K1 <<
", "
1543 << CShuffleMXdlPerWavePerShuffle <<
", "
1544 << CShuffleNXdlPerWavePerShuffle <<
", "
1545 << CBlockTransferScalarPerVector_NWaveNPerXdl
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
void hip_check_error(hipError_t x)
Definition host_utility/hip_check_error.hpp:10
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:91
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
__global__ void kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg, const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const index_t num_k_per_block)
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:109
__global__ void kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg, const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const index_t num_k_per_block)
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:51
auto get_bwd_weight_gemm_sizes(const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_lengths)
Definition split_k_utils.hpp:55
ConvolutionBackwardWeightSpecialization
Definition convolution_backward_weight_specialization.hpp:13
@ Filter1x1Stride1Pad0
Definition convolution_backward_weight_specialization.hpp:15
constexpr bool is_GNWC_GKXC_GNWK()
Definition device_grouped_conv_utils.hpp:23
constexpr bool is_GNDHWC_GKZYXC_GNDHWK()
Definition device_grouped_conv_utils.hpp:88
constexpr bool is_NHWGC_GKYXC_NHWGK()
Definition device_grouped_conv_utils.hpp:40
ck::index_t get_best_occupancy_k_batch_value(int max_occupancy, ck::index_t grid_size)
Definition split_k_utils.hpp:30
GemmSpecialization
Definition gemm_specialization.hpp:11
@ Default
Definition gemm_specialization.hpp:13
constexpr bool is_NDHWGC_GKZYXC_NDHWGK()
Definition device_grouped_conv_utils.hpp:80
std::string getConvBackwardWeightSpecializationString(const ConvolutionBackwardWeightSpecialization &s)
Definition convolution_backward_weight_specialization.hpp:21
ck::index_t calculate_mn_grid_size(ck::index_t gemmM, ck::index_t gemmN)
Definition split_k_utils.hpp:84
constexpr bool is_GNHWC_GKYXC_GNHWK()
Definition device_grouped_conv_utils.hpp:48
Definition convolution_backward_data_specialization.hpp:7
void flush_icache()
Definition flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition flush_cache.hpp:398
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ One
Definition blkgemmpipe_scheduler.hpp:37
@ Seven
Definition blkgemmpipe_scheduler.hpp:43
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Four
Definition blkgemmpipe_scheduler.hpp:40
@ Two
Definition blkgemmpipe_scheduler.hpp:38
@ Full
Definition blkgemmpipe_scheduler.hpp:49
@ Three
Definition blkgemmpipe_scheduler.hpp:39
@ Five
Definition blkgemmpipe_scheduler.hpp:41
@ Six
Definition blkgemmpipe_scheduler.hpp:42
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition amd_wave_read_first_lane.hpp:100
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
bool is_tf32_supported()
Definition host_utility/device_prop.hpp:132
constexpr bool is_same_v
Definition type.hpp:283
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
auto accumulate_n(ForwardIterator first, Size count, T init, BinaryOperation op) -> decltype(std::accumulate(first, std::next(first, count), init, op))
Definition library/utility/numeric.hpp:11
int64_t long_index_t
Definition ck.hpp:300
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
bool is_bf16_atomic_supported()
Definition host_utility/device_prop.hpp:108
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_xdl_cshuffle_conv_v3.hpp:66
ck::GridwiseGemm_xdl_cshuffle_conv_v3< tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, tensor_layout::gemm::RowMajor, ADataType, BDataType, AccDataType, CDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, K0PerBlock, K1, K1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CBlockTransferScalarPerVector_NWaveNPerXdl, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock __host__ static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition gridwise_gemm_xdl_cshuffle_conv_v3.hpp:644
ck::GridwiseGemm_xdl_cshuffle_conv_v3< tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, tensor_layout::gemm::RowMajor, ADataType, BDataType, AccDataType, CDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, K0PerBlock, K1, K1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CBlockTransferScalarPerVector_NWaveNPerXdl, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::CalculateNBlock static __host__ auto CalculateNBlock(index_t N)
Definition gridwise_gemm_xdl_cshuffle_conv_v3.hpp:153
ck::GridwiseGemm_xdl_cshuffle_conv_v3< tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, tensor_layout::gemm::RowMajor, ADataType, BDataType, AccDataType, CDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, K0PerBlock, K1, K1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CBlockTransferScalarPerVector_NWaveNPerXdl, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::CalculateMBlock static __host__ auto CalculateMBlock(index_t M)
Definition gridwise_gemm_xdl_cshuffle_conv_v3.hpp:148
Definition tensor_operation/gpu/device/tensor_layout.hpp:31
Definition tensor_operation/gpu/device/tensor_layout.hpp:26
Definition split_k_arg.hpp:11
index_t k_batch_
Definition split_k_arg.hpp:12
Definition device_base.hpp:197
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:399
static int GetMaxOccupancy()
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:401
int max_occupancy_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:461
ActiveWorkgroupsPerCU()
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:443
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:465
WeiElementwiseOperation c_element_op_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:615
const std::array< ck::index_t, NDimSpatial > & input_left_pads_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:626
BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:603
const std::array< ck::index_t, NDimSpatial > & conv_filter_strides_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:625
const index_t Conv_N_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:619
const BDataType * p_b_grid_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:600
CDataType * p_c_grid_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:601
long_index_t c_space_size_bytes
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:628
index_t M01_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:610
OutElementwiseOperation a_element_op_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:613
const ADataType * p_a_grid_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:599
index_t N01_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:611
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:605
const index_t Conv_K_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:620
const std::array< ck::index_t, NDimSpatial > & input_right_pads_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:627
std::array< ck::index_t, NDimSpatial > output_spatial_lengths_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:624
const index_t Conv_C_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:621
const index_t Conv_G_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:618
Argument(const InDataType *p_in_grid, WeiDataType *p_wei_grid, const OutDataType *p_out_grid, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_dilations, const std::array< ck::index_t, NDimSpatial > &input_left_pads, const std::array< ck::index_t, NDimSpatial > &input_right_pads, const ck::index_t M01, const ck::index_t N01, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, ck::index_t split_k)
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:466
ComputePtrOffsetOfStridedBatch< I1, I1, I0 > compute_ptr_offset_of_batch_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:608
std::array< ck::index_t, NDimSpatial > input_spatial_lengths_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:622
AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:602
InElementwiseOperation b_element_op_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:614
std::array< ck::index_t, NDimSpatial > filter_spatial_lengths_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:623
CGridDesc_M_N c_grid_desc_m_n_
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:604
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:633
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:1265
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:655
DeviceOp::Argument Argument
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:634
void ShowInfo(const Argument &arg)
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:636
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:221
static auto I5
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:247
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 DeviceOp
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:226
remove_cvref_t< decltype(ABCGridDescs{}[I2])> CGridDesc_M_N
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:340
remove_cvref_t< decltype(ABCGridDescs{}[I0])> AGridDesc_K0_M_K1
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:338
InDataType ABDataType
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:240
static auto GetABCGridDesc()
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:262
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:1278
static constexpr auto K1Number
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:250
GridwiseGemm_xdl_cshuffle_conv_v3< tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, tensor_layout::gemm::RowMajor, ADataType, BDataType, AccDataType, CDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, K0PerBlock, K1, K1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CBlockTransferScalarPerVector_NWaveNPerXdl, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB > GridwiseGemmBase
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:343
static auto I0
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:242
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:228
remove_cvref_t< decltype(ABCGridDescs{}[I1])> BGridDesc_K0_N_K1
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:339
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:391
static auto I2
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:244
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:1519
static constexpr bool IsValidCompilationParameter()
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:1272
decltype(GridwiseGemm64::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( CGridDesc_M_N{}, 1, 1)) CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:394
std::string GetTypeString() const override
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:1524
static constexpr auto conv_to_gemm_transformer
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:252
static auto MakeArgument(const InDataType *p_in_grid, WeiDataType *p_wei_grid, const OutDataType *p_out_grid, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_dilations, const std::array< ck::index_t, NDimSpatial > &input_left_pads, const std::array< ck::index_t, NDimSpatial > &input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, const ck::index_t split_k)
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:1438
static auto I4
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:246
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:390
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_in_grid, void *p_wei_grid, const void *p_out_grid, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_dilations, const std::array< ck::index_t, NDimSpatial > &input_left_pads, const std::array< ck::index_t, NDimSpatial > &input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, const ck::index_t split_k) override
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:1480
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:1432
static auto I1
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:243
OutDataType ADataType
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:231
InElementwiseOperation BElementwiseOperation
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:236
OutElementwiseOperation AElementwiseOperation
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:235
static auto I3
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:245
InDataType BDataType
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:232
WeiElementwiseOperation CElementwiseOperation
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:237
static constexpr GemmSpecialization GemmSpec
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:249
static auto MakeInvoker()
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:1477
static constexpr auto NXdlPerWave32
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:229
WeiDataType CDataType
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:233
decltype(GetABCGridDesc< NDimSpatial >()) ABCGridDescs
Definition device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp:336
Definition device_grouped_conv_bwd_weight.hpp:29
Definition flush_cache.hpp:299
#define CK_ENV(name)
Definition utility/env.hpp:129