fmha_bwd_kernel.hpp Source File

fmha_bwd_kernel.hpp Source File#

Composable Kernel: fmha_bwd_kernel.hpp Source File
fmha_bwd_kernel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
10
11#include <string>
12#include <type_traits>
13#include <utility>
14#include <variant>
15
16// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
17// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
18// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
19// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
20// dV[seqlen_k, hdim_v] = P^T[seqlen_k, seqlen_q] @ dO^T[hdim_v, seqlen_q]
21// dP[seqlen_q, seqlen_k] = dO[seqlen_q, hdim_v] @ V[seqlen_k, hdim_v]
22// D[seqlen_q] = rowsum(dO[seqlen_q, hdim_v] * O[seqlen_q, hdim_v])
23// dS''[seqlen_q, seqlen_k] = P[seqlen_q, seqlen_k] * (dP[seqlen_q, seqlen_k] - D[seqlen_q])
24// dBias[seqlen_q, seqlen_k] = dS'[seqlen_q, seqlen_k] = dS''[seqlen_q, seqlen_k]
25// dK[seqlen_k, hdim_q] = dS'^T[seqlen_k, seqlen_q] @ Q^T[hdim_q, seqlen_q] * Scale[1]
26// dQ[seqlen_q, hdim_q] = dS'[seqlen_q, seqlen_k] @ K^T[hdim_q, seqlen_k] * Scale[1]
27
28namespace ck_tile {
29
30template <typename FmhaPipeline_,
31 typename KGradEpiloguePipeline_,
32 typename VGradEpiloguePipeline_,
33 typename QGradEpiloguePipeline_ = void>
35{
40 static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
41 static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
42 static constexpr bool kUseQrQtrDorPipeline =
44 static_assert(!kUseQrQtrDorPipeline || !std::is_same_v<QGradEpiloguePipeline_, void>,
45 "QrQtrDorPipeline needs QGradEpiloguePipeline");
46
62
63 static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
64 static constexpr index_t kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
65 static constexpr index_t kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
66 static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
67 static constexpr bool kHasBiasGrad = FmhaPipeline::kHasBiasGrad;
70 static constexpr bool kHasMask = FmhaMask::IsMasking;
71 static constexpr bool kHasDropout = FmhaDropout::IsDropout;
72 static constexpr bool kIsStoreRandval = FmhaDropout::IsStoreRandval;
73 static constexpr bool kIsDeterministic = FmhaPipeline::kIsDeterministic;
74 static constexpr bool kUseTrLoad = FmhaPipeline::kUseTrLoad;
75 static constexpr index_t kMaxSeqLenQ = FmhaPipeline::BlockFmhaShape::kMaxSeqLenQ;
76 static_assert(kUseQrQtrDorPipeline == (kMaxSeqLenQ != 0));
77#if defined(__gfx950__)
78 static constexpr bool kIsAvailable = true;
79#else
80 static constexpr bool kIsAvailable = !kUseTrLoad;
81#endif
82
83 // clang-format off
84 template <typename T> struct t2s;
85 template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
86 template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
87 template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
88 // clang-format on
89
90 CK_TILE_HOST static std::string GetName()
91 {
92 // sync with generate.py
93 // clang-format off
94 using bfs = typename FmhaPipeline::BlockFmhaShape;
95 using gbr0 = typename bfs::Gemm0BlockWarps;
96 using gbr1 = typename bfs::Gemm1BlockWarps;
97 using gbr4 = typename bfs::Gemm4BlockWarps;
98 using gwt0 = typename bfs::Gemm0WarpTile;
99 using gwt1 = typename bfs::Gemm1WarpTile;
100 #define _SS_ std::string
101 #define _TS_ std::to_string
102 auto pn = [&] () {
103 std::string n;
104 if (kPadHeadDimQ) n += "d" + _TS_(kPadHeadDimQ);
105 if (kPadHeadDimV) n += "dv"+ _TS_(kPadHeadDimV);
106 return n.empty() ? n : std::string("p") + n; }();
107 return
108 _SS_("fmha_bwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
109 "_" + (kIsGroupMode ? "group" : "batch") + "_" +
110 "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK2) + "x" + _TS_(bfs::kK3) + "x" +
111 _TS_(bfs::kK4) + "x" + _TS_(bfs::kQKHeaddim) + "x" + _TS_(bfs::kVHeaddim) + "_" +
112 "r" + _TS_(gbr0::at(ck_tile::number<0>{})) + "x" + _TS_(gbr0::at(ck_tile::number<1>{})) + "x" + _TS_(gbr0::at(ck_tile::number<2>{})) + "_" +
113 "r" + _TS_(gbr1::at(ck_tile::number<0>{})) + "x" + _TS_(gbr1::at(ck_tile::number<1>{})) + "x" + _TS_(gbr1::at(ck_tile::number<2>{})) + "_" +
114 "r" + _TS_(gbr4::at(ck_tile::number<0>{})) + "x" + _TS_(gbr4::at(ck_tile::number<1>{})) + "x" + _TS_(gbr4::at(ck_tile::number<2>{})) + "_" +
115 "w" + _TS_(gwt0::at(ck_tile::number<0>{})) + "x" + _TS_(gwt0::at(ck_tile::number<1>{})) + "x" + _TS_(gwt0::at(ck_tile::number<2>{})) + "_" +
116 "w" + _TS_(gwt1::at(ck_tile::number<0>{})) + "x" + _TS_(gwt1::at(ck_tile::number<1>{})) + "x" + _TS_(gwt1::at(ck_tile::number<2>{})) + "_" +
117 ("o" + _TS_(kBlockPerCu)) + "_" +
118 ("maxq" + _TS_(kMaxSeqLenQ)) +
119 (pn.empty() ? "_npad" : "_" + pn) +
121 (kHasBiasGrad ? "_dbias" : "_ndbias") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kHasDropout ? gwt0::at(ck_tile::number<0>{}) == 16? "_dropout_wg16":"_dropout_wg32" : "_ndropout" ) +
122 (kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : "_ndeterministic" ) + (kUseTrLoad ? "_trload" : "_ntrload");
123 #undef _SS_
124 #undef _TS_
125 // clang-format on
126 }
127
128 template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
129 // arg
131 {
132 };
133
134 // kargs use aggregate initializer, so no constructor will provided
135 // use inheritance to minimize karg size
136 // user need to use MakeKargs() function to create kargs.
178
185
190
192 {
193 // alibi is batch*nhead*1, no matter in batch/group mode, they are the same
194 const void* alibi_slope_ptr;
195 ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
196 };
197
204
209
215
229
231 {
232 void init_dropout(float p_drop, uint64_t seed, uint64_t offset, float raw_scale)
233 {
234 float p_undrop = 1.0 - p_drop;
236 uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
237 rp_undrop = 1.0 / p_undrop;
238 scale_rp_undrop = rp_undrop * raw_scale;
239
240 this->drop_seed.val = seed;
241 this->drop_offset.val = offset;
243 }
244
245 void init_dropout(float p_drop,
246 const uint64_t* seed_ptr,
247 const uint64_t* offset_ptr,
248 float raw_scale)
249 {
250 float p_undrop = 1.0 - p_drop;
252 uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
253 rp_undrop = 1.0 / p_undrop;
254 scale_rp_undrop = rp_undrop * raw_scale;
255
256 this->drop_seed.ptr = seed_ptr;
257 this->drop_offset.ptr = offset_ptr;
258 this->is_drop_seed_offset_from_host = false;
259 }
260
261 float rp_undrop = 1;
263 uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
264 void* rand_val_ptr = nullptr;
265
268 };
269
274
279
282 std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
283 FmhaBwdBatchModeBiasKargs,
284 std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
285 FmhaBwdAlibiKargs,
286 FmhaBwdEmptyKargs<0>>>,
287 std::conditional_t<kHasBiasGrad, FmhaBwdBatchModeBiasGradKargs, FmhaBwdEmptyKargs<1>>,
288 std::conditional_t<kHasMask, FmhaBwdMaskKargs, FmhaBwdEmptyKargs<2>>,
289 std::conditional_t<kHasDropout, FmhaBwdBatchModeDropoutKargs, FmhaBwdEmptyKargs<3>>,
290 std::conditional_t<kIsDeterministic, FmhaBwdDeterministicKargs, FmhaBwdEmptyKargs<4>>
291 {
300 };
301
304 std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
305 FmhaBwdCommonBiasKargs,
306 std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
307 FmhaBwdAlibiKargs,
308 FmhaBwdEmptyKargs<0>>>,
309 std::conditional_t<kHasBiasGrad, FmhaBwdCommonBiasGradKargs, FmhaBwdEmptyKargs<1>>,
310 std::conditional_t<kHasMask, FmhaBwdMaskKargs, FmhaBwdEmptyKargs<2>>,
311 std::conditional_t<kHasDropout, FmhaBwdCommonDropoutKargs, FmhaBwdEmptyKargs<3>>,
312 std::conditional_t<kIsDeterministic, FmhaBwdDeterministicKargs, FmhaBwdEmptyKargs<4>>
313 {
316 const int32_t* seqlen_q_ptr; // per-batch actual length [batch]
317 const int32_t* seqlen_k_ptr; // per-batch actual length [batch]
318 const int32_t* cu_seqlen_q_ptr; // cumulative seqlen [batch+1], optional
319 const int32_t* cu_seqlen_k_ptr; // cumulative seqlen [batch+1], optional
320 };
321
322 using Kargs = std::conditional_t<kIsGroupMode, FmhaBwdGroupModeKargs, FmhaBwdBatchModeKargs>;
323
324 // std::variant<> can't take in a list initializer, overload for backward compatibility
325 template <typename... Ts>
326 CK_TILE_HOST static constexpr Kargs
327 MakeKargs(Ts... args, const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
328 {
329 return MakeKargsImpl(
330 args..., std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
331 }
332
333 // std::variant<> can't take in a list initializer, overload for backward compatibility
334 template <typename... Ts>
335 CK_TILE_HOST static constexpr Kargs
336 MakeKargs(Ts... args, const std::tuple<const void*, const void*>& drop_seed_offset)
337 {
338 return MakeKargsImpl(
339 args..., std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
340 }
341
342 template <bool Cond = !kIsGroupMode>
343 CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
344 MakeKargsImpl(const void* q_ptr,
345 const void* k_ptr,
346 const void* v_ptr,
347 const void* bias_ptr,
348 const void* lse_ptr,
349 const void* do_ptr,
350 const void* d_ptr,
351 void* rand_val_ptr,
352 void* dk_ptr,
353 void* dv_ptr,
354 void* dbias_ptr,
355 void* dq_acc_ptr, // can be dq_acc_ptr for qrqtrdor pipeline
356 ck_tile::index_t seqlen_q,
357 ck_tile::index_t seqlen_k,
358 ck_tile::index_t hdim_q,
359 ck_tile::index_t hdim_v,
360 ck_tile::index_t num_head_q,
361 ck_tile::index_t nhead_ratio_qk,
362 float scale,
363 ck_tile::index_t stride_q,
364 ck_tile::index_t stride_k,
365 ck_tile::index_t stride_v,
366 ck_tile::index_t stride_bias,
367 ck_tile::index_t stride_randval,
368 ck_tile::index_t stride_do,
369 ck_tile::index_t stride_dq_acc,
370 ck_tile::index_t stride_dk,
371 ck_tile::index_t stride_dv,
372 ck_tile::index_t stride_dbias,
373 ck_tile::index_t nhead_stride_q,
374 ck_tile::index_t nhead_stride_k,
375 ck_tile::index_t nhead_stride_v,
376 ck_tile::index_t nhead_stride_bias,
377 ck_tile::index_t nhead_stride_randval,
378 ck_tile::index_t nhead_stride_do,
379 ck_tile::index_t nhead_stride_lsed,
380 ck_tile::index_t nhead_stride_dq_acc,
381 ck_tile::index_t nhead_stride_dk,
382 ck_tile::index_t nhead_stride_dv,
383 ck_tile::index_t nhead_stride_dbias,
384 ck_tile::index_t batch_stride_q,
385 ck_tile::index_t batch_stride_k,
386 ck_tile::index_t batch_stride_v,
387 ck_tile::index_t batch_stride_bias,
388 ck_tile::index_t batch_stride_randval,
389 ck_tile::index_t batch_stride_do,
390 ck_tile::index_t batch_stride_lsed,
391 ck_tile::index_t batch_stride_dq_acc,
392 ck_tile::index_t batch_stride_dk,
393 ck_tile::index_t batch_stride_dv,
394 ck_tile::index_t batch_stride_dbias,
395 ck_tile::index_t split_stride_dq_acc,
396 ck_tile::index_t window_size_left,
397 ck_tile::index_t window_size_right,
398 ck_tile::index_t mask_type,
399 float p_drop,
400 std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
401 drop_seed_offset)
402 {
403 Kargs kargs{{q_ptr,
404 k_ptr,
405 v_ptr,
406 lse_ptr,
407 do_ptr,
408 d_ptr,
409 dq_acc_ptr,
410 dk_ptr,
411 dv_ptr,
412 seqlen_q,
413 seqlen_k,
414 hdim_q,
415 hdim_v,
416 num_head_q,
417 nhead_ratio_qk,
418 scale,
419 static_cast<float>(scale * ck_tile::log2e_v<>),
420 stride_q,
421 stride_k,
422 stride_v,
423 stride_do,
424 stride_dq_acc,
425 stride_dk,
426 stride_dv,
427 nhead_stride_q,
428 nhead_stride_k,
429 nhead_stride_v,
430 nhead_stride_do,
431 nhead_stride_lsed,
432 nhead_stride_dq_acc,
433 nhead_stride_dk,
434 nhead_stride_dv}, // args for common karg
435 {}, // placeholder for bias
436 {}, // placeholder for dbias
437 {}, // placeholder for mask
438 {}, // placeholder for dropout
439 {}, // placeholder for deterministic
440 batch_stride_q,
441 batch_stride_k,
442 batch_stride_v,
443 batch_stride_do,
444 batch_stride_lsed,
445 batch_stride_dq_acc,
446 batch_stride_dk,
447 batch_stride_dv};
448
450 {
451 kargs.bias_ptr = bias_ptr;
452 kargs.stride_bias = stride_bias;
453 kargs.nhead_stride_bias = nhead_stride_bias;
454 kargs.batch_stride_bias = batch_stride_bias;
455 }
456 else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
457 {
458 kargs.alibi_slope_ptr = bias_ptr;
459 kargs.alibi_slope_stride = stride_bias;
460 }
461
462 if constexpr(kHasBiasGrad)
463 {
464 kargs.dbias_ptr = dbias_ptr;
465 kargs.stride_dbias = stride_dbias;
466 kargs.nhead_stride_dbias = nhead_stride_dbias;
467 kargs.batch_stride_dbias = batch_stride_dbias;
468 }
469
470 if constexpr(kHasMask)
471 {
472 kargs.window_size_left = window_size_left;
473 kargs.window_size_right = window_size_right;
474 kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
475 }
476
477 if constexpr(kHasDropout)
478 {
479 if(drop_seed_offset.index() == 0) // seed & offset come from host
480 {
481 const auto& [seed, offset] = std::get<0>(drop_seed_offset);
482 kargs.init_dropout(p_drop, seed, offset, scale);
483 }
484 else // seed & offset come from device
485 {
486 const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
487 kargs.init_dropout(p_drop,
488 reinterpret_cast<const uint64_t*>(seed_ptr),
489 reinterpret_cast<const uint64_t*>(offset_ptr),
490 scale);
491 }
492
493 if constexpr(kIsStoreRandval)
494 {
495 kargs.rand_val_ptr = rand_val_ptr;
496 kargs.stride_randval = stride_randval;
497 kargs.nhead_stride_randval = nhead_stride_randval;
498 kargs.batch_stride_randval = batch_stride_randval;
499 }
500 }
501
502 if constexpr(kIsDeterministic && !kUseQrQtrDorPipeline)
503 {
504 kargs.split_stride_dq_acc = split_stride_dq_acc;
505 }
506
507 return kargs;
508 }
509
510 template <bool Cond = kIsGroupMode>
511 CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
512 MakeKargsImpl(const void* q_ptr,
513 const void* k_ptr,
514 const void* v_ptr,
515 const void* bias_ptr,
516 const void* lse_ptr,
517 const void* do_ptr,
518 const void* d_ptr,
519 void* rand_val_ptr,
520 void* dk_ptr,
521 void* dv_ptr,
522 void* dbias_ptr,
523 void* dq_acc_ptr,
524 const void* seqstart_q_ptr,
525 const void* seqstart_k_ptr,
526 const void* seqlen_q_ptr,
527 const void* seqlen_k_ptr,
528 const void* cu_seqlen_q_ptr,
529 const void* cu_seqlen_k_ptr,
530 ck_tile::index_t hdim_q,
531 ck_tile::index_t hdim_v,
532 ck_tile::index_t num_head_q,
533 ck_tile::index_t nhead_ratio_qk,
534 float scale,
535 ck_tile::index_t stride_q,
536 ck_tile::index_t stride_k,
537 ck_tile::index_t stride_v,
538 ck_tile::index_t stride_bias,
539 ck_tile::index_t stride_randval,
540 ck_tile::index_t stride_do,
541 ck_tile::index_t stride_dq_acc,
542 ck_tile::index_t stride_dk,
543 ck_tile::index_t stride_dv,
544 ck_tile::index_t stride_dbias,
545 ck_tile::index_t nhead_stride_q,
546 ck_tile::index_t nhead_stride_k,
547 ck_tile::index_t nhead_stride_v,
548 ck_tile::index_t nhead_stride_bias,
549 ck_tile::index_t nhead_stride_randval,
550 ck_tile::index_t nhead_stride_do,
551 ck_tile::index_t nhead_stride_lsed,
552 ck_tile::index_t nhead_stride_dq_acc,
553 ck_tile::index_t nhead_stride_dk,
554 ck_tile::index_t nhead_stride_dv,
555 ck_tile::index_t nhead_stride_dbias,
556 ck_tile::index_t split_stride_dq_acc,
557 ck_tile::index_t window_size_left,
558 ck_tile::index_t window_size_right,
559 ck_tile::index_t mask_type,
560 float p_drop,
561 std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
562 drop_seed_offset)
563 {
564 Kargs kargs{{q_ptr,
565 k_ptr,
566 v_ptr,
567 lse_ptr,
568 do_ptr,
569 d_ptr,
570 dq_acc_ptr,
571 dk_ptr,
572 dv_ptr,
573 -1, // seqlen will be updated by another pointer
574 -1, //
575 hdim_q,
576 hdim_v,
577 num_head_q,
578 nhead_ratio_qk,
579 scale,
580 static_cast<float>(scale * ck_tile::log2e_v<>),
581 stride_q,
582 stride_k,
583 stride_v,
584 stride_do,
585 stride_dq_acc,
586 stride_dk,
587 stride_dv,
588 nhead_stride_q,
589 nhead_stride_k,
590 nhead_stride_v,
591 nhead_stride_do,
592 nhead_stride_lsed,
593 nhead_stride_dq_acc,
594 nhead_stride_dk,
595 nhead_stride_dv}, // args for common karg
596 {}, // placeholder for bias
597 {}, // placeholder for dbias
598 {}, // placeholder for mask
599 {}, // placeholder for dropout
600 {}, // placeholder for deterministic
601 reinterpret_cast<const int32_t*>(seqstart_q_ptr),
602 reinterpret_cast<const int32_t*>(seqstart_k_ptr),
603 reinterpret_cast<const int32_t*>(seqlen_q_ptr),
604 reinterpret_cast<const int32_t*>(seqlen_k_ptr),
605 reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr),
606 reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr)};
607
609 {
610 kargs.bias_ptr = bias_ptr;
611 kargs.stride_bias = stride_bias;
612 kargs.nhead_stride_bias = nhead_stride_bias;
613 }
614 else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
615 {
616 kargs.alibi_slope_ptr = bias_ptr;
617 kargs.alibi_slope_stride = stride_bias;
618 }
619 if constexpr(kHasBiasGrad)
620 {
621 kargs.dbias_ptr = dbias_ptr;
622 kargs.stride_dbias = stride_dbias;
623 kargs.nhead_stride_dbias = nhead_stride_dbias;
624 }
625 if constexpr(kHasMask)
626 {
627 kargs.window_size_left = window_size_left;
628 kargs.window_size_right = window_size_right;
629 kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
630 }
631 if constexpr(kHasDropout)
632 {
633 if(drop_seed_offset.index() == 0) // seed & offset come from host
634 {
635 const auto& [seed, offset] = std::get<0>(drop_seed_offset);
636 kargs.init_dropout(p_drop, seed, offset, scale);
637 }
638 else // seed & offset come from device
639 {
640 const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
641 kargs.init_dropout(p_drop,
642 reinterpret_cast<const uint64_t*>(seed_ptr),
643 reinterpret_cast<const uint64_t*>(offset_ptr),
644 scale);
645 }
646
647 if constexpr(kIsStoreRandval)
648 {
649 kargs.rand_val_ptr = rand_val_ptr;
650 kargs.stride_randval = stride_randval;
651 kargs.nhead_stride_randval = nhead_stride_randval;
652 }
653 }
654 if constexpr(kIsDeterministic)
655 {
656 kargs.split_stride_dq_acc = split_stride_dq_acc;
657 }
658
659 return kargs;
660 }
661
662 CK_TILE_HOST static constexpr auto
664 {
665 return dim3(
666 kUseQrQtrDorPipeline ? 1 : ck_tile::integer_divide_ceil(seqlen_k_, FmhaPipeline::kN0),
667 nhead_,
668 batch_size_);
669 }
670
671 CK_TILE_DEVICE static constexpr auto GetTileIndex()
672 {
673 const index_t i_block = blockIdx.x;
674 const index_t i_nhead = blockIdx.y;
675 const index_t i_batch = blockIdx.z;
676
677 return ck_tile::make_tuple(i_block, i_nhead, i_batch);
678 }
679
681 {
682 if(is_wave32())
683 {
684 return dim3(kBlockSize / 2);
685 }
686 else
687 {
688 return dim3(kBlockSize);
689 }
690 }
691
693 {
694 return ck_tile::max(FmhaPipeline::GetSmemSize(),
695 KGradEpiloguePipeline::GetSmemSize(),
696 VGradEpiloguePipeline::GetSmemSize());
697 }
698
700 {
701 if constexpr(kIsAvailable)
702 run_(std::move(kargs));
703 }
704
705 CK_TILE_DEVICE void run_(Kargs kargs) const
706 {
707 // allocate LDS
708 __shared__ char smem_ptr[GetSmemSize()];
709
710 // divide problem
711 const auto [i_tile_n, i_nhead, i_batch] = GetTileIndex();
712
713 const index_t i_n0 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN0);
714
715 long_index_t batch_offset_q = 0;
716 long_index_t batch_offset_k = 0;
717 long_index_t batch_offset_v = 0;
718 long_index_t batch_offset_bias = 0;
719 long_index_t batch_offset_randval = 0;
720 long_index_t batch_offset_do = 0;
721 long_index_t batch_offset_lsed = 0;
722 long_index_t batch_offset_dq_acc = 0;
723 long_index_t batch_offset_dk = 0;
724 long_index_t batch_offset_dv = 0;
725 long_index_t batch_offset_dbias = 0;
726
727 if constexpr(kIsGroupMode)
728 {
729 // get starting offset for each batch
730 const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
731 const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
732
733 batch_offset_q = query_start * kargs.stride_q;
734 batch_offset_k = key_start * kargs.stride_k;
735 batch_offset_v = key_start * kargs.stride_v;
736 batch_offset_do = query_start * kargs.stride_do;
737 batch_offset_lsed = query_start;
738 batch_offset_dq_acc = query_start * kargs.stride_dq_acc;
739 batch_offset_dk = key_start * kargs.stride_dk;
740 batch_offset_dv = key_start * kargs.stride_dv;
742 {
743 batch_offset_bias = query_start * kargs.stride_bias;
744 }
745 if constexpr(kHasBiasGrad)
746 {
747 batch_offset_dbias = query_start * kargs.stride_dbias;
748 }
749 else
750 {
751 batch_offset_dbias = key_start;
752 }
753 if constexpr(kIsStoreRandval)
754 {
755 batch_offset_randval = query_start * kargs.stride_randval;
756 }
757
758 // Priority: cu_seqlen_q_ptr > seqlen_q_ptr > physical_seqlen_q
759 if(kargs.cu_seqlen_q_ptr != nullptr)
760 {
761 kargs.seqlen_q =
762 kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
763 }
764 else
765 {
766 // get real # queries & # keys under group mode
767 const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
768 const ck_tile::index_t physical_seqlen_q =
769 adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
770 kargs.seqlen_q =
771 kargs.seqlen_q_ptr ? kargs.seqlen_q_ptr[i_batch] : physical_seqlen_q;
772 }
773
774 // Priority: cu_seqlen_k_ptr > seqlen_k_ptr > seqstart_k
775 if(kargs.cu_seqlen_k_ptr != nullptr)
776 {
777 kargs.seqlen_k =
778 kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
779 }
780 else if(kargs.seqlen_k_ptr != nullptr)
781 {
782 kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
783 }
784 else
785 {
786 const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
787 kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
788 }
789
790 // skip if logical lengths are zero
791 if(kargs.seqlen_q == 0 || kargs.seqlen_k == 0)
792 {
793 return;
794 }
795
796 // # of required blocks is different in each groups, terminate unnecessary blocks
797 // earlier
798 if constexpr(!kUseQrQtrDorPipeline)
799 if(kargs.seqlen_k <= i_n0)
800 return;
801 }
802 else
803 {
804 batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
805 batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
806 batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
807 batch_offset_do = static_cast<long_index_t>(i_batch) * kargs.batch_stride_do;
808 batch_offset_lsed = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lsed;
809 batch_offset_dq_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dq_acc;
810 batch_offset_dk = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dk;
811 batch_offset_dv = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dv;
813 {
814 batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
815 }
816 if constexpr(kHasBiasGrad)
817 {
818 batch_offset_dbias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dbias;
819 }
820 if constexpr(kIsStoreRandval)
821 {
822 batch_offset_randval =
823 static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
824 }
825 }
826
827 // for simplicity, batch stride we just modify the pointer
828 const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
829 static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
830 batch_offset_q;
831 const KDataType* k_ptr =
832 reinterpret_cast<const KDataType*>(kargs.k_ptr) +
833 static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
834 batch_offset_k;
835 const VDataType* v_ptr =
836 reinterpret_cast<const VDataType*>(kargs.v_ptr) +
837 static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
838 batch_offset_v;
839 const LSEDataType* lse_ptr = reinterpret_cast<const LSEDataType*>(kargs.lse_ptr) +
840 static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_lsed +
841 batch_offset_lsed;
842 const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
843 static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_lsed +
844 batch_offset_lsed;
845 const OGradDataType* do_ptr = reinterpret_cast<const OGradDataType*>(kargs.do_ptr) +
846 static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_do +
847 batch_offset_do;
848 auto dk_ptr = reinterpret_cast<KGradDataType*>(kargs.dk_ptr) +
849 static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dk + batch_offset_dk;
850 auto dv_ptr = reinterpret_cast<VGradDataType*>(kargs.dv_ptr) +
851 static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dv + batch_offset_dv;
852
853 // Q/K/V/LSE/D/dO/dQ/dK/dV DRAM and DRAM window
855 q_ptr,
856 make_tuple(kargs.seqlen_q, kargs.hdim_q),
857 make_tuple(kargs.stride_q, 1),
859 number<1>{});
860 const auto q_dram = pad_tensor_view(
861 q_dram_naive,
863 sequence<false, (kPadHeadDimQ > 0)>{});
864
866 k_ptr,
867 make_tuple(kargs.seqlen_k, kargs.hdim_q),
868 make_tuple(kargs.stride_k, 1),
870 number<1>{});
871 const auto k_dram = pad_tensor_view(
872 k_dram_naive,
874 sequence<false, (kPadHeadDimQ > 0)>{});
875
876 const auto v_dram = [&]() {
878 v_ptr,
879 make_tuple(kargs.seqlen_k, kargs.hdim_v),
880 make_tuple(kargs.stride_v, 1),
882 number<1>{});
883 return pad_tensor_view(
884 v_dram_naive,
886 sequence<false, (kPadHeadDimV > 0)>{});
887 }();
888
889 // lse and d should be fine to read unpaded data as they are not on the reduction dimension
891 lse_ptr, make_tuple(kargs.seqlen_q), number<FmhaPipeline::kM0>{});
892
894 d_ptr, make_tuple(kargs.seqlen_q), number<FmhaPipeline::kM0>{});
895
897 do_ptr,
898 make_tuple(kargs.seqlen_q, kargs.hdim_v),
899 make_tuple(kargs.stride_do, 1),
901 number<1>{});
902 const auto do_dram = pad_tensor_view(
903 do_dram_naive,
905 sequence<false, (kPadHeadDimV > 0)>{});
906
907 auto q_dram_window = make_tile_window(
908 q_dram,
910 {0, 0});
911
912 auto k_dram_window = make_tile_window(
913 k_dram,
915 {i_n0, 0});
916
917 auto v_dram_window = make_tile_window(
918 v_dram,
920 {i_n0, 0});
921
922 auto do_dram_window = make_tile_window(
923 do_dram,
925 {0, 0});
926
927 auto dq_dram_window = [&, i_tile_n_ = i_tile_n, i_nhead_ = i_nhead]() {
928 constexpr bool kUseKSplit = !kUseQrQtrDorPipeline && kIsDeterministic;
929 using DType = std::conditional_t<kUseQrQtrDorPipeline, QGradDataType, AccDataType>;
930
931 auto dq_acc_ptr = reinterpret_cast<DType*>(kargs.dq_acc_ptr) + [&]() {
932 if constexpr(kUseKSplit)
933 return static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
934 static_cast<long_index_t>(i_tile_n_) * kargs.split_stride_dq_acc +
935 batch_offset_dq_acc;
936 else
937 return static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
938 batch_offset_dq_acc;
939 }();
940
941 constexpr auto DstInMemOp = conditional_expr<kUseKSplit>(
943 const auto dq_acc_dram_naive =
945 dq_acc_ptr,
946 make_tuple(kargs.seqlen_q, kargs.hdim_q),
947 make_tuple(kargs.stride_dq_acc, 1),
949 number<1>{});
950 const auto dq_acc_dram = pad_tensor_view(
951 dq_acc_dram_naive,
953 sequence<false, (kPadHeadDimQ > 0)>{});
954 return make_tile_window(
955 dq_acc_dram,
957 {0, 0});
958 }();
959
960 auto lse_dram_window =
962
963 auto d_dram_window = make_tile_window(d_dram, make_tuple(number<FmhaPipeline::kM0>{}), {0});
964
967 constexpr auto bias_dram_window_lengths =
969 const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
971 {
972 const BiasDataType* bias_ptr =
973 reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
974 static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
975 batch_offset_bias;
976
977 const auto bias_dram = [&]() {
978 const auto bias_dram_naive = make_naive_tensor_view<address_space_enum::global>(
979 bias_ptr,
980 make_tuple(kargs.seqlen_q, kargs.seqlen_k),
981 make_tuple(kargs.stride_bias, 1),
983 number<1>{});
984
985 return pad_tensor_view(
986 bias_dram_naive, bias_dram_window_lengths, sequence<false, true>{});
987 }();
988
989 return make_tile_window(bias_dram, bias_dram_window_lengths, {0, i_n0});
990 }
991 else
992 {
993 return make_null_tile_window(bias_dram_window_lengths);
994 }
995 }();
996
997 auto dbias_dram_window = [&, i_nhead_ = i_nhead]() {
998 if constexpr(kHasBiasGrad)
999 {
1000 BiasGradDataType* dbias_ptr =
1001 reinterpret_cast<BiasGradDataType*>(kargs.dbias_ptr) +
1002 static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dbias +
1003 batch_offset_dbias;
1004
1005 auto dbias_dram = [&]() {
1006 const auto dbias_dram_naive =
1008 dbias_ptr,
1009 make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1010 make_tuple(kargs.stride_dbias, 1),
1012 number<1>{});
1013
1014 return pad_tensor_view(
1015 dbias_dram_naive, bias_dram_window_lengths, sequence<false, true>{});
1016 }();
1017
1018 return make_tile_window(dbias_dram, bias_dram_window_lengths, {0, i_n0});
1019 }
1020 else
1021 {
1022 return make_null_tile_window(bias_dram_window_lengths);
1023 }
1024 }();
1025
1026 // WA i_batch capture structure binding before c++20
1027 auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
1029 {
1030 // data loading, shared by entire wg
1031 // TODO: how to use s_read?
1032 AccDataType slope = *(reinterpret_cast<const AccDataType*>(kargs.alibi_slope_ptr) +
1033 i_batch_ * kargs.alibi_slope_stride + i_nhead_);
1034 slope *= ck_tile::log2e_v<>;
1035 if constexpr(kHasMask)
1036 {
1038 kargs.window_size_left,
1039 kargs.window_size_right,
1040 kargs.seqlen_q,
1041 kargs.seqlen_k,
1042 kargs.mask_type);
1043 }
1044 else
1045 {
1047 slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
1048 }
1049 }
1050 else
1051 {
1053 }
1054 }();
1055
1056 // dropout
1057 float rp_undrop = 1;
1058 float scale_rp_undrop = 1;
1059 if constexpr(kHasDropout)
1060 {
1061 rp_undrop = kargs.rp_undrop;
1062 scale_rp_undrop = kargs.scale_rp_undrop;
1063 }
1064 auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
1065 if constexpr(kHasDropout)
1066 {
1067 return FmhaDropout{i_batch_,
1068 i_nhead_,
1069 kargs.num_head_q,
1070 kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
1071 : *kargs.drop_seed.ptr,
1072 kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
1073 : *kargs.drop_offset.ptr,
1074 kargs.rp_undrop,
1075 kargs.p_undrop_in_uint8_t};
1076 }
1077 else
1078 {
1079 return FmhaDropout{};
1080 };
1081 }();
1082
1083 auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
1084 constexpr auto randval_dram_window_lengths =
1086 if constexpr(kIsStoreRandval)
1087 {
1088 RandValOutputDataType* rand_val_ptr =
1089 reinterpret_cast<RandValOutputDataType*>(kargs.rand_val_ptr) +
1090 static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_randval +
1091 batch_offset_randval;
1092
1093 const auto randval_dram = [&]() {
1094 const auto randval_dram_naive =
1096 rand_val_ptr,
1097 make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1098 make_tuple(kargs.stride_randval, 1),
1099 number<1>{},
1100 number<1>{});
1101
1102 return pad_tensor_view(
1103 randval_dram_naive, randval_dram_window_lengths, sequence<false, true>{});
1104 }();
1105
1106 return make_tile_window(randval_dram, randval_dram_window_lengths, {0, i_n0});
1107 }
1108 else
1109 {
1110 return make_null_tile_window(randval_dram_window_lengths);
1111 }
1112 }();
1113
1114 FmhaMask mask = [&]() {
1115 if constexpr(kHasMask)
1117 kargs.window_size_left,
1118 kargs.window_size_right,
1119 kargs.seqlen_q,
1120 kargs.seqlen_k,
1122 else
1123 return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
1124 }();
1125
1126 auto dk_dram = [&]() {
1127 const auto dk_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1128 dk_ptr,
1129 make_tuple(kargs.seqlen_k, kargs.hdim_q),
1130 make_tuple(kargs.stride_dk, 1),
1132 number<1>{});
1133
1134 return pad_tensor_view(
1135 dk_dram_naive,
1137 sequence<false, (kPadHeadDimQ > 0)>{});
1138 }();
1139
1140 auto dv_dram = [&]() {
1141 const auto dv_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1142 dv_ptr,
1143 make_tuple(kargs.seqlen_k, kargs.hdim_v),
1144 make_tuple(kargs.stride_dv, 1),
1146 number<1>{});
1147
1148 return pad_tensor_view(
1149 dv_dram_naive,
1151 sequence<false, (kPadHeadDimV > 0)>{});
1152 }();
1153
1154 auto dk_dram_window = make_tile_window(
1155 dk_dram,
1157 {i_n0, 0});
1158
1159 auto dv_dram_window = make_tile_window(
1160 dv_dram,
1162 {i_n0, 0});
1163 if constexpr(!kUseQrQtrDorPipeline)
1164 {
1165 auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(smem_ptr,
1166 q_dram_window,
1167 k_dram_window,
1168 v_dram_window,
1169 bias_dram_window,
1170 randval_dram_window,
1171 do_dram_window,
1172 lse_dram_window,
1173 d_dram_window,
1174 dq_dram_window,
1175 dbias_dram_window,
1176 mask,
1177 position_encoding,
1178 kargs.raw_scale,
1179 kargs.scale,
1180 rp_undrop,
1181 scale_rp_undrop,
1182 dropout);
1183
1184#if defined(__gfx12__)
1185 // Workaround for a compiler bug (SWDEV-559729): v_wmma instructions can be incorrectly
1186 // placed in divergent branches used to store padded tensors (when some lanes are
1187 // inactive due to padding). Inline asm with dummy dependencies on VGPRs of the tensors
1188 // prevents the compiler doing this.
1189 if constexpr(kPadHeadDimQ > 0)
1190 {
1191 impl::insert_dummy_dep(dk_acc_tile.get_thread_buffer());
1192 }
1193 if constexpr(kPadHeadDimV > 0)
1194 {
1195 impl::insert_dummy_dep(dv_acc_tile.get_thread_buffer());
1196 }
1197#endif
1198
1199 KGradEpiloguePipeline{}(dk_dram_window, dk_acc_tile, nullptr);
1200 VGradEpiloguePipeline{}(dv_dram_window, dv_acc_tile, nullptr);
1201 }
1202 else
1203 {
1204 FmhaPipeline{}(smem_ptr,
1205 q_dram_window,
1206 k_dram_window,
1207 v_dram_window,
1208 bias_dram_window,
1209 randval_dram_window,
1210 do_dram_window,
1211 lse_dram_window,
1212 d_dram_window,
1213 dq_dram_window,
1214 dk_dram_window,
1215 dv_dram_window,
1216 dbias_dram_window,
1220 mask,
1221 position_encoding,
1222 kargs.raw_scale,
1223 kargs.scale,
1224 rp_undrop,
1225 scale_rp_undrop,
1226 dropout);
1227 }
1228 }
1229};
1230
1231template <typename FmhaBwdOGradDotO_>
1233{
1235 static constexpr ck_tile::index_t kBlockSize = FmhaBwdOGradDotO::kBlockSize;
1236 static constexpr ck_tile::index_t kBlockPerCu = FmhaBwdOGradDotO::kBlockPerCu;
1237 static constexpr ck_tile::index_t kM0 = kBlockSize;
1238 static constexpr ck_tile::index_t kVHeaddim = FmhaBwdOGradDotO::kVHeaddim;
1239
1243
1244 static constexpr bool kIsGroupMode = FmhaBwdOGradDotO::kIsGroupMode;
1245 static constexpr bool kPadSeqLenQ = FmhaBwdOGradDotO::kPadSeqLenQ;
1246 static constexpr bool kPadHeadDimV = FmhaBwdOGradDotO::kPadHeadDimV;
1247
1248 // clang-format off
1249 template <typename T> struct t2s;
1250 template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
1251 template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
1252 template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
1253 // clang-format on
1254
1255 CK_TILE_HOST static std::string GetName()
1256 {
1257 // sync with generate.py
1258 // clang-format off
1259
1260 #define _SS_ std::string
1261 #define _TS_ std::to_string
1262 auto pn = [&] () {
1263 std::string n;
1264 if (kPadSeqLenQ) n += "s";
1265 if (kPadHeadDimV) n += "dv";
1266 return n.empty() ? n : std::string("p") + n; }();
1267 return
1268 _SS_("fmha_bwd_dot_do_o_d") + _TS_(kVHeaddim) + "_" + _SS_(t2s<ODataType>::name) +
1269 "_b" + _TS_(kM0) + "_" + (kIsGroupMode ? "group" : "batch") + "_" +
1270 ("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "_npad" : "_" + pn);
1271 #undef _SS_
1272 #undef _TS_
1273 // clang-format on
1274 }
1275
1276 // kargs use aggregate initializer, so no constructor will provided
1277 // use inheritance to minimize karg size
1278 // user need to use MakeKargs() function to create kargs.
1297
1304
1306 {
1308 const int32_t* seqlen_q_ptr; // per-batch actual length [batch]
1309 const int32_t* cu_seqlen_q_ptr; // cumulative seqlen [batch+1], optional
1310 };
1311
1312 using Kargs = std::
1313 conditional_t<kIsGroupMode, FmhaBwdOGradDotOGroupModeKargs, FmhaBwdOGradDotOBatchModeKargs>;
1314
1315 template <bool Cond = !kIsGroupMode>
1316 CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
1317 MakeKargs(const void* o_ptr,
1318 const void* do_ptr,
1319 void* d_ptr,
1320 float p_undrop,
1321 ck_tile::index_t seqlen_q,
1322 ck_tile::index_t hdim_v,
1323 ck_tile::index_t stride_do,
1324 ck_tile::index_t stride_o,
1325 ck_tile::index_t nhead_stride_do,
1326 ck_tile::index_t nhead_stride_o,
1327 ck_tile::index_t nhead_stride_d,
1328 ck_tile::index_t batch_stride_do,
1329 ck_tile::index_t batch_stride_o,
1330 ck_tile::index_t batch_stride_d)
1331 {
1332 Kargs kargs{{o_ptr,
1333 do_ptr,
1334 d_ptr,
1335 p_undrop,
1336 seqlen_q,
1337 hdim_v,
1338 stride_do,
1339 stride_o,
1340 nhead_stride_do,
1341 nhead_stride_o,
1342 nhead_stride_d},
1343 batch_stride_do,
1344 batch_stride_o,
1345 batch_stride_d};
1346
1347 return kargs;
1348 }
1349
1350 template <bool Cond = kIsGroupMode>
1351 CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
1352 MakeKargs(const void* o_ptr,
1353 const void* do_ptr,
1354 void* d_ptr,
1355 float p_undrop,
1356 const void* seqstart_q_ptr,
1357 const void* seqlen_q_ptr,
1358 const void* cu_seqlen_q_ptr,
1359 ck_tile::index_t hdim_v,
1360 ck_tile::index_t stride_do,
1361 ck_tile::index_t stride_o,
1362 ck_tile::index_t nhead_stride_do,
1363 ck_tile::index_t nhead_stride_o,
1364 ck_tile::index_t nhead_stride_d)
1365 {
1366 Kargs kargs{{o_ptr,
1367 do_ptr,
1368 d_ptr,
1369 p_undrop,
1370 -1, // seqlen will be updated by another pointer
1371 hdim_v,
1372 stride_do,
1373 stride_o,
1374 nhead_stride_do,
1375 nhead_stride_o,
1376 nhead_stride_d},
1377 reinterpret_cast<const int32_t*>(seqstart_q_ptr),
1378 reinterpret_cast<const int32_t*>(seqlen_q_ptr),
1379 reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr)};
1380
1381 return kargs;
1382 }
1383
1384 CK_TILE_HOST static constexpr auto
1386 {
1387 return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0), nhead_, batch_size_);
1388 }
1389
1390 CK_TILE_DEVICE static constexpr auto GetTileIndex()
1391 {
1392 const index_t i_block = blockIdx.x;
1393 const index_t i_nhead = blockIdx.y;
1394 const index_t i_batch = blockIdx.z;
1395
1396 return ck_tile::make_tuple(i_block, i_nhead, i_batch);
1397 }
1398
1399 CK_TILE_HOST static dim3 BlockSize() { return dim3(kBlockSize); }
1400
1401 CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
1402
1404 {
1405 // divide problem
1406 const auto [i_tile_m, i_nhead, i_batch] = GetTileIndex();
1407
1408 const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * kM0);
1409
1410 long_index_t batch_offset_o = 0;
1411 long_index_t batch_offset_do = 0;
1412 long_index_t batch_offset_d = 0;
1413
1414 if constexpr(kIsGroupMode)
1415 {
1416 // get starting offset for each batch
1417 const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
1418
1419 batch_offset_o = query_start * kargs.stride_o;
1420 batch_offset_do = query_start * kargs.stride_do;
1421 batch_offset_d = query_start;
1422
1423 // Priority: cu_seqlen_q_ptr > seqlen_q_ptr > physical_seqlen_q
1424 if(kargs.cu_seqlen_q_ptr != nullptr)
1425 {
1426 kargs.seqlen_q =
1427 kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
1428 }
1429 else
1430 {
1431 // get real # queries & # keys under group mode
1432 const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
1433 const ck_tile::index_t physical_seqlen_q =
1434 adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
1435 kargs.seqlen_q = kargs.seqlen_q_ptr
1436 ? static_cast<ck_tile::index_t>(kargs.seqlen_q_ptr[i_batch])
1437 : physical_seqlen_q;
1438 }
1439
1440 // # of required blocks is different in each groups, terminate unnecessary blocks
1441 // earlier
1442 if(kargs.seqlen_q <= i_m0)
1443 {
1444 return;
1445 }
1446 }
1447 else
1448 {
1449 batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
1450 batch_offset_do = static_cast<long_index_t>(i_batch) * kargs.batch_stride_do;
1451 batch_offset_d = static_cast<long_index_t>(i_batch) * kargs.batch_stride_d;
1452 }
1453
1454 // for simplicity, batch stride we just modify the pointer
1455 const ODataType* o_ptr = reinterpret_cast<const ODataType*>(kargs.o_ptr) +
1456 static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
1457 batch_offset_o;
1458 const OGradDataType* do_ptr = reinterpret_cast<const OGradDataType*>(kargs.do_ptr) +
1459 static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_do +
1460 batch_offset_do;
1461 DDataType* d_ptr = reinterpret_cast<DDataType*>(kargs.d_ptr) +
1462 static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_d +
1463 batch_offset_d;
1464
1465 // O/dO/D DRAM and DRAM window
1466 const auto o_dram = [&]() {
1468 o_ptr,
1469 make_tuple(kargs.seqlen_q, kargs.hdim_v),
1470 make_tuple(kargs.stride_o, 1),
1472 number<1>{});
1473 return pad_tensor_view(o_dram_naive,
1476 }();
1477 const auto do_dram = [&]() {
1479 do_ptr,
1480 make_tuple(kargs.seqlen_q, kargs.hdim_v),
1481 make_tuple(kargs.stride_do, 1),
1483 number<1>{});
1484 return pad_tensor_view(do_dram_naive,
1487 }();
1488 auto d_dram = [&]() {
1490 d_ptr, make_tuple(kargs.seqlen_q), number<1>{});
1491 return pad_tensor_view(
1492 d_dram_naive, make_tuple(number<kM0>{}), sequence<kPadSeqLenQ>{});
1493 }();
1494
1495 auto o_dram_window =
1496 make_tile_window(o_dram, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {i_m0, 0});
1497
1498 auto do_dram_window =
1499 make_tile_window(do_dram, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {i_m0, 0});
1500
1501 auto d_dram_window = make_tile_window(d_dram, make_tuple(number<kM0>{}), {i_m0});
1502
1503 FmhaBwdOGradDotO{}(o_dram_window, do_dram_window, d_dram_window, kargs.p_undrop);
1504 }
1505};
1506
1507template <typename FmhaBwdConvertQGrad_>
1509{
1511 static constexpr ck_tile::index_t kBlockSize = FmhaBwdConvertQGrad::kBlockSize;
1512 static constexpr ck_tile::index_t kBlockPerCu = FmhaBwdConvertQGrad::kBlockPerCu;
1513 static constexpr ck_tile::index_t kM0 = FmhaBwdConvertQGrad::kM0;
1514 static constexpr ck_tile::index_t kN0 = FmhaBwdConvertQGrad::kN0;
1515 static constexpr ck_tile::index_t kQKHeaddim = FmhaBwdConvertQGrad::kQKHeaddim;
1516
1519
1520 static constexpr bool kIsGroupMode = FmhaBwdConvertQGrad::kIsGroupMode;
1521 static constexpr bool kPadSeqLenQ = FmhaBwdConvertQGrad::kPadSeqLenQ;
1522 static constexpr bool kPadHeadDimQ = FmhaBwdConvertQGrad::kPadHeadDimQ;
1523 static constexpr bool kIsDeterministic = FmhaBwdConvertQGrad::kIsDeterministic;
1524
1525 // clang-format off
1526 template <typename T> struct t2s;
1527 template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
1528 template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
1529 template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
1530 // clang-format on
1531
1532 CK_TILE_HOST static std::string GetName()
1533 {
1534 // sync with generate.py
1535 // clang-format off
1536
1537 #define _SS_ std::string
1538 #define _TS_ std::to_string
1539 auto pn = [&] () {
1540 std::string n;
1541 if (kPadSeqLenQ) n += "s";
1542 if (kPadHeadDimQ) n += "d";
1543 return n.empty() ? n : std::string("p") + n; }();
1544 return
1545 _SS_("fmha_bwd_convert_dq_d") + _TS_(kQKHeaddim) + "_"
1547 + "b" + _TS_(kM0) + "x" + _TS_(kN0) + "_"
1548 + (kIsGroupMode ? "group" : "batch") + "_"
1549 + ("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "_npad" : "_" + pn)
1550 + (kIsDeterministic ? "_deterministic" : "_ndeterministic") ;
1551 #undef _SS_
1552 #undef _TS_
1553 // clang-format on
1554 }
1555
1556 // to avoid duplicated base class prblem, introduce an template arg
1557 template <ck_tile::index_t I>
1559 {
1560 };
1561
1562 // kargs use aggregate initializer, so no constructor will provided
1563 // use inheritance to minimize karg size
1564 // user need to use MakeKargs() function to create kargs.
1579
1584
1587 std::conditional_t<kIsDeterministic,
1588 FmhaBwdConvertQGradDeterministicKargs,
1589 FmhaBwdConvertQGradEmptyKargs<0>>
1590 {
1593 };
1594
1597 std::conditional_t<kIsDeterministic,
1598 FmhaBwdConvertQGradDeterministicKargs,
1599 FmhaBwdConvertQGradEmptyKargs<0>>
1600 {
1603 const int32_t* seqlen_q_ptr; // per-batch actual length [batch]
1604 const int32_t* seqlen_k_ptr; // per-batch actual length [batch]
1605 const int32_t* cu_seqlen_q_ptr; // cumulative seqlen [batch+1], optional
1606 const int32_t* cu_seqlen_k_ptr; // cumulative seqlen [batch+1], optional
1607 };
1608
1609 using Kargs = std::conditional_t<kIsGroupMode,
1612
1613 template <bool Cond = !kIsGroupMode>
1614 CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
1615 MakeKargs(const void* dq_acc_ptr,
1616 void* dq_ptr,
1617 ck_tile::index_t seqlen_q,
1618 ck_tile::index_t seqlen_k,
1619 ck_tile::index_t hdim_q,
1620 ck_tile::index_t stride_dq,
1621 ck_tile::index_t stride_dq_acc,
1622 ck_tile::index_t nhead_stride_dq,
1623 ck_tile::index_t nhead_stride_dq_acc,
1624 ck_tile::index_t batch_stride_dq,
1625 ck_tile::index_t batch_stride_dq_acc,
1626 ck_tile::index_t split_stride_dq_acc)
1627 {
1628 Kargs kargs{{dq_acc_ptr,
1629 dq_ptr,
1630 seqlen_q,
1631 seqlen_k,
1632 hdim_q,
1633 stride_dq,
1634 stride_dq_acc,
1635 nhead_stride_dq,
1636 nhead_stride_dq_acc},
1637 {},
1638 batch_stride_dq,
1639 batch_stride_dq_acc};
1640
1641 if constexpr(kIsDeterministic)
1642 {
1643 kargs.split_stride_dq_acc = split_stride_dq_acc;
1644 }
1645
1646 return kargs;
1647 }
1648
1649 template <bool Cond = kIsGroupMode>
1650 CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
1651 MakeKargs(const void* dq_acc_ptr,
1652 void* dq_ptr,
1653 const void* seqstart_q_ptr,
1654 const void* seqstart_k_ptr,
1655 const void* seqlen_q_ptr,
1656 const void* seqlen_k_ptr,
1657 const void* cu_seqlen_q_ptr,
1658 const void* cu_seqlen_k_ptr,
1659 ck_tile::index_t hdim_q,
1660 ck_tile::index_t stride_dq,
1661 ck_tile::index_t stride_dq_acc,
1662 ck_tile::index_t nhead_stride_dq,
1663 ck_tile::index_t nhead_stride_dq_acc,
1664 ck_tile::index_t split_stride_dq_acc)
1665 {
1666 Kargs kargs{{dq_acc_ptr,
1667 dq_ptr,
1668 -1, // seqlen will be updated by another pointer
1669 -1, //
1670 hdim_q,
1671 stride_dq,
1672 stride_dq_acc,
1673 nhead_stride_dq,
1674 nhead_stride_dq_acc},
1675 {},
1676 reinterpret_cast<const int32_t*>(seqstart_q_ptr),
1677 reinterpret_cast<const int32_t*>(seqstart_k_ptr),
1678 reinterpret_cast<const int32_t*>(seqlen_q_ptr),
1679 reinterpret_cast<const int32_t*>(seqlen_k_ptr),
1680 reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr),
1681 reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr)};
1682
1683 if constexpr(kIsDeterministic)
1684 {
1685 kargs.split_stride_dq_acc = split_stride_dq_acc;
1686 }
1687
1688 return kargs;
1689 }
1690
1691 CK_TILE_HOST static constexpr auto
1693 {
1694 return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0), nhead_, batch_size_);
1695 }
1696
1697 CK_TILE_DEVICE static constexpr auto GetTileIndex()
1698 {
1699 const index_t i_block = blockIdx.x;
1700 const index_t i_nhead = blockIdx.y;
1701 const index_t i_batch = blockIdx.z;
1702
1703 return ck_tile::make_tuple(i_block, i_nhead, i_batch);
1704 }
1705
1706 CK_TILE_HOST static dim3 BlockSize() { return dim3(kBlockSize); }
1707
1708 CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
1709
1711 {
1712 // divide problem
1713 const auto [i_tile_m, i_nhead, i_batch] = GetTileIndex();
1714
1715 const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * kM0);
1716
1717 long_index_t batch_offset_dq = 0;
1718 long_index_t batch_offset_dq_acc = 0;
1719 if constexpr(kIsGroupMode)
1720 {
1721 // get starting offset for each batch
1722 const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
1723 batch_offset_dq = query_start * kargs.stride_dq;
1724 batch_offset_dq_acc = query_start * kargs.stride_dq_acc;
1725
1726 if(kargs.cu_seqlen_q_ptr != nullptr)
1727 {
1728 kargs.seqlen_q =
1729 kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
1730 }
1731 else
1732 {
1733 // get real # queries & # keys under group mode
1734 const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
1735 const ck_tile::index_t physical_seqlen_q =
1736 adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
1737 kargs.seqlen_q = kargs.seqlen_q_ptr
1738 ? static_cast<ck_tile::index_t>(kargs.seqlen_q_ptr[i_batch])
1739 : physical_seqlen_q;
1740 }
1741
1742 if constexpr(kIsDeterministic)
1743 {
1744 const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
1745 const ck_tile::index_t physical_seqlen_k =
1746 adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
1747
1748 // Priority: cu_seqlen_k_ptr > seqlen_k_ptr > physical_seqlen_k
1749 if(kargs.cu_seqlen_k_ptr != nullptr)
1750 {
1751 kargs.seqlen_k =
1752 kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
1753 }
1754 else
1755 {
1756 kargs.seqlen_k =
1757 kargs.seqlen_k_ptr
1758 ? static_cast<ck_tile::index_t>(kargs.seqlen_k_ptr[i_batch])
1759 : physical_seqlen_k;
1760 }
1761 }
1762 // # of required blocks is different in each groups, terminate unnecessary blocks
1763 // earlier
1764 if(kargs.seqlen_q <= i_m0)
1765 {
1766 return;
1767 }
1768 }
1769 else
1770 {
1771 batch_offset_dq = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dq;
1772 batch_offset_dq_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dq_acc;
1773 }
1774
1775 // for simplicity, batch stride we just modify the pointer
1776 QGradDataType* dq_ptr = reinterpret_cast<QGradDataType*>(kargs.dq_ptr) +
1777 static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dq +
1778 batch_offset_dq;
1779
1780 // dQAcc/dQ DRAM and DRAM window
1781 const auto dq_acc_dram = [&, i_nhead_ = i_nhead]() {
1782 if constexpr(kIsDeterministic)
1783 {
1784 const AccDataType* dq_acc_ptr =
1785 reinterpret_cast<const AccDataType*>(kargs.dq_acc_ptr) +
1786 static_cast<long_index_t>(i_nhead_) * (kargs.nhead_stride_dq_acc) +
1787 batch_offset_dq_acc;
1788
1789 const index_t nsplits = ck_tile::integer_divide_ceil(kargs.seqlen_k, kN0);
1790
1792 dq_acc_ptr,
1793 make_tuple(nsplits, kargs.seqlen_q, kargs.hdim_q),
1794 make_tuple(kargs.split_stride_dq_acc, kargs.stride_dq_acc, 1),
1796 number<1>{});
1797 return pad_tensor_view(dq_acc_dram_naive,
1800 }
1801 else
1802 {
1803 const AccDataType* dq_acc_ptr =
1804 reinterpret_cast<const AccDataType*>(kargs.dq_acc_ptr) +
1805 static_cast<long_index_t>(i_nhead_) * (kargs.nhead_stride_dq_acc) +
1806 batch_offset_dq_acc;
1807
1809 dq_acc_ptr,
1810 make_tuple(kargs.seqlen_q, kargs.hdim_q),
1811 make_tuple(kargs.stride_dq_acc, 1),
1813 number<1>{});
1814 return pad_tensor_view(dq_acc_dram_naive,
1817 }
1818 }();
1819
1820 auto dq_dram = [&]() {
1822 dq_ptr,
1823 make_tuple(kargs.seqlen_q, kargs.hdim_q),
1824 make_tuple(kargs.stride_dq, 1),
1826 number<1>{});
1827 return pad_tensor_view(dq_dram_naive,
1830 }();
1831
1832 auto dq_acc_dram_window = [&]() {
1833 if constexpr(kIsDeterministic)
1834 {
1835 return make_tile_window(
1836 dq_acc_dram,
1838 {0, i_m0, 0});
1839 }
1840 else
1841 {
1842 return make_tile_window(
1843 dq_acc_dram, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {i_m0, 0});
1844 }
1845 }();
1846
1847 auto dq_dram_window =
1848 make_tile_window(dq_dram, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {i_m0, 0});
1849
1850 if constexpr(kIsDeterministic)
1851 {
1852 const index_t nsplits = ck_tile::integer_divide_ceil(kargs.seqlen_k, kN0);
1853 FmhaBwdConvertQGrad{}(dq_acc_dram_window, dq_dram_window, nsplits);
1854 }
1855 else
1856 {
1857 FmhaBwdConvertQGrad{}(dq_acc_dram_window, dq_dram_window);
1858 }
1859 }
1860};
1861
1862} // namespace ck_tile
#define _TS_
#define _SS_
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
CK_TILE_DEVICE void insert_dummy_dep()
Definition tile/core/arch/amd_buffer_addressing.hpp:1037
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
constexpr auto conditional_expr(X &&x, Y &&y)
Definition tile/core/utility/functional.hpp:220
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
CK_TILE_HOST_DEVICE constexpr auto make_generic_attention_mask_from_lr_window(index_t left_size, index_t right_size, index_t y_total, index_t x_total, bool is_top_left=true)
Definition block_masking.hpp:632
@ atomic_add
Definition arch.hpp:58
@ set
Definition arch.hpp:57
@ ALIBI
Definition block_attention_bias_enum.hpp:15
@ NO_BIAS
Definition block_attention_bias_enum.hpp:13
@ ELEMENTWISE_BIAS
Definition block_attention_bias_enum.hpp:14
bfloat16_t bf16_t
Definition bfloat16.hpp:113
_Float16 fp16_t
Definition half.hpp:110
constexpr T log2e_v
Definition tile/core/numeric/math.hpp:488
int64_t long_index_t
Definition integer.hpp:11
CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope, index_t window_left_size, index_t window_right_size, index_t y_total, index_t x_total, GenericAttentionMaskEnum mask_enum)
Definition block_position_encoding.hpp:148
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
int32_t int32_t
Definition integer.hpp:10
CK_TILE_DEVICE constexpr auto make_null_tile_window(const WindowLengths &window_lengths)
Definition null_tile_window.hpp:66
GenericAttentionMaskEnum
Definition block_masking.hpp:11
@ MASK_FROM_TOP_LEFT
Definition block_masking.hpp:15
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
int32_t index_t
Definition integer.hpp:9
@ FROM_BOTTOM_RIGHT
Definition block_position_encoding.hpp:43
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view_packed(DataType *__restrict__ p, const tuple< Lengths... > &lengths, number< GuaranteedLastDimensionVectorLength >=number<-1 >{})
Definition tensor_view.hpp:494
CK_TILE_HOST bool is_wave32()
Definition arch.hpp:72
unsigned char uint8_t
Definition stdint.h:124
unsigned __int64 uint64_t
Definition stdint.h:136
Definition block_position_encoding.hpp:48
Definition block_attention_bias_enum.hpp:19
Definition block_position_encoding.hpp:137
ck_tile::index_t batch_stride_dq
Definition fmha_bwd_kernel.hpp:1591
ck_tile::index_t batch_stride_dq_acc
Definition fmha_bwd_kernel.hpp:1592
ck_tile::index_t seqlen_q
Definition fmha_bwd_kernel.hpp:1570
ck_tile::index_t stride_dq_acc
Definition fmha_bwd_kernel.hpp:1575
ck_tile::index_t nhead_stride_dq
Definition fmha_bwd_kernel.hpp:1576
ck_tile::index_t hdim_q
Definition fmha_bwd_kernel.hpp:1572
ck_tile::index_t stride_dq
Definition fmha_bwd_kernel.hpp:1574
const void * dq_acc_ptr
Definition fmha_bwd_kernel.hpp:1567
ck_tile::index_t seqlen_k
Definition fmha_bwd_kernel.hpp:1571
ck_tile::index_t nhead_stride_dq_acc
Definition fmha_bwd_kernel.hpp:1577
ck_tile::index_t split_stride_dq_acc
Definition fmha_bwd_kernel.hpp:1582
const int32_t * seqstart_k_ptr
Definition fmha_bwd_kernel.hpp:1602
const int32_t * cu_seqlen_q_ptr
Definition fmha_bwd_kernel.hpp:1605
const int32_t * seqstart_q_ptr
Definition fmha_bwd_kernel.hpp:1601
const int32_t * seqlen_q_ptr
Definition fmha_bwd_kernel.hpp:1603
const int32_t * cu_seqlen_k_ptr
Definition fmha_bwd_kernel.hpp:1606
const int32_t * seqlen_k_ptr
Definition fmha_bwd_kernel.hpp:1604
static constexpr const char * name
Definition fmha_bwd_kernel.hpp:1529
static constexpr const char * name
Definition fmha_bwd_kernel.hpp:1528
static constexpr const char * name
Definition fmha_bwd_kernel.hpp:1527
Definition fmha_bwd_kernel.hpp:1526
Definition fmha_bwd_kernel.hpp:1509
static constexpr bool kIsGroupMode
Definition fmha_bwd_kernel.hpp:1520
static constexpr bool kPadHeadDimQ
Definition fmha_bwd_kernel.hpp:1522
static constexpr bool kIsDeterministic
Definition fmha_bwd_kernel.hpp:1523
static constexpr ck_tile::index_t kBlockSize
Definition fmha_bwd_kernel.hpp:1511
static constexpr bool kPadSeqLenQ
Definition fmha_bwd_kernel.hpp:1521
static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > MakeKargs(const void *dq_acc_ptr, void *dq_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t stride_dq, ck_tile::index_t stride_dq_acc, ck_tile::index_t nhead_stride_dq, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t batch_stride_dq, ck_tile::index_t batch_stride_dq_acc, ck_tile::index_t split_stride_dq_acc)
Definition fmha_bwd_kernel.hpp:1615
static constexpr ck_tile::index_t kBlockPerCu
Definition fmha_bwd_kernel.hpp:1512
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition fmha_bwd_kernel.hpp:1710
ck_tile::remove_cvref_t< typename FmhaBwdConvertQGrad::AccDataType > AccDataType
Definition fmha_bwd_kernel.hpp:1517
static CK_TILE_HOST dim3 BlockSize()
Definition fmha_bwd_kernel.hpp:1706
static constexpr ck_tile::index_t kM0
Definition fmha_bwd_kernel.hpp:1513
ck_tile::remove_cvref_t< typename FmhaBwdConvertQGrad::QGradDataType > QGradDataType
Definition fmha_bwd_kernel.hpp:1518
static constexpr ck_tile::index_t kN0
Definition fmha_bwd_kernel.hpp:1514
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition fmha_bwd_kernel.hpp:1708
static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > MakeKargs(const void *dq_acc_ptr, void *dq_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_q_ptr, const void *seqlen_k_ptr, const void *cu_seqlen_q_ptr, const void *cu_seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t stride_dq, ck_tile::index_t stride_dq_acc, ck_tile::index_t nhead_stride_dq, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t split_stride_dq_acc)
Definition fmha_bwd_kernel.hpp:1651
ck_tile::remove_cvref_t< FmhaBwdConvertQGrad_ > FmhaBwdConvertQGrad
Definition fmha_bwd_kernel.hpp:1510
static CK_TILE_HOST constexpr auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
Definition fmha_bwd_kernel.hpp:1692
static constexpr ck_tile::index_t kQKHeaddim
Definition fmha_bwd_kernel.hpp:1515
std::conditional_t< kIsGroupMode, FmhaBwdConvertQGradGroupModeKargs, FmhaBwdConvertQGradBatchModeKargs > Kargs
Definition fmha_bwd_kernel.hpp:1609
static CK_TILE_DEVICE constexpr auto GetTileIndex()
Definition fmha_bwd_kernel.hpp:1697
static CK_TILE_HOST std::string GetName()
Definition fmha_bwd_kernel.hpp:1532
Definition fmha_bwd_kernel.hpp:192
const void * alibi_slope_ptr
Definition fmha_bwd_kernel.hpp:194
ck_tile::index_t alibi_slope_stride
Definition fmha_bwd_kernel.hpp:195
ck_tile::index_t batch_stride_dbias
Definition fmha_bwd_kernel.hpp:207
ck_tile::index_t batch_stride_bias
Definition fmha_bwd_kernel.hpp:188
ck_tile::index_t batch_stride_randval
Definition fmha_bwd_kernel.hpp:272
Definition fmha_bwd_kernel.hpp:291
ck_tile::index_t batch_stride_v
Definition fmha_bwd_kernel.hpp:294
ck_tile::index_t batch_stride_k
Definition fmha_bwd_kernel.hpp:293
ck_tile::index_t batch_stride_q
Definition fmha_bwd_kernel.hpp:292
ck_tile::index_t batch_stride_do
Definition fmha_bwd_kernel.hpp:295
ck_tile::index_t batch_stride_dq_acc
Definition fmha_bwd_kernel.hpp:297
ck_tile::index_t batch_stride_dk
Definition fmha_bwd_kernel.hpp:298
ck_tile::index_t batch_stride_dv
Definition fmha_bwd_kernel.hpp:299
ck_tile::index_t batch_stride_lsed
Definition fmha_bwd_kernel.hpp:296
ck_tile::index_t nhead_stride_dbias
Definition fmha_bwd_kernel.hpp:202
void * dbias_ptr
Definition fmha_bwd_kernel.hpp:200
ck_tile::index_t stride_dbias
Definition fmha_bwd_kernel.hpp:201
ck_tile::index_t stride_bias
Definition fmha_bwd_kernel.hpp:182
ck_tile::index_t nhead_stride_bias
Definition fmha_bwd_kernel.hpp:183
const void * bias_ptr
Definition fmha_bwd_kernel.hpp:181
uint8_t p_undrop_in_uint8_t
Definition fmha_bwd_kernel.hpp:263
float rp_undrop
Definition fmha_bwd_kernel.hpp:261
ck_tile::index_t nhead_stride_randval
Definition fmha_bwd_kernel.hpp:267
void init_dropout(float p_drop, const uint64_t *seed_ptr, const uint64_t *offset_ptr, float raw_scale)
Definition fmha_bwd_kernel.hpp:245
float scale_rp_undrop
Definition fmha_bwd_kernel.hpp:262
void * rand_val_ptr
Definition fmha_bwd_kernel.hpp:264
ck_tile::index_t stride_randval
Definition fmha_bwd_kernel.hpp:266
void init_dropout(float p_drop, uint64_t seed, uint64_t offset, float raw_scale)
Definition fmha_bwd_kernel.hpp:232
Definition fmha_bwd_kernel.hpp:138
ck_tile::index_t nhead_stride_dk
Definition fmha_bwd_kernel.hpp:175
ck_tile::index_t stride_do
Definition fmha_bwd_kernel.hpp:164
ck_tile::index_t seqlen_k
Definition fmha_bwd_kernel.hpp:150
const void * q_ptr
Definition fmha_bwd_kernel.hpp:139
ck_tile::index_t hdim_q
Definition fmha_bwd_kernel.hpp:151
ck_tile::index_t nhead_stride_do
Definition fmha_bwd_kernel.hpp:172
ck_tile::index_t num_head_q
Definition fmha_bwd_kernel.hpp:156
const void * lse_ptr
Definition fmha_bwd_kernel.hpp:142
float raw_scale
Definition fmha_bwd_kernel.hpp:158
ck_tile::index_t nhead_stride_k
Definition fmha_bwd_kernel.hpp:170
ck_tile::index_t nhead_stride_q
Definition fmha_bwd_kernel.hpp:169
ck_tile::index_t stride_dv
Definition fmha_bwd_kernel.hpp:167
ck_tile::index_t nhead_stride_lsed
Definition fmha_bwd_kernel.hpp:173
void * dq_acc_ptr
Definition fmha_bwd_kernel.hpp:145
ck_tile::index_t stride_q
Definition fmha_bwd_kernel.hpp:161
ck_tile::index_t seqlen_q
Definition fmha_bwd_kernel.hpp:149
ck_tile::index_t stride_dk
Definition fmha_bwd_kernel.hpp:166
const void * do_ptr
Definition fmha_bwd_kernel.hpp:143
float scale
Definition fmha_bwd_kernel.hpp:159
void * dk_ptr
Definition fmha_bwd_kernel.hpp:146
ck_tile::index_t nhead_stride_v
Definition fmha_bwd_kernel.hpp:171
ck_tile::index_t stride_v
Definition fmha_bwd_kernel.hpp:163
const void * d_ptr
Definition fmha_bwd_kernel.hpp:144
const void * k_ptr
Definition fmha_bwd_kernel.hpp:140
ck_tile::index_t nhead_stride_dq_acc
Definition fmha_bwd_kernel.hpp:174
ck_tile::index_t nhead_ratio_qk
Definition fmha_bwd_kernel.hpp:157
void * dv_ptr
Definition fmha_bwd_kernel.hpp:147
const void * v_ptr
Definition fmha_bwd_kernel.hpp:141
ck_tile::index_t hdim_v
Definition fmha_bwd_kernel.hpp:152
ck_tile::index_t stride_k
Definition fmha_bwd_kernel.hpp:162
ck_tile::index_t nhead_stride_dv
Definition fmha_bwd_kernel.hpp:176
ck_tile::index_t stride_dq_acc
Definition fmha_bwd_kernel.hpp:165
ck_tile::index_t split_stride_dq_acc
Definition fmha_bwd_kernel.hpp:277
bool is_drop_seed_offset_from_host
Definition fmha_bwd_kernel.hpp:227
ValueOrPointer< uint64_t > drop_seed
Definition fmha_bwd_kernel.hpp:225
ValueOrPointer< uint64_t > drop_offset
Definition fmha_bwd_kernel.hpp:226
Definition fmha_bwd_kernel.hpp:131
Definition fmha_bwd_kernel.hpp:313
const int32_t * seqstart_k_ptr
Definition fmha_bwd_kernel.hpp:315
const int32_t * seqstart_q_ptr
Definition fmha_bwd_kernel.hpp:314
const int32_t * seqlen_k_ptr
Definition fmha_bwd_kernel.hpp:317
const int32_t * cu_seqlen_k_ptr
Definition fmha_bwd_kernel.hpp:319
const int32_t * seqlen_q_ptr
Definition fmha_bwd_kernel.hpp:316
const int32_t * cu_seqlen_q_ptr
Definition fmha_bwd_kernel.hpp:318
Definition fmha_bwd_kernel.hpp:211
ck_tile::GenericAttentionMaskEnum mask_type
Definition fmha_bwd_kernel.hpp:213
ck_tile::index_t window_size_right
Definition fmha_bwd_kernel.hpp:212
ck_tile::index_t window_size_left
Definition fmha_bwd_kernel.hpp:212
static constexpr const char * name
Definition fmha_bwd_kernel.hpp:87
static constexpr const char * name
Definition fmha_bwd_kernel.hpp:86
static constexpr const char * name
Definition fmha_bwd_kernel.hpp:85
Definition fmha_bwd_kernel.hpp:84
Definition fmha_bwd_kernel.hpp:35
ck_tile::remove_cvref_t< KGradEpiloguePipeline_ > KGradEpiloguePipeline
Definition fmha_bwd_kernel.hpp:37
ck_tile::remove_cvref_t< typename FmhaPipeline::RandValOutputDataType > RandValOutputDataType
Definition fmha_bwd_kernel.hpp:55
static constexpr auto BiasEnum
Definition fmha_bwd_kernel.hpp:66
static CK_TILE_HOST std::string GetName()
Definition fmha_bwd_kernel.hpp:90
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaDropout > FmhaDropout
Definition fmha_bwd_kernel.hpp:69
ck_tile::remove_cvref_t< typename FmhaPipeline::OGradDataType > OGradDataType
Definition fmha_bwd_kernel.hpp:57
static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *lse_ptr, const void *do_ptr, const void *d_ptr, void *rand_val_ptr, void *dk_ptr, void *dv_ptr, void *dbias_ptr, void *dq_acc_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_do, ck_tile::index_t stride_dq_acc, ck_tile::index_t stride_dk, ck_tile::index_t stride_dv, ck_tile::index_t stride_dbias, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_lsed, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t nhead_stride_dk, ck_tile::index_t nhead_stride_dv, ck_tile::index_t nhead_stride_dbias, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_do, ck_tile::index_t batch_stride_lsed, ck_tile::index_t batch_stride_dq_acc, ck_tile::index_t batch_stride_dk, ck_tile::index_t batch_stride_dv, ck_tile::index_t batch_stride_dbias, ck_tile::index_t split_stride_dq_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * > > drop_seed_offset)
Definition fmha_bwd_kernel.hpp:344
static constexpr bool kUseQrQtrDorPipeline
Definition fmha_bwd_kernel.hpp:42
ck_tile::remove_cvref_t< VGradEpiloguePipeline_ > VGradEpiloguePipeline
Definition fmha_bwd_kernel.hpp:38
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasGradDataType > BiasGradDataType
Definition fmha_bwd_kernel.hpp:61
static constexpr ck_tile::index_t kBlockSize
Definition fmha_bwd_kernel.hpp:40
static constexpr bool kHasMask
Definition fmha_bwd_kernel.hpp:70
static CK_TILE_HOST dim3 BlockSize()
Definition fmha_bwd_kernel.hpp:680
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition fmha_bwd_kernel.hpp:48
static constexpr bool kIsGroupMode
Definition fmha_bwd_kernel.hpp:63
ck_tile::remove_cvref_t< typename FmhaPipeline::QGradDataType > QGradDataType
Definition fmha_bwd_kernel.hpp:58
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition fmha_bwd_kernel.hpp:692
ck_tile::remove_cvref_t< typename FmhaPipeline::DDataType > DDataType
Definition fmha_bwd_kernel.hpp:54
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition fmha_bwd_kernel.hpp:47
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition fmha_bwd_kernel.hpp:52
ck_tile::remove_cvref_t< typename FmhaPipeline::VGradDataType > VGradDataType
Definition fmha_bwd_kernel.hpp:60
ck_tile::remove_cvref_t< typename FmhaPipeline::GemmDataType > GemmDataType
Definition fmha_bwd_kernel.hpp:51
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition fmha_bwd_kernel.hpp:699
static constexpr bool kIsAvailable
Definition fmha_bwd_kernel.hpp:80
ck_tile::remove_cvref_t< QGradEpiloguePipeline_ > QGradEpiloguePipeline
Definition fmha_bwd_kernel.hpp:39
static CK_TILE_HOST constexpr Kargs MakeKargs(Ts... args, const std::tuple< const void *, const void * > &drop_seed_offset)
Definition fmha_bwd_kernel.hpp:336
static constexpr bool kHasDropout
Definition fmha_bwd_kernel.hpp:71
ck_tile::remove_cvref_t< typename FmhaPipeline::AccDataType > AccDataType
Definition fmha_bwd_kernel.hpp:53
CK_TILE_DEVICE void run_(Kargs kargs) const
Definition fmha_bwd_kernel.hpp:705
static CK_TILE_HOST constexpr auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
Definition fmha_bwd_kernel.hpp:663
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition fmha_bwd_kernel.hpp:49
static constexpr bool kHasBiasGrad
Definition fmha_bwd_kernel.hpp:67
static constexpr bool kIsDeterministic
Definition fmha_bwd_kernel.hpp:73
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition fmha_bwd_kernel.hpp:50
static constexpr index_t kPadHeadDimQ
Definition fmha_bwd_kernel.hpp:64
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition fmha_bwd_kernel.hpp:36
static CK_TILE_HOST constexpr Kargs MakeKargs(Ts... args, const std::tuple< uint64_t, uint64_t > &drop_seed_offset)
Definition fmha_bwd_kernel.hpp:327
ck_tile::remove_cvref_t< typename FmhaPipeline::KGradDataType > KGradDataType
Definition fmha_bwd_kernel.hpp:59
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition fmha_bwd_kernel.hpp:68
static constexpr bool kUseTrLoad
Definition fmha_bwd_kernel.hpp:74
static CK_TILE_DEVICE constexpr auto GetTileIndex()
Definition fmha_bwd_kernel.hpp:671
static constexpr index_t kPadHeadDimV
Definition fmha_bwd_kernel.hpp:65
static constexpr ck_tile::index_t kBlockPerCu
Definition fmha_bwd_kernel.hpp:41
std::conditional_t< kIsGroupMode, FmhaBwdGroupModeKargs, FmhaBwdBatchModeKargs > Kargs
Definition fmha_bwd_kernel.hpp:322
static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *lse_ptr, const void *do_ptr, const void *d_ptr, void *rand_val_ptr, void *dk_ptr, void *dv_ptr, void *dbias_ptr, void *dq_acc_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_q_ptr, const void *seqlen_k_ptr, const void *cu_seqlen_q_ptr, const void *cu_seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_do, ck_tile::index_t stride_dq_acc, ck_tile::index_t stride_dk, ck_tile::index_t stride_dv, ck_tile::index_t stride_dbias, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_lsed, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t nhead_stride_dk, ck_tile::index_t nhead_stride_dv, ck_tile::index_t nhead_stride_dbias, ck_tile::index_t split_stride_dq_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * > > drop_seed_offset)
Definition fmha_bwd_kernel.hpp:512
static constexpr bool kIsStoreRandval
Definition fmha_bwd_kernel.hpp:72
static constexpr index_t kMaxSeqLenQ
Definition fmha_bwd_kernel.hpp:75
ck_tile::index_t batch_stride_o
Definition fmha_bwd_kernel.hpp:1301
ck_tile::index_t batch_stride_do
Definition fmha_bwd_kernel.hpp:1300
ck_tile::index_t batch_stride_d
Definition fmha_bwd_kernel.hpp:1302
void * d_ptr
Definition fmha_bwd_kernel.hpp:1283
const void * o_ptr
Definition fmha_bwd_kernel.hpp:1281
ck_tile::index_t hdim_v
Definition fmha_bwd_kernel.hpp:1288
ck_tile::index_t nhead_stride_do
Definition fmha_bwd_kernel.hpp:1293
ck_tile::index_t stride_o
Definition fmha_bwd_kernel.hpp:1291
ck_tile::index_t nhead_stride_o
Definition fmha_bwd_kernel.hpp:1294
const void * do_ptr
Definition fmha_bwd_kernel.hpp:1282
ck_tile::index_t stride_do
Definition fmha_bwd_kernel.hpp:1290
ck_tile::index_t seqlen_q
Definition fmha_bwd_kernel.hpp:1287
float p_undrop
Definition fmha_bwd_kernel.hpp:1285
ck_tile::index_t nhead_stride_d
Definition fmha_bwd_kernel.hpp:1295
const int32_t * cu_seqlen_q_ptr
Definition fmha_bwd_kernel.hpp:1309
const int32_t * seqstart_q_ptr
Definition fmha_bwd_kernel.hpp:1307
const int32_t * seqlen_q_ptr
Definition fmha_bwd_kernel.hpp:1308
static constexpr const char * name
Definition fmha_bwd_kernel.hpp:1252
static constexpr const char * name
Definition fmha_bwd_kernel.hpp:1251
static constexpr const char * name
Definition fmha_bwd_kernel.hpp:1250
Definition fmha_bwd_kernel.hpp:1249
Definition fmha_bwd_kernel.hpp:1233
ck_tile::remove_cvref_t< typename FmhaBwdOGradDotO::ODataType > ODataType
Definition fmha_bwd_kernel.hpp:1241
static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > MakeKargs(const void *o_ptr, const void *do_ptr, void *d_ptr, float p_undrop, ck_tile::index_t seqlen_q, ck_tile::index_t hdim_v, ck_tile::index_t stride_do, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_d, ck_tile::index_t batch_stride_do, ck_tile::index_t batch_stride_o, ck_tile::index_t batch_stride_d)
Definition fmha_bwd_kernel.hpp:1317
ck_tile::remove_cvref_t< FmhaBwdOGradDotO_ > FmhaBwdOGradDotO
Definition fmha_bwd_kernel.hpp:1234
static CK_TILE_HOST dim3 BlockSize()
Definition fmha_bwd_kernel.hpp:1399
static constexpr ck_tile::index_t kVHeaddim
Definition fmha_bwd_kernel.hpp:1238
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition fmha_bwd_kernel.hpp:1403
static CK_TILE_HOST constexpr auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
Definition fmha_bwd_kernel.hpp:1385
static constexpr bool kIsGroupMode
Definition fmha_bwd_kernel.hpp:1244
ck_tile::remove_cvref_t< typename FmhaBwdOGradDotO::OGradDataType > OGradDataType
Definition fmha_bwd_kernel.hpp:1242
static constexpr ck_tile::index_t kM0
Definition fmha_bwd_kernel.hpp:1237
static constexpr ck_tile::index_t kBlockSize
Definition fmha_bwd_kernel.hpp:1235
static constexpr ck_tile::index_t kBlockPerCu
Definition fmha_bwd_kernel.hpp:1236
static CK_TILE_DEVICE constexpr auto GetTileIndex()
Definition fmha_bwd_kernel.hpp:1390
static constexpr bool kPadSeqLenQ
Definition fmha_bwd_kernel.hpp:1245
static CK_TILE_HOST std::string GetName()
Definition fmha_bwd_kernel.hpp:1255
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition fmha_bwd_kernel.hpp:1401
ck_tile::remove_cvref_t< typename FmhaBwdOGradDotO::DDataType > DDataType
Definition fmha_bwd_kernel.hpp:1240
static constexpr bool kPadHeadDimV
Definition fmha_bwd_kernel.hpp:1246
static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > MakeKargs(const void *o_ptr, const void *do_ptr, void *d_ptr, float p_undrop, const void *seqstart_q_ptr, const void *seqlen_q_ptr, const void *cu_seqlen_q_ptr, ck_tile::index_t hdim_v, ck_tile::index_t stride_do, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_d)
Definition fmha_bwd_kernel.hpp:1352
std:: conditional_t< kIsGroupMode, FmhaBwdOGradDotOGroupModeKargs, FmhaBwdOGradDotOBatchModeKargs > Kargs
Definition fmha_bwd_kernel.hpp:1312
Definition block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp:777
Definition coordinate_transform.hpp:1392
Definition tile/core/container/sequence.hpp:49