BlockFmhaFwdSplitKVCombinePipeline< Problem_, Policy_ > Struct Template Reference#
Public Types |
Public Member Functions |
Static Public Member Functions |
Static Public Attributes |
List of all members
ck_tile::BlockFmhaFwdSplitKVCombinePipeline< Problem_, Policy_ > Struct Template Reference
#include <block_fmha_fwd_splitkv_combine_pipeline.hpp>
Public Types | |
| using | Problem = remove_cvref_t<Problem_> |
| using | Policy = remove_cvref_t<Policy_> |
| using | LSEDataType = remove_cvref_t<typename Problem::LSEDataType> |
| using | OaccDataType = remove_cvref_t<typename Problem::OaccDataType> |
| using | ODataType = remove_cvref_t<typename Problem::ODataType> |
Public Member Functions | |
| template<typename LSEaccDramBlockWindowTmp, typename OaccDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename LSEElementFunction, typename OaccElementFunction> | |
| CK_TILE_HOST_DEVICE auto | operator() (const LSEaccDramBlockWindowTmp &lse_acc_dram_block_window_tmp, const OaccDramBlockWindowTmp &o_acc_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_window_tmp, const LSEElementFunction &lse_element_func, const OaccElementFunction &o_acc_element_func, index_t num_splits, void *smem_ptr) const |
| template<typename LSEaccDramBlockWindow, typename OaccDramBlockWindow, typename LSEDramBlockWindow> | |
| CK_TILE_HOST_DEVICE auto | operator() (const LSEaccDramBlockWindow &lse_acc_dram_block_window, const OaccDramBlockWindow &o_acc_dram_block_window, LSEDramBlockWindow &lse_dram_block_window, index_t num_splits, void *smem_ptr) const |
Static Public Member Functions | |
| static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t | GetSmemSize () |
Static Public Attributes | |
| static constexpr index_t | kNumWarps = Problem::kNumWarps |
| static constexpr index_t | kBlockSize = Problem::kBlockSize |
| static constexpr index_t | kHeadDimV = Problem::kHeadDimV |
| static constexpr index_t | kM0 = Problem::kM0 |
| static constexpr index_t | kN1 = Problem::kN1 |
| static constexpr bool | kIsGroupMode = Problem::kIsGroupMode |
| static constexpr bool | kPadSeqLenQ = Problem::kPadSeqLenQ |
| static constexpr bool | kPadHeadDimV = Problem::kPadHeadDimV |
| static constexpr bool | kStoreLSE = Problem::kStoreLSE |
| static constexpr index_t | kMaxSplits = Problem::kMaxSplits |
| static constexpr index_t | kAlignmentLSE |
| static constexpr index_t | kAlignmentLSEacc = kAlignmentLSE |
| static constexpr index_t | kAlignmentOacc |
| static constexpr index_t | kAlignmentO |
| static constexpr index_t | kBlockPerCu |
| static constexpr const char * | name = "unused" |
Member Typedef Documentation
◆ LSEDataType
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
| using ck_tile::BlockFmhaFwdSplitKVCombinePipeline< Problem_, Policy_ >::LSEDataType = remove_cvref_t<typename Problem::LSEDataType> |
◆ OaccDataType
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
| using ck_tile::BlockFmhaFwdSplitKVCombinePipeline< Problem_, Policy_ >::OaccDataType = remove_cvref_t<typename Problem::OaccDataType> |
◆ ODataType
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
| using ck_tile::BlockFmhaFwdSplitKVCombinePipeline< Problem_, Policy_ >::ODataType = remove_cvref_t<typename Problem::ODataType> |
◆ Policy
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
| using ck_tile::BlockFmhaFwdSplitKVCombinePipeline< Problem_, Policy_ >::Policy = remove_cvref_t<Policy_> |
◆ Problem
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
| using ck_tile::BlockFmhaFwdSplitKVCombinePipeline< Problem_, Policy_ >::Problem = remove_cvref_t<Problem_> |
Member Function Documentation
◆ GetSmemSize()
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
|
inlinestaticconstexpr |
◆ operator()() [1/2]
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
template<typename LSEaccDramBlockWindow, typename OaccDramBlockWindow, typename LSEDramBlockWindow>
|
inline |
◆ operator()() [2/2]
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
template<typename LSEaccDramBlockWindowTmp, typename OaccDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename LSEElementFunction, typename OaccElementFunction>
|
inline |
Member Data Documentation
◆ kAlignmentLSE
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
|
staticconstexpr |
Initial value:
=
kPadSeqLenQ ? 1 : Policy::template GetAlignmentLSE<Problem>()
static constexpr bool kPadSeqLenQ
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:64
◆ kAlignmentLSEacc
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
|
staticconstexpr |
◆ kAlignmentO
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
|
staticconstexpr |
Initial value:
=
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>()
static constexpr bool kPadHeadDimV
Definition block_fmha_bwd_dot_do_o.hpp:24
◆ kAlignmentOacc
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
|
staticconstexpr |
Initial value:
=
kPadHeadDimV ? 1 : Policy::template GetAlignmentOacc<Problem>()
static constexpr bool kPadHeadDimV
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:65
◆ kBlockPerCu
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
|
staticconstexpr |
Initial value:
= []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
{
constexpr std::array occupancy{3, 3, 3, 3, 3, 1};
return occupancy[detail::log2<kMaxSplits>::value - 2];
}
{
constexpr std::array occupancy{3, 3, 3, 3, 2, 1};
return occupancy[detail::log2<kMaxSplits>::value - 2];
}
{
}
}
}()
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:48
static constexpr index_t kHeadDimV
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:59
Definition block_fmha_fwd_splitkv_combine_pipeline.hpp:13
◆ kBlockSize
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
|
staticconstexpr |
◆ kHeadDimV
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
|
staticconstexpr |
◆ kIsGroupMode
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
|
staticconstexpr |
◆ kM0
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
|
staticconstexpr |
◆ kMaxSplits
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
|
staticconstexpr |
◆ kN1
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
|
staticconstexpr |
◆ kNumWarps
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
|
staticconstexpr |
◆ kPadHeadDimV
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
|
staticconstexpr |
◆ kPadSeqLenQ
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
|
staticconstexpr |
◆ kStoreLSE
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
|
staticconstexpr |
◆ name
template<typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy>
|
staticconstexpr |
The documentation for this struct was generated from the following file: