block_fmha_bwd_pipeline_problem.hpp Source File

block_fmha_bwd_pipeline_problem.hpp Source File#

Composable Kernel: block_fmha_bwd_pipeline_problem.hpp Source File
block_fmha_bwd_pipeline_problem.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
7
8namespace ck_tile {
9
10template <typename QDataType_,
11 typename KDataType_,
12 typename VDataType_,
13 typename GemmDataType_,
14 typename LSEDataType_,
15 typename AccDataType_,
16 typename DDataType_,
17 typename BiasDataType_,
18 typename RandValOutputDataType_,
19 typename ODataType_,
20 typename OGradDataType_,
21 typename QGradDataType_,
22 typename KGradDataType_,
23 typename VGradDataType_,
24 typename BiasGradDataType_,
25 typename BlockFmhaShape_,
26 bool kIsGroupMode_,
27 bool kIsDeterministic_,
28 typename FmhaMask_,
29 typename FmhaDropout_,
30 bool kUseTrLoad_,
31 typename Traits_>
33{
53
54 static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
55 static constexpr bool kIsGroupMode = kIsGroupMode_;
56 static constexpr bool kIsDeterministic = kIsDeterministic_;
57 static constexpr bool kUseTrLoad = kUseTrLoad_;
58
59 // attributes from traits
60 static constexpr index_t kPadHeadDimQ = Traits::kPadHeadDimQ;
61 static constexpr index_t kPadHeadDimV = Traits::kPadHeadDimV;
62 static constexpr auto BiasEnum = Traits::BiasEnum;
63 static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad;
64 static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
65};
66
67template <typename ODataType_,
68 typename OGradDataType_,
69 typename DDataType_,
70 index_t kBlockSize_,
71 index_t kVHeaddim_,
72 bool kIsGroupMode_,
73 typename Traits_>
75{
80
81 static_assert(0 < kBlockSize_ && kBlockSize_ % get_warp_size() == 0,
82 "kBlockSize should be divisible by get_warp_size()");
83
84 static constexpr index_t kBlockSize = kBlockSize_;
85 static constexpr index_t kVHeaddim = kVHeaddim_;
86 static constexpr bool kIsGroupMode = kIsGroupMode_;
87
88 // attributes from traits
89 static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
90 static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
91 static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
92};
93
94template <typename AccDataType_,
95 typename QGradDataType_,
96 index_t kBlockSize_,
97 index_t kM0_,
98 index_t kN0_,
99 index_t kQKHeaddim_,
100 bool kIsGroupMode_,
101 bool kIsDeterministic_,
102 typename Traits_>
104{
108
109 static_assert(0 < kBlockSize_ && kBlockSize_ % get_warp_size() == 0,
110 "kBlockSize should be divisible by get_warp_size()");
111
112 static constexpr index_t kBlockSize = kBlockSize_;
113 static constexpr index_t kM0 = kM0_;
114 static constexpr index_t kN0 = kN0_;
115 static constexpr index_t kQKHeaddim = kQKHeaddim_;
116 static constexpr bool kIsGroupMode = kIsGroupMode_;
117 static constexpr bool kIsDeterministic = kIsDeterministic_;
118
119 // attributes from traits
120 static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
121 static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
122 static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
123};
124
125} // namespace ck_tile
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
int32_t index_t
Definition integer.hpp:9
Definition block_fmha_bwd_pipeline_problem.hpp:104
static constexpr index_t kBlockPerCu
Definition block_fmha_bwd_pipeline_problem.hpp:122
static constexpr index_t kM0
Definition block_fmha_bwd_pipeline_problem.hpp:113
static constexpr index_t kQKHeaddim
Definition block_fmha_bwd_pipeline_problem.hpp:115
static constexpr bool kIsGroupMode
Definition block_fmha_bwd_pipeline_problem.hpp:116
static constexpr bool kPadHeadDimQ
Definition block_fmha_bwd_pipeline_problem.hpp:121
remove_cvref_t< QGradDataType_ > QGradDataType
Definition block_fmha_bwd_pipeline_problem.hpp:106
static constexpr index_t kBlockSize
Definition block_fmha_bwd_pipeline_problem.hpp:112
remove_cvref_t< AccDataType_ > AccDataType
Definition block_fmha_bwd_pipeline_problem.hpp:105
static constexpr bool kPadSeqLenQ
Definition block_fmha_bwd_pipeline_problem.hpp:120
static constexpr bool kIsDeterministic
Definition block_fmha_bwd_pipeline_problem.hpp:117
remove_cvref_t< Traits_ > Traits
Definition block_fmha_bwd_pipeline_problem.hpp:107
static constexpr index_t kN0
Definition block_fmha_bwd_pipeline_problem.hpp:114
Definition block_fmha_bwd_pipeline_problem.hpp:75
remove_cvref_t< ODataType_ > ODataType
Definition block_fmha_bwd_pipeline_problem.hpp:76
remove_cvref_t< OGradDataType_ > OGradDataType
Definition block_fmha_bwd_pipeline_problem.hpp:77
static constexpr bool kPadHeadDimV
Definition block_fmha_bwd_pipeline_problem.hpp:90
remove_cvref_t< Traits_ > Traits
Definition block_fmha_bwd_pipeline_problem.hpp:79
static constexpr index_t kVHeaddim
Definition block_fmha_bwd_pipeline_problem.hpp:85
static constexpr index_t kBlockPerCu
Definition block_fmha_bwd_pipeline_problem.hpp:91
static constexpr bool kPadSeqLenQ
Definition block_fmha_bwd_pipeline_problem.hpp:89
static constexpr bool kIsGroupMode
Definition block_fmha_bwd_pipeline_problem.hpp:86
remove_cvref_t< DDataType_ > DDataType
Definition block_fmha_bwd_pipeline_problem.hpp:78
static constexpr index_t kBlockSize
Definition block_fmha_bwd_pipeline_problem.hpp:84
Definition block_fmha_bwd_pipeline_problem.hpp:33
remove_cvref_t< BiasGradDataType_ > BiasGradDataType
Definition block_fmha_bwd_pipeline_problem.hpp:48
remove_cvref_t< Traits_ > Traits
Definition block_fmha_bwd_pipeline_problem.hpp:52
remove_cvref_t< FmhaMask_ > FmhaMask
Definition block_fmha_bwd_pipeline_problem.hpp:50
remove_cvref_t< GemmDataType_ > GemmDataType
Definition block_fmha_bwd_pipeline_problem.hpp:37
remove_cvref_t< KGradDataType_ > KGradDataType
Definition block_fmha_bwd_pipeline_problem.hpp:46
remove_cvref_t< DDataType_ > DDataType
Definition block_fmha_bwd_pipeline_problem.hpp:40
remove_cvref_t< BiasDataType_ > BiasDataType
Definition block_fmha_bwd_pipeline_problem.hpp:41
remove_cvref_t< FmhaDropout_ > FmhaDropout
Definition block_fmha_bwd_pipeline_problem.hpp:51
static constexpr auto BiasEnum
Definition block_fmha_bwd_pipeline_problem.hpp:62
remove_cvref_t< QGradDataType_ > QGradDataType
Definition block_fmha_bwd_pipeline_problem.hpp:45
static constexpr bool kIsGroupMode
Definition block_fmha_bwd_pipeline_problem.hpp:55
remove_cvref_t< QDataType_ > QDataType
Definition block_fmha_bwd_pipeline_problem.hpp:34
remove_cvref_t< KDataType_ > KDataType
Definition block_fmha_bwd_pipeline_problem.hpp:35
static constexpr bool kHasBiasGrad
Definition block_fmha_bwd_pipeline_problem.hpp:63
static constexpr index_t kPadHeadDimV
Definition block_fmha_bwd_pipeline_problem.hpp:61
remove_cvref_t< LSEDataType_ > LSEDataType
Definition block_fmha_bwd_pipeline_problem.hpp:38
static constexpr index_t kBlockSize
Definition block_fmha_bwd_pipeline_problem.hpp:54
static constexpr bool kUseTrLoad
Definition block_fmha_bwd_pipeline_problem.hpp:57
remove_cvref_t< VDataType_ > VDataType
Definition block_fmha_bwd_pipeline_problem.hpp:36
remove_cvref_t< OGradDataType_ > OGradDataType
Definition block_fmha_bwd_pipeline_problem.hpp:44
remove_cvref_t< AccDataType_ > AccDataType
Definition block_fmha_bwd_pipeline_problem.hpp:39
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition block_fmha_bwd_pipeline_problem.hpp:49
static constexpr index_t kBlockPerCu
Definition block_fmha_bwd_pipeline_problem.hpp:64
remove_cvref_t< ODataType_ > ODataType
Definition block_fmha_bwd_pipeline_problem.hpp:43
remove_cvref_t< VGradDataType_ > VGradDataType
Definition block_fmha_bwd_pipeline_problem.hpp:47
remove_cvref_t< RandValOutputDataType_ > RandValOutputDataType
Definition block_fmha_bwd_pipeline_problem.hpp:42
static constexpr index_t kPadHeadDimQ
Definition block_fmha_bwd_pipeline_problem.hpp:60
static constexpr bool kIsDeterministic
Definition block_fmha_bwd_pipeline_problem.hpp:56