grouped_gemm_quant_kernel.hpp Source File#
grouped_gemm_quant_kernel.hpp
Go to the documentation of this file.
#define HIP_CHECK_ERROR(retval_or_funcall)
Definition host_utility/hip_check_error.hpp:21
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition type_traits.hpp:67
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition arch.hpp:307
QuantGemmKernelArgs QuantGroupedGemmKernelArgs
Definition grouped_gemm_quant_kernel.hpp:92
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
Struct used to calculate offseted tile indexes.
Definition gemm_tile_partitioner.hpp:184
static CK_TILE_DEVICE auto GetOffsetedTileIndex(index_t block_start, index_t M, index_t N) noexcept -> const tuple< index_t, index_t >
The function subtracts the block's start (offset) from 1D raw-indexes.
Definition gemm_tile_partitioner.hpp:192
Definition gemm_quant_kernel.hpp:272
index_t splitted_k
Definition gemm_quant_kernel.hpp:310
Definition gemm_quant_kernel.hpp:171
Definition gemm_quant_kernel.hpp:195
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition gemm_quant_kernel.hpp:729
static CK_TILE_HOST bool IsSupportedArgument(const QuantGemmKernelArgs &kargs)
Definition gemm_quant_kernel.hpp:313
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition gemm_quant_kernel.hpp:806
Definition grouped_gemm_quant_kernel.hpp:95
QuantGemmTransKernelArg()=delete
QuantGemmTransKernelArg(QuantGroupedGemmKernelArgs &&karg, index_t bl_start, index_t bl_end)
Definition grouped_gemm_quant_kernel.hpp:101
ck_tile::index_t block_end
Definition grouped_gemm_quant_kernel.hpp:98
ck_tile::index_t block_start
Definition grouped_gemm_quant_kernel.hpp:97
QuantGroupedGemmKernelArgs group_karg
Definition grouped_gemm_quant_kernel.hpp:96
QuantGemmTransKernelArg(QuantGroupedGemmKernelArgs &&karg)
Definition grouped_gemm_quant_kernel.hpp:106
index_t stride_BQ
Definition grouped_gemm_quant_kernel.hpp:81
const void * b_ptr
Definition grouped_gemm_quant_kernel.hpp:64
const void * aq_ptr
Definition grouped_gemm_quant_kernel.hpp:65
index_t stride_B
Definition grouped_gemm_quant_kernel.hpp:79
index_t k_batch
Definition grouped_gemm_quant_kernel.hpp:89
index_t stride_AQ
Definition grouped_gemm_quant_kernel.hpp:80
CK_TILE_HOST QuantGroupedGemmHostArgs(const void *a_ptr_, const void *b_ptr_, void *e_ptr_, const void *aq_ptr_, const void *bq_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t QK_A_, index_t QK_B_, index_t stride_A_, index_t stride_B_, index_t stride_E_, index_t stride_AQ_, index_t stride_BQ_)
Definition grouped_gemm_quant_kernel.hpp:28
index_t stride_A
Definition grouped_gemm_quant_kernel.hpp:78
const void * bq_ptr
Definition grouped_gemm_quant_kernel.hpp:66
index_t stride_C
Definition grouped_gemm_quant_kernel.hpp:86
index_t stride_E
Definition grouped_gemm_quant_kernel.hpp:85
const void * a_ptr
Definition grouped_gemm_quant_kernel.hpp:63
Definition grouped_gemm_quant_kernel.hpp:117
static CK_TILE_HOST_DEVICE constexpr auto GetSmemSize() -> index_t
Definition grouped_gemm_quant_kernel.hpp:297
QuantGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_ > Base
Inject the UniversalGemmKernel base class to support execution of all necessary functions.
Definition grouped_gemm_quant_kernel.hpp:120
static CK_TILE_DEVICE void RunGemmWithPipelineSelection(const ADataType *a_ptr, const BDataType *b_ptr, const AQDataType *aq_ptr, const BQDataType *bq_ptr, CDataType *c_ptr, void *smem_ptr_0, const QuantGroupedGemmKernelArgs &kargs, const typename Base::SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition grouped_gemm_quant_kernel.hpp:425
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition grouped_gemm_quant_kernel.hpp:129
static CK_TILE_HOST auto GridSize(const std::vector< QuantGroupedGemmHostArgs > &gemm_descs)
Definition grouped_gemm_quant_kernel.hpp:220
static constexpr index_t kBlockSize
Definition grouped_gemm_quant_kernel.hpp:163
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Specify the data type configurations for A, B, C/E.
Definition grouped_gemm_quant_kernel.hpp:132
static CK_TILE_HOST auto BlockSize() -> dim3
Definition grouped_gemm_quant_kernel.hpp:191
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition grouped_gemm_quant_kernel.hpp:209
CK_TILE_DEVICE void Run(const QuantGroupedGemmKernelArgs &kargs, const tuple< index_t, index_t > &block_idx_2d, const index_t block_idx_z) const
Definition grouped_gemm_quant_kernel.hpp:302
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition grouped_gemm_quant_kernel.hpp:128
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition grouped_gemm_quant_kernel.hpp:123
remove_cvref_t< typename EpiloguePipeline::AccDataType > AccDataType
Definition grouped_gemm_quant_kernel.hpp:135
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Definition grouped_gemm_quant_kernel.hpp:127
static CK_TILE_HOST const std::string GetName()
Definition grouped_gemm_quant_kernel.hpp:167
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition grouped_gemm_quant_kernel.hpp:133
static CK_TILE_HOST auto GetWorkSpaceSize(const std::vector< QuantGroupedGemmHostArgs > &gemm_descs) -> std::size_t
Definition grouped_gemm_quant_kernel.hpp:181
static CK_TILE_HOST auto GetWorkSpaceSize(index_t group_count) -> std::size_t
Definition grouped_gemm_quant_kernel.hpp:186
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition grouped_gemm_quant_kernel.hpp:122
CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count) const
Definition grouped_gemm_quant_kernel.hpp:502
QuantGroupedGemmKernel< TilePartitioner, GemmPipeline, EpiloguePipeline, kQuantType > Kernel
Definition grouped_gemm_quant_kernel.hpp:160
static CK_TILE_DEVICE void RunGemmWithPipelineSelection2LDS(const ADataType *a_ptr, const BDataType *b_ptr, const AQDataType *aq_ptr, const BQDataType *bq_ptr, CDataType *c_ptr, void *smem_ptr_0, void *smem_ptr_1, const QuantGroupedGemmKernelArgs &kargs, const typename Base::SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Definition grouped_gemm_quant_kernel.hpp:359
remove_cvref_t< typename detail::get_aq_data_type_or< GemmPipeline, AccDataType >::type > AQDataType
Definition grouped_gemm_quant_kernel.hpp:137
static CK_TILE_HOST auto MakeKargs(const std::vector< QuantGroupedGemmHostArgs > &gemm_descs) -> std::vector< QuantGemmTransKernelArg >
Definition grouped_gemm_quant_kernel.hpp:231
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition grouped_gemm_quant_kernel.hpp:124
static constexpr bool UsePersistentKernel
Definition grouped_gemm_quant_kernel.hpp:164
static constexpr auto kQuantType
Definition grouped_gemm_quant_kernel.hpp:142
static CK_TILE_HOST bool IsSupportedArgument(const std::vector< QuantGemmTransKernelArg > &kargs)
Definition grouped_gemm_quant_kernel.hpp:285
OffsettedTile1DPartitioner< TilePartitioner > OffsetTile1DPartitioner
ALayout and ADataType are expected to be scalars, not a tuple.
Definition grouped_gemm_quant_kernel.hpp:159
remove_cvref_t< typename detail::get_bq_data_type_or< GemmPipeline, AccDataType >::type > BQDataType
Definition grouped_gemm_quant_kernel.hpp:139
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition grouped_gemm_quant_kernel.hpp:134
Definition ck_tile/host/stream_config.hpp:30
Definition tile/core/container/tuple.hpp:192