BlockFmhaFwdSplitKVPipelineQRKSVS< Problem_, Policy_ > Struct Template Reference#
Public Types |
Public Member Functions |
Static Public Member Functions |
Static Public Attributes |
List of all members
ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS< Problem_, Policy_ > Struct Template Reference
#include <block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp>
Public Types | |
| using | Problem = remove_cvref_t<Problem_> |
| using | Policy = remove_cvref_t<Policy_> |
| using | QDataType = remove_cvref_t<typename Problem::QDataType> |
| using | KDataType = remove_cvref_t<typename Problem::KDataType> |
| using | VDataType = remove_cvref_t<typename Problem::VDataType> |
| using | SaccDataType = remove_cvref_t<typename Problem::SaccDataType> |
| using | SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType> |
| using | BiasDataType = remove_cvref_t<typename Problem::BiasDataType> |
| using | LSEDataType = remove_cvref_t<typename Problem::LSEDataType> |
| using | PDataType = remove_cvref_t<typename Problem::PDataType> |
| using | OaccDataType = remove_cvref_t<typename Problem::OaccDataType> |
| using | ODataType = remove_cvref_t<typename Problem::ODataType> |
| using | AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant> |
| using | FmhaMask = remove_cvref_t<typename Problem::FmhaMask> |
| using | BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape> |
| using | VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout> |
Public Member Functions | |
| template<typename QDramBlockWindowTmp, typename KDramBlockWindowLengths, typename KPageBlockNavigator, typename VDramBlockWindowLengths, typename VPageBlockNavigator, typename BiasDramBlockWindowTmp, typename LSEaccDramBlockWindowTmp, typename QElementFunction, typename KElementFunction, typename VElementFunction, typename BiasElementFunction, typename LSEaccElementFunction, typename SAccElementFunction, typename PComputeElementFunction, typename OAccElementFunction, typename PositionEncoding, typename AttentionVariantParams, typename BlockIndices> | |
| CK_TILE_HOST_DEVICE auto | operator() (const QDramBlockWindowTmp &q_dram_block_window_tmp, const QElementFunction &q_element_func, const KDramBlockWindowLengths &k_dram_block_window_lengths, const KPageBlockNavigator &k_page_block_navigator, const KElementFunction &k_element_func, const VDramBlockWindowLengths &v_dram_block_window_lengths, const VPageBlockNavigator &v_page_block_navigator, const VElementFunction &v_element_func, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, const BiasElementFunction &bias_element_func, LSEaccDramBlockWindowTmp &lse_acc_dram_window_tmp, const LSEaccElementFunction &lse_acc_element_func, const SAccElementFunction &s_acc_element_func, const PComputeElementFunction &p_compute_element_func, const OAccElementFunction &o_acc_element_func, index_t num_splits, index_t i_split, FmhaMask mask, PositionEncoding position_encoding, float scale_s, const AttentionVariant &variant, const AttentionVariantParams &variant_params, const BlockIndices &block_indices, index_t kv_l2p_offset, void *smem_ptr) const |
| template<typename QDramBlockWindowTmp, typename KDramBlockWindowLengths, typename KPageBlockNavigator, typename VDramBlockWindowLengths, typename VPageBlockNavigator, typename BiasDramBlockWindowTmp, typename LSEaccDramBlockWindowTmp, typename PositionEncoding, typename AttentionVariantParams, typename BlockIndices> | |
| CK_TILE_HOST_DEVICE auto | operator() (const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowLengths &k_dram_block_window_lengths, const KPageBlockNavigator &k_page_block_navigator, const VDramBlockWindowLengths &v_dram_block_window_lengths, const VPageBlockNavigator &v_page_block_navigator, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, LSEaccDramBlockWindowTmp &lse_acc_dram_block_window_tmp, index_t num_splits, index_t i_split, FmhaMask mask, PositionEncoding position_encoding, float scale_s, const AttentionVariant &variant, const AttentionVariantParams &variant_params, const BlockIndices &block_indices, index_t kv_l2p_offset, void *smem_ptr) const |
Static Public Member Functions | |
| static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t | GetSmemSize () |
Static Public Attributes | |
| static constexpr bool | kQLoadOnce = true |
| static constexpr index_t | kBlockSize = Problem::kBlockSize |
| static constexpr index_t | kM0 = BlockFmhaShape::kM0 |
| static constexpr index_t | kN0 = BlockFmhaShape::kN0 |
| static constexpr index_t | kK0 = BlockFmhaShape::kK0 |
| static constexpr index_t | kN1 = BlockFmhaShape::kN1 |
| static constexpr index_t | kK1 = BlockFmhaShape::kK1 |
| static constexpr index_t | kQKHeaddim = BlockFmhaShape::kQKHeaddim |
| static constexpr index_t | kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim |
| static constexpr bool | kIsGroupMode = Problem::kIsGroupMode |
| static constexpr bool | kPadSeqLenQ = Problem::kPadSeqLenQ |
| static constexpr bool | kPadSeqLenK = Problem::kPadSeqLenK |
| static constexpr bool | kPadHeadDimQ = Problem::kPadHeadDimQ |
| static constexpr bool | kPadHeadDimV = Problem::kPadHeadDimV |
| static constexpr bool | kHasLogitsSoftCap = Problem::kHasLogitsSoftCap |
| static constexpr auto | BiasEnum = Problem::BiasEnum |
| static constexpr bool | kStoreLSE = Problem::kStoreLSE |
| static constexpr bool | kIsPagedKV = Problem::kIsPagedKV |
| static constexpr bool | kHasUnevenSplits = Problem::kHasUnevenSplits |
| static constexpr index_t | kAlignmentQ |
| static constexpr index_t | kAlignmentK |
| static constexpr index_t | kAlignmentV |
| static constexpr index_t | kAlignmentOacc |
| static constexpr index_t | kAlignmentBias |
| static constexpr index_t | kBlockPerCu |
| static constexpr const char * | name = "qr" |
Member Typedef Documentation
◆ AttentionVariant
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
| using ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS< Problem_, Policy_ >::AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant> |
◆ BiasDataType
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
| using ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS< Problem_, Policy_ >::BiasDataType = remove_cvref_t<typename Problem::BiasDataType> |
◆ BlockFmhaShape
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
| using ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS< Problem_, Policy_ >::BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape> |
◆ FmhaMask
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
| using ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS< Problem_, Policy_ >::FmhaMask = remove_cvref_t<typename Problem::FmhaMask> |
◆ KDataType
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
| using ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS< Problem_, Policy_ >::KDataType = remove_cvref_t<typename Problem::KDataType> |
◆ LSEDataType
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
| using ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS< Problem_, Policy_ >::LSEDataType = remove_cvref_t<typename Problem::LSEDataType> |
◆ OaccDataType
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
| using ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS< Problem_, Policy_ >::OaccDataType = remove_cvref_t<typename Problem::OaccDataType> |
◆ ODataType
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
| using ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS< Problem_, Policy_ >::ODataType = remove_cvref_t<typename Problem::ODataType> |
◆ PDataType
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
| using ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS< Problem_, Policy_ >::PDataType = remove_cvref_t<typename Problem::PDataType> |
◆ Policy
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
| using ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS< Problem_, Policy_ >::Policy = remove_cvref_t<Policy_> |
◆ Problem
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
| using ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS< Problem_, Policy_ >::Problem = remove_cvref_t<Problem_> |
◆ QDataType
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
| using ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS< Problem_, Policy_ >::QDataType = remove_cvref_t<typename Problem::QDataType> |
◆ SaccDataType
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
| using ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS< Problem_, Policy_ >::SaccDataType = remove_cvref_t<typename Problem::SaccDataType> |
◆ SMPLComputeDataType
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
| using ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS< Problem_, Policy_ >::SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType> |
◆ VDataType
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
| using ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS< Problem_, Policy_ >::VDataType = remove_cvref_t<typename Problem::VDataType> |
◆ VLayout
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
| using ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS< Problem_, Policy_ >::VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout> |
Member Function Documentation
◆ GetSmemSize()
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
|
inlinestaticconstexpr |
◆ operator()() [1/2]
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
template<typename QDramBlockWindowTmp, typename KDramBlockWindowLengths, typename KPageBlockNavigator, typename VDramBlockWindowLengths, typename VPageBlockNavigator, typename BiasDramBlockWindowTmp, typename LSEaccDramBlockWindowTmp, typename PositionEncoding, typename AttentionVariantParams, typename BlockIndices>
|
inline |
◆ operator()() [2/2]
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
template<typename QDramBlockWindowTmp, typename KDramBlockWindowLengths, typename KPageBlockNavigator, typename VDramBlockWindowLengths, typename VPageBlockNavigator, typename BiasDramBlockWindowTmp, typename LSEaccDramBlockWindowTmp, typename QElementFunction, typename KElementFunction, typename VElementFunction, typename BiasElementFunction, typename LSEaccElementFunction, typename SAccElementFunction, typename PComputeElementFunction, typename OAccElementFunction, typename PositionEncoding, typename AttentionVariantParams, typename BlockIndices>
|
inline |
TODO: only check in first/last iteration without increasing code size
NOTICE: bias might be materialized mask including -inf values, need consideration
Member Data Documentation
◆ BiasEnum
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
◆ kAlignmentBias
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
Initial value:
=
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>()
static constexpr bool kPadSeqLenK
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:64
◆ kAlignmentK
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
Initial value:
=
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>()
static constexpr index_t kPadHeadDimQ
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:52
◆ kAlignmentOacc
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
Initial value:
=
kPadHeadDimV ? 1 : Policy::template GetAlignmentOacc<Problem>()
static constexpr bool kPadHeadDimV
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:65
◆ kAlignmentQ
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
Initial value:
=
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>()
◆ kAlignmentV
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
Initial value:
= []() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}()
static constexpr index_t kPadHeadDimV
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp:53
remove_cvref_t< Policy_ > Policy
Definition block_fmha_fwd_appendkv_pipeline.hpp:16
remove_cvref_t< Problem_ > Problem
Definition block_fmha_fwd_appendkv_pipeline.hpp:15
◆ kBlockPerCu
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
Initial value:
= []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
{
return 2;
}
{
return 3;
}
{
return 1;
else
return 2;
}
{
return 1;
}
else
{
return 1;
}
}
}()
@ ELEMENTWISE_BIAS
Definition block_attention_bias_enum.hpp:14
static constexpr index_t kQKHeaddim
Definition block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp:46
static constexpr index_t kQKHeaddim
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:44
static constexpr auto BiasEnum
Definition block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp:55
◆ kBlockSize
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
◆ kHasLogitsSoftCap
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
◆ kHasUnevenSplits
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
◆ kIsGroupMode
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
◆ kIsPagedKV
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
◆ kK0
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
◆ kK1
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
◆ kM0
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
◆ kN0
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
◆ kN1
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
◆ kPadHeadDimQ
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
◆ kPadHeadDimV
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
◆ kPadSeqLenK
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
◆ kPadSeqLenQ
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
◆ kQKHeaddim
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
◆ kQLoadOnce
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
◆ kStoreLSE
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
◆ kSubQKHeaddim
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
◆ name
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
|
staticconstexpr |
The documentation for this struct was generated from the following file: