fused_moegemm_kernel.hpp Source File

fused_moegemm_kernel.hpp Source File#

Composable Kernel: fused_moegemm_kernel.hpp Source File
fused_moegemm_kernel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
9#include <string>
10#include <type_traits>
11
12// clang-format off
13// [indexing implementation-1]
14// using M_a as constexpr block_size to partition all tokens into different slices
15// each slice map to one expert, and one expert can have multiple slices
16// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
17// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
18// tok-0 tok-1 tok-2 tok-3 tok-4
19// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
20//
21// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
22// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
23// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
24//
25// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
26// * this could be larger than actual, since actual tokens are on GPU
27//
28// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
29// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
30// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
31//
32// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
33//
34// * Note on token_id_per_expert/sorted_token_ids_ptr data:
35// currently we do not have topk information from the data of token_id_per_expert/sorted_token_ids_ptr.
36// In some cases(like smooth-quant), we need topk information to indexing into tokens quant from
37// different expert smooth quant. So we modify the number stored inside token_id_per_expert/sorted_token_ids_ptr
38//
39// 32bit 0........23 24.....31 bit
40// (data) -> (token_id | topk_id)
41// low 24 bit is for token id, top 8 bit is for topk id
42//
43// the input after smooth-quant is [token, topk, hidden_dim], originally it is [token, hidden_dim]
44// the input scale for token is [topk, token, 1], the smooth-quant scale for first gemm is [expert, interm_dim]
45//
46// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
47// * length is (max_num_tokens_padded + block_size - 1) / block_size
48//
49// num_tokens_post_padded_ptr : [28]
50// num_sorted_tiles_ptr : [7]
51//
52// * different from vLLM
53// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
54// 2)need sorted_weight_ptr
55// 3) use num_sorted_tiles_ptr, already divided by M_a
56//
57// * below used for indexing
58// 1) sorted_token_ids_ptr [max_num_tokens_padded]
59// 2) sorted_weight_ptr
60// 3) sorted_expert_ids_ptr
61// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
62//
63// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
64//
65// [indexing implementation-2]
66// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
67// tok-0 tok-1 tok-2 tok-3 tok-4
68// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
69//
70// we generate original rol/col id as
71// topk_rc_ids : [[0, 5, A], [1, 6, B], [2, 7, C], [3, 8, D], [4, 9, E]]
72// let x be one element of above, we can get:
73// tpok_row_id(token_id) = x % num_tokens(5)
74// tpok_col_id(expert_Id) = x / num_tokens
75// topk_row_id/col_id can be used to access original topk_ids/topk_weight
76//
77// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 5, 5]]
78// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
79// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
80//
81// we can get permuted_rc_ids:
82// [[0], [2, 3, 4], [1, 8], [5, 6, 7, D, 9], [], [A, B, C, E]]
83//
84//
85// clang-format on
86//
87namespace ck_tile {
88
89// m: num_tokens (or token*input-batch)
90// k: intermediate_size
91// n: intermediate_size used between 2 FC (TP slice this)
92// e: num expert
93// if doing pre-shuffle
94// nr : n / Block_Nr
95// kr : k / Block_Kr
96// w : fattened 1d wave buffer
98{
99 const void* a_ptr; // [m, k], input token
100 const void* a_scale_ptr; // [m, 1], token scale
101 const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
102 const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w])
103 const void* g_scale_ptr; // [e, 1, n], gate(up) scale
104 const void* d_scale_ptr; // [e, 1, k], down scale
105 const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
106 void* o_ptr; // [m, k], output token
107
108 const void* sorted_token_ids_ptr; // [max_num_tokens_padded]
109 const void* sorted_weight_ptr; // [max_num_tokens_padded]
110 const void* sorted_expert_ids_ptr; // [(max_num_tokens_padded + block_size - 1) / block_size]
111 const void* num_sorted_tiles_ptr; // [1]
112
114 index_t intermediate_size; // n / TP, for Gate/UP/Down
115 index_t num_tokens; // input number of tokens for current iteration
116 index_t num_experts; // number of groups
117 index_t topk; // need this?
118
119 index_t stride_token; // for input/output, stride for each row, should >= hidden_size
120};
121
122// This is scatter/gather b2b group-gemm
123template <typename Partitioner_, typename Pipeline_, typename Epilogue_>
125{
128 using Epilogue = remove_cvref_t<Epilogue_>; // TODO: not used
129 // static constexpr index_t kBlockPerCu = Pipeline::kBlockPerCu;
130 // static_assert(kBlockPerCu > 0);
131
132 using BlockShape = typename Pipeline::BlockShape; // this is FusedMoeGemmShape
133 static constexpr index_t kBlockSize = BlockShape::BlockSize;
134
135 using ADataType = typename Pipeline::Problem::ADataType;
136 using GDataType = typename Pipeline::Problem::GDataType;
137 using DDataType = typename Pipeline::Problem::DDataType;
138 using AccDataType = typename Pipeline::Problem::AccDataType;
139 using ODataType = typename Pipeline::Problem::ODataType;
140 using AScaleDataType = typename Pipeline::Problem::AScaleDataType;
141 using GScaleDataType = typename Pipeline::Problem::GScaleDataType;
142 using DScaleDataType = typename Pipeline::Problem::DScaleDataType;
143 using YSmoothScaleDataType = typename Pipeline::Problem::YSmoothScaleDataType;
144 using TopkWeightDataType = typename Pipeline::Problem::TopkWeightDataType;
145 using IndexDataType = typename Pipeline::Problem::IndexDataType;
146 using YDataType = typename Pipeline::Problem::YDataType;
147
148 using Traits = typename Pipeline::Problem::Traits;
149 static constexpr bool UseUK = true;
150
151 static constexpr bool IsGateOnly = Traits::IsGateOnly;
152 static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
153 static constexpr bool PadHiddenSize = Traits::PadHiddenSize;
154 static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize;
155
156 // clang-format off
157 template <typename T> struct t2s;
158 template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
159 template <> struct t2s<fp16_t> { static constexpr const char * name = "fp16"; };
160 template <> struct t2s<bf16_t> { static constexpr const char * name = "bf16"; };
161 template <> struct t2s<fp8_t> { static constexpr const char * name = "fp8"; };
162 template <> struct t2s<bf8_t> { static constexpr const char * name = "bf8"; };
163 template <> struct t2s<int8_t> { static constexpr const char * name = "int8"; };
164 // clang-format on
165
166 CK_TILE_HOST static std::string GetName()
167 {
168#define _SS_ std::string
169#define _TS_ std::to_string
170 // clang-format off
171 using S_ = BlockShape;
172
173 auto prec_str = [&] () {
174 std::string base_str = _SS_(t2s<ADataType>::name);
175 if (!std::is_same_v<ADataType, GDataType>) {
176 base_str += _SS_("_") + _SS_(t2s<GDataType>::name);
177 }
178 return base_str;
179 }();
180
181 return _SS_("fused_moe_") + _SS_(prec_str) + "_" + (IsGateOnly ? "g1u0_":"g1u1_") +
182 _TS_(S_::Block_M0) + "x" + _TS_(S_::Block_N0) + "x" + _TS_(S_::Block_K0) + "x" + _TS_(S_::Block_N1) + "_" +
183 _TS_(S_::WarpPerBlock_M0) + "x" + _TS_(S_::WarpPerBlock_N0) + "x" + _TS_(S_::WarpPerBlock_K0) + "_" +
184 _TS_(S_::Warp_M0) + "x" + _TS_(S_::Warp_N0) + "x" + _TS_(S_::Warp_K0) + "_" + _SS_(Pipeline::name);
185#undef _SS_
186#undef _TS_
187 // clang-format on
188 }
189
191 {
192 const void* a_ptr; // [m, k], input token
193 const void* a_scale_ptr; // [m, 1], token scale
194 const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
195 const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w])
196 const void* g_scale_ptr; // [e, 1, n], gate(up) scale
197 const void* d_scale_ptr; // [e, 1, k], down scale
198 const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
199 void* o_ptr; // [m, k], output token
200
202 const void* sorted_weight_ptr;
205
207 index_t intermediate_size; // n / TP, for Gate/Up/Down
208 index_t num_tokens; // input number of tokens for current iteration
209 index_t num_experts; // number of groups
210 index_t topk; // need this?
211
212 index_t stride_token; // for input/output, stride for each row, should >= hidden_size
213 };
214
215 // TODO: switch karg based on
218
219 CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
220 {
221 // TODO: hargs/kargs not guranteed to be the same
222 return bit_cast<Kargs>(hargs);
223 }
224
225 CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
226 {
227 constexpr index_t block_m = BlockShape::Block_M0;
228 int max_num_tokens_padded =
229 hargs.topk * hargs.num_tokens + hargs.num_experts * block_m - hargs.topk;
230 // printf("xxx max_num_tokens_padded:%d\n", max_num_tokens_padded);
231 return Partitioner::GridSize(max_num_tokens_padded, hargs.intermediate_size);
232 }
233
234 CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
235
236 CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
237
239 {
240 if constexpr(UseUK)
241 {
242 __shared__ CK_TILE_LDS_ADDR char smem[GetSmemSize()];
243 IndexDataType num_sorted_tiles = amd_wave_read_first_lane(
244 *reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr));
245
246 num_sorted_tiles = num_sorted_tiles / BlockShape::Block_M0;
247
248 const auto [sorted_tile_id, intermediate_tile_id] =
249 Partitioner{}(num_sorted_tiles, kargs.intermediate_size);
250 // if(threadIdx.x == 0)
251 // printf("bid:%d,%d, num_sorted_tiles:%d, sorted_tile_id:%d(%d),
252 // intermediate_tile_id:%d\n", static_cast<int>(blockIdx.x),
253 // static_cast<int>(blockIdx.y), num_sorted_tiles, sorted_tile_id, sorted_tile_id >=
254 // num_sorted_tiles? 1 : 0, intermediate_tile_id);
255 if(sorted_tile_id >= num_sorted_tiles)
256 return;
257
258 Pipeline{}(kargs, smem, sorted_tile_id, intermediate_tile_id);
259 }
260 else
261 {
262 // allocate LDS
263 // __shared__ char smem_ptr[GetSmemSize()];
264 IndexDataType num_sorted_tiles = amd_wave_read_first_lane(
265 *reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr));
266 constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
267
268 index_t nr_0 = kargs.intermediate_size / BlockShape::Block_Nr0;
269 index_t kr_0 = kargs.hidden_size / BlockShape::Block_Kr0;
270 index_t nr_1 = kargs.hidden_size / BlockShape::Block_Nr1; // should be same as kr_0
271 index_t kr_1 =
272 kargs.intermediate_size / BlockShape::Block_Kr1; // should be same as nr_0
273
274 index_t expert_stride_0 = kargs.intermediate_size * hidden_radio_0 * kargs.hidden_size;
275 index_t expert_stride_1 = kargs.intermediate_size * kargs.hidden_size;
276
277 __shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()];
278
279 // note this is in unit of tile, need multiple tile size to get the index
280 const auto [sorted_tile_id, intermediate_tile_id] =
281 Partitioner{}(num_sorted_tiles, kargs.intermediate_size);
282 if(sorted_tile_id >= num_sorted_tiles)
283 return;
284
285 const IndexDataType expert_id =
286 amd_wave_read_first_lane(reinterpret_cast<const IndexDataType*>(
287 kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
288
289 // index along intermediate_size
290 // index_t hidden_idx = __builtin_amdgcn_readfirstlane(intermediate_tile_id *
291 // BlockShape::Block_N0);
292 index_t interm_idx_nr =
293 amd_wave_read_first_lane(intermediate_tile_id * BlockShape::Block_Nr0);
294
295 const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col]
296 const auto sorted_token_id =
297 a_coord[number<0>{}] + sorted_tile_id * BlockShape::Block_M0;
298
299 index_t token_id =
300 reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[sorted_token_id];
301#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
302 token_id &= 0xffffff;
303#endif
304 auto topk_weight = reinterpret_cast<const TopkWeightDataType*>(
305 kargs.sorted_weight_ptr)[sorted_token_id];
306
307 const auto a_window = [&]() {
308 // A is already pre-padded in previous kernel
309 const ADataType* a_ptr = reinterpret_cast<const ADataType*>(kargs.a_ptr);
311 a_ptr,
312 make_tuple(kargs.num_tokens, kargs.hidden_size),
313 make_tuple(kargs.stride_token, 1),
315 number<1>{});
316
317 // gather is here use indexing transform
318 const auto a_gather_view_ = transform_tensor_view(
319 a_view_,
324
325 const auto a_window_ = make_tile_window(
326 a_gather_view_,
328 {0, 0});
329 return a_window_;
330 }();
331
332 // TODO: gtile using NSub to have less register pressure
333 const auto g_window = [&]() {
334 const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
335 static_cast<long_index_t>(expert_id) * expert_stride_0 +
336 interm_idx_nr * kr_0 * BlockShape::Block_W0;
338 g_ptr,
340 make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
342 number<1>{});
343 const auto g_view_1_ =
344 pad_tensor_view(g_view_,
349
350 const auto g_window_ = make_tile_window(g_view_1_,
354 {0, 0, 0});
355 return g_window_;
356 }();
357
358 const auto d_window = [&]() {
359 const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
360 static_cast<long_index_t>(expert_id) * expert_stride_1 +
361 interm_idx_nr * BlockShape::Block_W1;
362 // note interm_idx_nr is along the gemm-k dim of 2nd gemm
363
365 d_ptr,
366 make_tuple(nr_1, kr_1, BlockShape::Block_W1),
367 make_tuple(kr_1 * BlockShape::Block_W1, BlockShape::Block_W1, 1),
369 number<1>{});
370 const auto d_view_1_ =
371 pad_tensor_view(d_view_,
376
377 const auto d_window_ = make_tile_window(d_view_1_,
381 {0, 0, 0});
382 return d_window_;
383 }();
384
385 auto o_window = [&]() {
386 ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr);
389 o_ptr,
390 make_tuple(kargs.num_tokens, kargs.hidden_size),
391 make_tuple(kargs.stride_token, 1),
393 number<1>{});
394
395 // gather is here
396 auto o_scatter_view_ = transform_tensor_view(
397 o_view_,
402
403 auto o_window_ = make_tile_window(
404 o_scatter_view_,
406 {0, 0});
407 return o_window_;
408 }();
409
410 // do compute yeah
411 Pipeline{}(a_window,
412 g_window,
413 d_window,
414 o_window,
415 topk_weight,
416 smem,
417 kargs.hidden_size,
418 kargs.intermediate_size,
419 kargs.stride_token);
420 }
421 }
422};
423
424} // namespace ck_tile
#define _TS_
#define _SS_
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_LDS_ADDR
Definition config.hpp:58
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
@ atomic_add
Definition arch.hpp:58
int8_t int8_t
Definition int8.hpp:20
bfloat16_t bf16_t
Definition bfloat16.hpp:113
_Float16 fp16_t
Definition half.hpp:110
_BitInt(8) fp8_t
Definition float8.hpp:204
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X &x)
Definition bit_cast.hpp:11
int64_t long_index_t
Definition integer.hpp:11
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
unsigned _BitInt(8) bf8_t
Definition float8.hpp:206
@ global
Definition arch.hpp:48
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_view.hpp:511
CK_TILE_HOST_DEVICE constexpr auto make_indexing_transform(const UpLength &up_lengths, const Indices &indices)
Definition coordinate_transform.hpp:1680
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition fused_moegemm_kernel.hpp:98
const void * a_ptr
Definition fused_moegemm_kernel.hpp:99
index_t num_tokens
Definition fused_moegemm_kernel.hpp:115
const void * sorted_expert_ids_ptr
Definition fused_moegemm_kernel.hpp:110
void * o_ptr
Definition fused_moegemm_kernel.hpp:106
const void * num_sorted_tiles_ptr
Definition fused_moegemm_kernel.hpp:111
const void * a_scale_ptr
Definition fused_moegemm_kernel.hpp:100
const void * g_scale_ptr
Definition fused_moegemm_kernel.hpp:103
const void * sorted_weight_ptr
Definition fused_moegemm_kernel.hpp:109
const void * d_ptr
Definition fused_moegemm_kernel.hpp:102
const void * d_scale_ptr
Definition fused_moegemm_kernel.hpp:104
index_t topk
Definition fused_moegemm_kernel.hpp:117
index_t num_experts
Definition fused_moegemm_kernel.hpp:116
const void * sorted_token_ids_ptr
Definition fused_moegemm_kernel.hpp:108
const void * g_ptr
Definition fused_moegemm_kernel.hpp:101
index_t intermediate_size
Definition fused_moegemm_kernel.hpp:114
index_t hidden_size
Definition fused_moegemm_kernel.hpp:113
const void * y_smooth_scale_ptr
Definition fused_moegemm_kernel.hpp:105
index_t stride_token
Definition fused_moegemm_kernel.hpp:119
Definition fused_moegemm_kernel.hpp:191
index_t topk
Definition fused_moegemm_kernel.hpp:210
void * o_ptr
Definition fused_moegemm_kernel.hpp:199
const void * sorted_expert_ids_ptr
Definition fused_moegemm_kernel.hpp:203
index_t intermediate_size
Definition fused_moegemm_kernel.hpp:207
index_t hidden_size
Definition fused_moegemm_kernel.hpp:206
const void * y_smooth_scale_ptr
Definition fused_moegemm_kernel.hpp:198
const void * a_ptr
Definition fused_moegemm_kernel.hpp:192
index_t num_tokens
Definition fused_moegemm_kernel.hpp:208
const void * g_scale_ptr
Definition fused_moegemm_kernel.hpp:196
const void * d_ptr
Definition fused_moegemm_kernel.hpp:195
index_t num_experts
Definition fused_moegemm_kernel.hpp:209
const void * a_scale_ptr
Definition fused_moegemm_kernel.hpp:193
const void * sorted_weight_ptr
Definition fused_moegemm_kernel.hpp:202
const void * g_ptr
Definition fused_moegemm_kernel.hpp:194
index_t stride_token
Definition fused_moegemm_kernel.hpp:212
const void * num_sorted_tiles_ptr
Definition fused_moegemm_kernel.hpp:204
const void * d_scale_ptr
Definition fused_moegemm_kernel.hpp:197
const void * sorted_token_ids_ptr
Definition fused_moegemm_kernel.hpp:201
static constexpr const char * name
Definition fused_moegemm_kernel.hpp:160
static constexpr const char * name
Definition fused_moegemm_kernel.hpp:162
static constexpr const char * name
Definition fused_moegemm_kernel.hpp:158
static constexpr const char * name
Definition fused_moegemm_kernel.hpp:159
static constexpr const char * name
Definition fused_moegemm_kernel.hpp:161
static constexpr const char * name
Definition fused_moegemm_kernel.hpp:163
Definition fused_moegemm_kernel.hpp:157
Definition fused_moegemm_kernel.hpp:125
static constexpr bool UseUK
Definition fused_moegemm_kernel.hpp:149
typename Pipeline::Problem::ADataType ADataType
Definition fused_moegemm_kernel.hpp:135
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition fused_moegemm_kernel.hpp:238
typename Pipeline::Problem::GDataType GDataType
Definition fused_moegemm_kernel.hpp:136
typename Pipeline::Problem::Traits Traits
Definition fused_moegemm_kernel.hpp:148
static constexpr bool PadIntermediateSize
Definition fused_moegemm_kernel.hpp:154
typename Pipeline::Problem::TopkWeightDataType TopkWeightDataType
Definition fused_moegemm_kernel.hpp:144
remove_cvref_t< Partitioner_ > Partitioner
Definition fused_moegemm_kernel.hpp:126
typename Pipeline::Problem::DDataType DDataType
Definition fused_moegemm_kernel.hpp:137
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition fused_moegemm_kernel.hpp:236
remove_cvref_t< Pipeline_ > Pipeline
Definition fused_moegemm_kernel.hpp:127
static constexpr bool UseSmoothQuant
Definition fused_moegemm_kernel.hpp:152
static constexpr index_t kBlockSize
Definition fused_moegemm_kernel.hpp:133
static CK_TILE_HOST constexpr auto BlockSize()
Definition fused_moegemm_kernel.hpp:234
typename Pipeline::Problem::DScaleDataType DScaleDataType
Definition fused_moegemm_kernel.hpp:142
typename Pipeline::Problem::AScaleDataType AScaleDataType
Definition fused_moegemm_kernel.hpp:140
typename Pipeline::Problem::GScaleDataType GScaleDataType
Definition fused_moegemm_kernel.hpp:141
typename Pipeline::Problem::AccDataType AccDataType
Definition fused_moegemm_kernel.hpp:138
static CK_TILE_HOST constexpr auto GridSize(const Hargs &hargs)
Definition fused_moegemm_kernel.hpp:225
typename Pipeline::Problem::ODataType ODataType
Definition fused_moegemm_kernel.hpp:139
FusedMoeGemmHostArgs Hargs
Definition fused_moegemm_kernel.hpp:217
typename Pipeline::Problem::IndexDataType IndexDataType
Definition fused_moegemm_kernel.hpp:145
static constexpr bool IsGateOnly
Definition fused_moegemm_kernel.hpp:151
static constexpr bool PadHiddenSize
Definition fused_moegemm_kernel.hpp:153
typename Pipeline::Problem::YSmoothScaleDataType YSmoothScaleDataType
Definition fused_moegemm_kernel.hpp:143
remove_cvref_t< Epilogue_ > Epilogue
Definition fused_moegemm_kernel.hpp:128
static CK_TILE_HOST constexpr Kargs MakeKargs(const Hargs &hargs)
Definition fused_moegemm_kernel.hpp:219
typename Pipeline::Problem::YDataType YDataType
Definition fused_moegemm_kernel.hpp:146
typename Pipeline::BlockShape BlockShape
Definition fused_moegemm_kernel.hpp:132
static CK_TILE_HOST std::string GetName()
Definition fused_moegemm_kernel.hpp:166
FusedMoeGemmKargs Kargs
Definition fused_moegemm_kernel.hpp:216
Definition tile/core/container/sequence.hpp:49