rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp Source File

rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp Source File#

Composable Kernel: rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp Source File
rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8#include <string>
9#include <type_traits>
10
11namespace ck_tile {
12
30
31template <typename Problem_, typename Policy_ = Rmsnorm2dFwdPipelineDefaultPolicy>
33{
36
42
45
46 static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
47 static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
48 static constexpr bool kSaveUnquant = Problem::Traits::kSaveUnquant;
49
50 static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
51 static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM
52 static constexpr bool kPadN = Problem::Traits::kPadN;
53 static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
54 static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
55
56 static constexpr const char* name = []() {
57 if constexpr(kNeedCrossWarpSync)
58 return "bpr_op"; // block per row
59 else
60 return "wpr_op"; // warp per row
61 }();
62
64 {
65 return Policy::template GetSmemSize<Problem>();
66 }
67
68 template <typename XWindow,
69 typename XResidualWindow,
70 typename GammaWindow,
71 typename YWindow,
72 typename YResidualWindow,
73 typename InvRmsWindow,
74 typename SmoothScaleWindow,
75 typename YScaleWindow,
76 typename UnquantYWindow,
77 typename Epilogue>
78 CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
79 const XResidualWindow& x_residual_window_,
80 const GammaWindow& gamma_window_,
81 YWindow& y_window_,
82 const YResidualWindow& y_residual_window_,
83 InvRmsWindow& inv_rms_window,
84 const SmoothScaleWindow& sm_scale_window_,
85 YScaleWindow& y_scale_window_,
86 UnquantYWindow& unquant_y_window,
87 ComputeDataType epsilon,
88 ck_tile::index_t row_size,
89 void* smem,
90 Epilogue) const
91 {
92 const auto x_window =
93 make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
94 const auto gamma_window = make_tile_window(
95 gamma_window_, Policy::template MakeGammaBlockTileDistribution<Problem>());
96 const auto x_residual_window = make_tile_window(
97 x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
98 auto y_residual_window = make_tile_window(
99 y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
100
101 auto reduce_square_sum_func = ReduceOp::SquareAdd{};
102 auto reduce_sum_func = ReduceOp::Add{};
103 auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
104 auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
105 auto block_reduce2d_cross_warp_sync =
106 Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
107
108 auto x = load_tile(x_window);
109 auto x_resi = load_tile(x_residual_window);
110
111 // load gamma (TODO: support no gamma?)
112 const auto gamma = load_tile(gamma_window);
113
114 auto acc = cast_tile<ComputeDataType>(x);
115
118 {
119 [[maybe_unused]] auto pre_out =
120 make_static_distributed_tensor<YResidualDataType>(x.get_tile_distribution());
121
122 sweep_tile(x_resi, [&](auto idx) {
123 // compute x = x_resi + x
124 acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
125
126 // To make norm input align with residual output
128 {
129 if constexpr(std::is_same_v<YResidualDataType, ck_tile::bf16_t>)
130 {
131 pre_out(idx) = float_to_bf16<bf16_rounding_mode::standard>(acc(idx));
132 }
133 else
134 {
135 pre_out(idx) = type_convert<YResidualDataType>(acc(idx));
136 }
137 acc(idx) = type_convert<ComputeDataType>(pre_out(idx));
138 }
139 });
141 {
142 store_tile(y_residual_window, pre_out);
143 }
144 }
145
146 // compute mean square each-thread->cross-lane->cross-warp
147 auto square_sum = block_reduce2d.template MakeYBlockTile<decltype(acc)>();
148 set_tile(square_sum, 0);
149 if constexpr((Problem::BlockShape::Repeat_N * Problem::BlockShape::Vector_N) % 2 == 0)
150 {
152 acc,
153 [&](auto idx_0, auto idx_1) {
154 square_sum(idx_0) += acc[idx_0] * acc[idx_0] + acc[idx_1] * acc[idx_1];
155 },
157 }
158 else
159 {
160 square_sum = block_reduce2d(acc,
161 reduce_square_sum_func.GetIdentityValue<ComputeDataType>(),
162 reduce_square_sum_func);
163 }
164 block_reduce2d_sync(square_sum, reduce_sum_func);
165 block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func);
166
167 // compute inv-rms
168 auto inv_rms = tile_elementwise_in(
169 [&](const auto& v_) { return rsqrtf(v_ / row_size + epsilon); }, square_sum);
170
171 if constexpr(kSaveInvRms)
172 store_tile(inv_rms_window, cast_tile<InvRmsDataType>(inv_rms));
173
174 // rmsnorm computation
175 auto rmsn = make_static_distributed_tensor<ComputeDataType>(x.get_tile_distribution());
176 sweep_tile(rmsn, [&, inv_rms_ = inv_rms](auto idx) {
177 constexpr auto i_idx = make_tuple(idx[number<0>{}]);
178 constexpr auto j_idx = make_tuple(idx[number<1>{}]);
179
180 const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
181
182 if constexpr(std::is_same_v<XDataType, ck_tile::bf16_t>)
183 {
184 const auto tmp0 =
185 float_to_bf16<bf16_rounding_mode::standard>(acc[idx] * inv_rms_[i_idx]);
187 type_convert<ComputeDataType>(tmp0) * gamma_);
188 const auto rmsn_ = type_convert<ComputeDataType>(tmp1);
189 rmsn(idx) = rmsn_;
190 }
191 else
192 {
193 const auto tmp = type_convert<XDataType>(acc[idx] * inv_rms_[i_idx]);
194 const auto rmsn_ = type_convert<ComputeDataType>(tmp) * gamma_;
195 rmsn(idx) = rmsn_;
196 }
197 });
198
200 {
201 if constexpr(kSaveUnquant)
202 {
203 Epilogue{}(
204 unquant_y_window, y_window_, sm_scale_window_, y_scale_window_, rmsn, smem);
205 }
206 else
207 {
208 Epilogue{}(y_window_, sm_scale_window_, y_scale_window_, rmsn, smem);
209 }
210 }
212 {
213 if constexpr(kSaveUnquant)
214 {
215 Epilogue{}(unquant_y_window, y_window_, y_scale_window_, rmsn, smem);
216 }
217 else
218 {
219 Epilogue{}(y_window_, y_scale_window_, rmsn, smem);
220 }
221 }
222 else
223 {
224 Epilogue{}(y_window_, rmsn, nullptr);
225 }
226 }
227};
228} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr bfloat16_t float_to_bf16(float f, constant< rounding >={})
Definition bfloat16.hpp:284
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition tile_elementwise.hpp:95
@ SMOOTH_DYNAMIC_QUANT
Definition rmsnorm2d_fwd_traits.hpp:29
@ DYNAMIC_QUANT
Definition rmsnorm2d_fwd_traits.hpp:30
@ PRE_ADD_STORE
Definition rmsnorm2d_fwd_traits.hpp:14
@ PRE_ADD
Definition rmsnorm2d_fwd_traits.hpp:16
CK_TILE_HOST_DEVICE constexpr void sweep_tile(const F &f, UnpacksPerXDim={})
Definition sweep_tile.hpp:231
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition reduce_operator.hpp:14
Definition reduce_operator.hpp:40
This T5Pass implements the RMSNorm2d forward pipeline as a variant based on Rmsnorm2dFwdPipelineOnePa...
Definition rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp:33
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp:63
static constexpr auto kFusedAdd
Definition rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp:53
static constexpr bool kSaveInvRms
Definition rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp:47
ck_tile::remove_cvref_t< Problem_ > Problem
Definition rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp:34
ck_tile::remove_cvref_t< Policy_ > Policy
Definition rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp:35
ck_tile::remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp:39
ck_tile::remove_cvref_t< typename Problem::XDataType > XDataType
Definition rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp:37
ck_tile::remove_cvref_t< typename Problem::GammaDataType > GammaDataType
Definition rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp:38
ck_tile::remove_cvref_t< typename Problem::YDataType > YDataType
Definition rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp:40
static constexpr auto kFusedQuant
Definition rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp:54
XDataType XResidualDataType
Definition rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp:43
ck_tile::remove_cvref_t< typename Problem::InvRmsDataType > InvRmsDataType
Definition rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp:41
CK_TILE_DEVICE auto operator()(const XWindow &x_window_, const XResidualWindow &x_residual_window_, const GammaWindow &gamma_window_, YWindow &y_window_, const YResidualWindow &y_residual_window_, InvRmsWindow &inv_rms_window, const SmoothScaleWindow &sm_scale_window_, YScaleWindow &y_scale_window_, UnquantYWindow &unquant_y_window, ComputeDataType epsilon, ck_tile::index_t row_size, void *smem, Epilogue) const
Definition rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp:78
static constexpr bool kNeedCrossWarpSync
Definition rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp:50
static constexpr bool kPadM
Definition rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp:51
static constexpr bool kSaveUnquant
Definition rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp:48
static constexpr bool kPadN
Definition rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp:52
static constexpr const char * name
Definition rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp:56
XDataType YResidualDataType
Definition rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp:44
static constexpr bool kHasGamma
Definition rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp:46
Definition tile/core/container/sequence.hpp:49