device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp Source File

device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp Source File#

Composable Kernel: device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp Source File
device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8#include <numeric>
9#include <initializer_list>
10#include <cstdlib>
11
12#include "ck/ck.hpp"
24
25namespace ck {
26namespace tensor_operation {
27namespace device {
28
29template <typename DeviceOp,
30 typename GridwiseOp,
31 typename ADataType,
32 typename B0DataType,
33 typename B1DataType,
34 typename CDataType,
35 typename AElementwiseOperation,
36 typename B0ElementwiseOperation,
37 typename AccElementwiseOperation,
38 typename B1ElementwiseOperation,
39 typename CElementwiseOperation,
40 bool HasMainKBlockLoop>
41__global__ void
42#if CK_USE_LAUNCH_BOUNDS
44#endif
45 kernel_batched_gemm_softmax_gemm_wmma_cshuffle(const ADataType* __restrict__ p_a_grid,
46 const B0DataType* __restrict__ p_b0_grid,
47 const B1DataType* __restrict__ p_b1_grid,
48 CDataType* __restrict__ p_c_grid,
49 index_t M,
50 index_t N,
51 index_t K,
52 index_t O,
53 index_t G0,
54 index_t G1,
55 float alpha,
56 bool input_permute,
57 bool output_permute)
58{
59#if(defined(__gfx11__) || defined(__gfx12__))
60
61 // clang-format off
62// ***************************************************
63// Make Tensor Descriptors
64 constexpr index_t array_size = 4;
65 std::array<ck::index_t, array_size> a_gs_ms_ks_lengths{G0, G1, M, K};
66 std::array<ck::index_t, array_size> a_gs_ms_ks_strides =
67 input_permute
68 ? std::array<ck::index_t, array_size>{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K]
69 : std::array<ck::index_t, array_size>{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K]
70
71 std::array<ck::index_t, array_size> b0_gs_ns_ks_lengths{G0, G1, N, K};
72 std::array<ck::index_t, array_size> b0_gs_ns_ks_strides =
73 input_permute
74 ? std::array<ck::index_t, array_size>{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K]
75 : std::array<ck::index_t, array_size>{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K]
76
77 std::array<ck::index_t, array_size> b1_gs_os_ns_lengths{G0, G1, O, N};
78 std::array<ck::index_t, array_size> b1_gs_os_ns_strides =
79 input_permute
80 ? std::array<ck::index_t, array_size>{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O]
81 : std::array<ck::index_t, array_size>{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O]
82
83 std::array<ck::index_t, array_size> c_gs_ms_os_lengths{G0, G1, M, O};
84 std::array<ck::index_t, array_size> c_gs_ms_os_strides =
85 output_permute
86 ? std::array<ck::index_t, array_size>{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O]
87 : std::array<ck::index_t, array_size>{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O]
88
89 const auto a_element_op = AElementwiseOperation{};
90 const auto b0_element_op = B0ElementwiseOperation{};
91 const auto acc0_element_op = AccElementwiseOperation{alpha};
92 const auto b1_element_op = B1ElementwiseOperation{};
93 const auto c_element_op = CElementwiseOperation{};
94 // fail to reuse DeviceOp::MakeArgument() because of the __device__ function required.
95
96 const auto a_grid_desc = DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
97 const auto b0_grid_desc =
98 DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
99 const auto b1_grid_desc =
100 DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
101 const auto c_grid_desc_m_n =
102 DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides);
103 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
104 GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
105 const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1);
106
107 const auto a_grid_desc_g_m_k =
108 DeviceOp::Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
109 const auto b0_grid_desc_g_l_k =
110 DeviceOp::Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
111 const auto b1_grid_desc_g_n_l =
112 DeviceOp::Transform::MakeB1GridDescriptor_G_N_K(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
113 const auto c_grid_desc_g_m_n =
114 DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides);
115 const auto compute_base_ptr_of_batch =
116 typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n};
117 index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{});
118 const auto c0_matrix_mask = typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(Number<1>{})};
119
120 // clang-format on
121 __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
122 const index_t num_blocks_per_batch =
123 __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
124 const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
125
126 const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
127 static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
128 const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane(
129 static_cast<long_index_t>(compute_base_ptr_of_batch.GetB0BasePtr(g_idx)));
130 const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
131 static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
132 const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
133 static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
134
135 GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
136 p_b0_grid + b0_batch_offset,
137 p_b1_grid + b1_batch_offset,
138 p_c_grid + c_batch_offset,
139 p_shared,
140 a_grid_desc,
141 b0_grid_desc,
142 b1_grid_desc,
143 c_grid_desc_mblock_mperblock_nblock_nperblock,
144 a_element_op,
145 b0_element_op,
146 acc0_element_op,
147 b1_element_op,
148 c_element_op,
149 c0_matrix_mask,
150 block_2_ctile_map);
151#else
152 ignore = p_a_grid;
153 ignore = p_b0_grid;
154 ignore = p_b1_grid;
155 ignore = p_c_grid;
156 ignore = M;
157 ignore = N;
158 ignore = K;
159 ignore = O;
160 ignore = G0;
161 ignore = G1;
162 ignore = alpha;
163 ignore = input_permute;
164 ignore = output_permute;
165#endif // end of if (defined(__gfx11__))
166}
167
168// Self-Attention
169template <typename DeviceOp,
170 typename GridwiseOp,
171 typename QKVDataType,
172 typename ODataType,
173 typename AElementwiseOperation,
174 typename B0ElementwiseOperation,
175 typename AccElementwiseOperation,
176 typename B1ElementwiseOperation,
177 typename CElementwiseOperation,
178 bool HasMainKBlockLoop>
179__global__ void
180#if CK_USE_LAUNCH_BOUNDS
182#endif
183 kernel_wmma_self_attention_forward(const QKVDataType* __restrict__ p_qkv_grid,
184 ODataType* __restrict__ p_out_grid,
185 index_t batch_size,
186 index_t sequence_length,
187 index_t head_count,
188 index_t head_size,
189 float alpha)
190{
191#if(defined(__gfx11__) || defined(__gfx12__))
192
193 // clang-format off
194// ***************************************************
195// Make Tensor Descriptors
196// o Self-attention(packed QKV): [batchSize, sequenceLength, headCount, 3, headSize]
197 constexpr index_t array_size = 4;
198 std::array<ck::index_t, array_size> qk_gs_ms_ks_lengths{batch_size, head_count, sequence_length, head_size};
199 std::array<ck::index_t, array_size> qk_gs_ms_ks_strides{sequence_length * head_count * 3 * head_size, 3 * head_size, head_count * 3 * head_size, 1};
200
201 std::array<ck::index_t, array_size> v_gs_os_ns_lengths{batch_size, head_count, head_size, sequence_length};
202 std::array<ck::index_t, array_size> v_gs_os_ns_strides{sequence_length * head_count * 3 * head_size, 3 * head_size, 1, head_count * 3 * head_size};
203
204 std::array<ck::index_t, array_size> c_gs_ms_os_lengths{batch_size, head_count, sequence_length, head_size};
205 std::array<ck::index_t, array_size> c_gs_ms_os_strides{sequence_length * head_count * head_size, head_size, head_count * head_size, 1};
206
207
208 const auto a_element_op = AElementwiseOperation{};
209 const auto b0_element_op = B0ElementwiseOperation{};
210 const auto acc0_element_op = AccElementwiseOperation{alpha};
211 const auto b1_element_op = B1ElementwiseOperation{};
212 const auto c_element_op = CElementwiseOperation{};
213
214 const auto a_grid_desc = DeviceOp::MakeAGridDescriptor(qk_gs_ms_ks_lengths, qk_gs_ms_ks_strides);
215 const auto b0_grid_desc =
216 DeviceOp::MakeB0GridDescriptor(qk_gs_ms_ks_lengths, qk_gs_ms_ks_strides);
217 const auto b1_grid_desc =
218 DeviceOp::MakeB1GridDescriptor(v_gs_os_ns_lengths, v_gs_os_ns_strides);
219 const auto c_grid_desc_m_n =
220 DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides);
221 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
222 GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
223 const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1);
224
225 const auto a_grid_desc_g_m_k =
226 DeviceOp::Transform::MakeAGridDescriptor_G_M_K(qk_gs_ms_ks_lengths, qk_gs_ms_ks_strides);
227 const auto b0_grid_desc_g_l_k =
228 DeviceOp::Transform::MakeB0GridDescriptor_G_N_K(qk_gs_ms_ks_lengths, qk_gs_ms_ks_strides);
229 const auto b1_grid_desc_g_n_l =
230 DeviceOp::Transform::MakeB1GridDescriptor_G_N_K(v_gs_os_ns_lengths, v_gs_os_ns_strides);
231 const auto c_grid_desc_g_m_n =
232 DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides);
233 const auto compute_base_ptr_of_batch =
234 typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n};
235 index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{});
236 const auto c0_matrix_mask = typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(Number<1>{})};
237
238 // clang-format on
239 __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
240 const index_t num_blocks_per_batch =
241 __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
242 const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
243
244 const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
245 static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
246 const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane(
247 static_cast<long_index_t>(compute_base_ptr_of_batch.GetB0BasePtr(g_idx)));
248 const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
249 static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
250 const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
251 static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
252
253 const index_t qkv_gap = __builtin_amdgcn_readfirstlane(head_size);
254#ifdef CK_SELF_ATTN_DEBUG
255 if(get_thread_global_1d_id() == 0)
256 {
257 printf("batch_size: %d\n", batch_size);
258 printf("sequence_length: %d\n", sequence_length);
259 printf("head_count: %d\n", head_count);
260 printf("head_size: %d\n", head_size);
261 printf("qkv_gap: %d\n", qkv_gap);
262 printf("get_grid_size(): %d\n", get_grid_size());
263 printf("batch_count: %d\n", batch_count);
264 printf("blockid: %d\n", get_block_1d_id());
265 printf("num_blocks_per_batch: %d\n", num_blocks_per_batch);
266 printf("g_idx: %d\n", g_idx);
267 printf("a_batch_offset: %ld\n", a_batch_offset);
268 printf("b0_batch_offset: %ld\n", b0_batch_offset);
269 printf("b1_batch_offset: %ld\n", b1_batch_offset);
270 }
271#endif
272 GridwiseOp::template Run<HasMainKBlockLoop>(p_qkv_grid + 0 * qkv_gap + a_batch_offset,
273 p_qkv_grid + 1 * qkv_gap + b0_batch_offset,
274 p_qkv_grid + 2 * qkv_gap + b1_batch_offset,
275 p_out_grid + c_batch_offset,
276 p_shared,
277 a_grid_desc,
278 b0_grid_desc,
279 b1_grid_desc,
280 c_grid_desc_mblock_mperblock_nblock_nperblock,
281 a_element_op,
282 b0_element_op,
283 acc0_element_op,
284 b1_element_op,
285 c_element_op,
286 c0_matrix_mask,
287 block_2_ctile_map);
288#else
289 ignore = p_qkv_grid;
290 ignore = p_out_grid;
291 ignore = batch_size;
292 ignore = sequence_length;
293 ignore = head_count;
294 ignore = head_size;
295 ignore = alpha;
296#endif // end of if (defined(__gfx11__))
297}
298// Cross-Attention
299// Self-Attention
300template <typename DeviceOp,
301 typename GridwiseOp,
302 typename QDataType,
303 typename KVDataType,
304 typename ODataType,
305 typename AElementwiseOperation,
306 typename B0ElementwiseOperation,
307 typename AccElementwiseOperation,
308 typename B1ElementwiseOperation,
309 typename CElementwiseOperation,
310 bool HasMainKBlockLoop>
311__global__ void
312#if CK_USE_LAUNCH_BOUNDS
314#endif
315 kernel_wmma_cross_attention_forward(const QDataType* __restrict__ p_q_grid,
316 const KVDataType* __restrict__ p_kv_grid,
317 ODataType* __restrict__ p_out_grid,
318 index_t batch_size,
319 index_t q_sequence_length,
320 index_t kv_sequence_length,
321 index_t head_count,
322 index_t head_size,
323 float alpha)
324{
325#if(defined(__gfx11__) || defined(__gfx12__))
326
327 // clang-format off
328// ***************************************************
329// Make Tensor Descriptors
330// o Self-attention(packed QKV): [batchSize, sequenceLength, headCount, 3, headSize]
331 constexpr index_t array_size = 4;
332 std::array<ck::index_t, array_size> q_gs_ms_ks_lengths{batch_size, head_count, q_sequence_length, head_size};
333 std::array<ck::index_t, array_size> q_gs_ms_ks_strides{q_sequence_length * head_count * head_size, head_size, head_count * head_size, 1};
334
335 std::array<ck::index_t, array_size> k_gs_ms_ks_lengths{batch_size, head_count, kv_sequence_length, head_size};
336 std::array<ck::index_t, array_size> k_gs_ms_ks_strides{kv_sequence_length * head_count * 2 * head_size, 2 * head_size, head_count * 2 * head_size, 1};
337
338 std::array<ck::index_t, array_size> v_gs_os_ns_lengths{batch_size, head_count, head_size, kv_sequence_length};
339 std::array<ck::index_t, array_size> v_gs_os_ns_strides{kv_sequence_length * head_count * 2 * head_size, 2 * head_size, 1, head_count * 2 * head_size};
340
341 std::array<ck::index_t, array_size> c_gs_ms_os_lengths{batch_size, head_count, q_sequence_length, head_size};
342 std::array<ck::index_t, array_size> c_gs_ms_os_strides{q_sequence_length * head_count * head_size, head_size, head_count * head_size, 1};
343
344
345 const auto a_element_op = AElementwiseOperation{};
346 const auto b0_element_op = B0ElementwiseOperation{};
347 const auto acc0_element_op = AccElementwiseOperation{alpha};
348 const auto b1_element_op = B1ElementwiseOperation{};
349 const auto c_element_op = CElementwiseOperation{};
350
351 const auto a_grid_desc = DeviceOp::MakeAGridDescriptor(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
352 const auto b0_grid_desc =
353 DeviceOp::MakeB0GridDescriptor(k_gs_ms_ks_lengths, k_gs_ms_ks_strides);
354 const auto b1_grid_desc =
355 DeviceOp::MakeB1GridDescriptor(v_gs_os_ns_lengths, v_gs_os_ns_strides);
356 const auto c_grid_desc_m_n =
357 DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides);
358 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
359 GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
360 const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1);
361
362 const auto a_grid_desc_g_m_k =
363 DeviceOp::Transform::MakeAGridDescriptor_G_M_K(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
364 const auto b0_grid_desc_g_l_k =
365 DeviceOp::Transform::MakeB0GridDescriptor_G_N_K(k_gs_ms_ks_lengths, k_gs_ms_ks_strides);
366 const auto b1_grid_desc_g_n_l =
367 DeviceOp::Transform::MakeB1GridDescriptor_G_N_K(v_gs_os_ns_lengths, v_gs_os_ns_strides);
368 const auto c_grid_desc_g_m_n =
369 DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides);
370 const auto compute_base_ptr_of_batch =
371 typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n};
372 index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{});
373 const auto c0_matrix_mask = typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(Number<1>{})};
374
375 // clang-format on
376 __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
377 const index_t num_blocks_per_batch =
378 __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
379 const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
380
381 const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
382 static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
383 const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane(
384 static_cast<long_index_t>(compute_base_ptr_of_batch.GetB0BasePtr(g_idx)));
385 const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
386 static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
387 const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
388 static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
389
390 const index_t kv_gap = __builtin_amdgcn_readfirstlane(head_size);
391#ifdef CK_SELF_ATTN_DEBUG
392 if(get_thread_global_1d_id() == 0)
393 {
394 printf("batch_size: %d\n", batch_size);
395 printf("q_sequence_length: %d\n", q_sequence_length);
396 printf("k_sequence_length: %d\n", kv_sequence_length);
397 printf("head_count: %d\n", head_count);
398 printf("head_size: %d\n", head_size);
399 printf("kv_gap: %d\n", kv_gap);
400 printf("get_grid_size(): %d\n", get_grid_size());
401 printf("batch_count: %d\n", batch_count);
402 printf("blockid: %d\n", get_block_1d_id());
403 printf("num_blocks_per_batch: %d\n", num_blocks_per_batch);
404 printf("g_idx: %d\n", g_idx);
405 printf("a_batch_offset: %ld\n", a_batch_offset);
406 printf("b0_batch_offset: %ld\n", b0_batch_offset);
407 printf("b1_batch_offset: %ld\n", b1_batch_offset);
408 }
409#endif
410 GridwiseOp::template Run<HasMainKBlockLoop>(p_q_grid + a_batch_offset,
411 p_kv_grid + 0 * kv_gap + b0_batch_offset,
412 p_kv_grid + 1 * kv_gap + b1_batch_offset,
413 p_out_grid + c_batch_offset,
414 p_shared,
415 a_grid_desc,
416 b0_grid_desc,
417 b1_grid_desc,
418 c_grid_desc_mblock_mperblock_nblock_nperblock,
419 a_element_op,
420 b0_element_op,
421 acc0_element_op,
422 b1_element_op,
423 c_element_op,
424 c0_matrix_mask,
425 block_2_ctile_map);
426#else
427 ignore = p_q_grid;
428 ignore = p_kv_grid;
429 ignore = p_out_grid;
430 ignore = batch_size;
431 ignore = q_sequence_length;
432 ignore = kv_sequence_length;
433 ignore = head_count;
434 ignore = head_size;
435 ignore = alpha;
436#endif // end of if (defined(__gfx11__))
437}
438// Computes C = A * B0 * B1
439// MN = MK * KL * LN
440// ^^^^^^ (Acc0)
441// ^^^^^^^^^^^ (Acc1)
442template <index_t NumDimG,
443 index_t NumDimM,
444 index_t NumDimL,
445 index_t NumDimK,
446 index_t NumDimN,
447 typename ADataType,
448 typename B0DataType,
449 typename B1DataType,
450 typename CDataType,
451 typename Acc0BiasDataType,
452 typename Acc0DataType,
453 typename Acc1BiasDataType,
454 typename Acc1DataType,
455 typename CShuffleDataType,
456 typename AElementwiseOperation,
457 typename B0ElementwiseOperation,
458 typename AccElementwiseOperation,
459 typename B1ElementwiseOperation,
460 typename CElementwiseOperation,
461 GemmSpecialization GemmSpec,
466 ck::index_t NumPrefetch,
467 ck::index_t BlockSize,
468 ck::index_t MPerBlock,
469 ck::index_t LPerBlock,
470 ck::index_t KPerBlock,
471 ck::index_t AK1,
472 ck::index_t BK1,
473 ck::index_t NPerBlock,
474 ck::index_t LTilePerBlock,
475 ck::index_t L1,
476 ck::index_t MPerWmma,
477 ck::index_t LPerWmma,
478 ck::index_t NPerWmma,
479 ck::index_t MRepeat,
480 ck::index_t LRepeat,
481 ck::index_t NRepeat,
482 typename ABlockTransferThreadClusterLengths_K0_M_K1,
483 typename ABlockTransferThreadClusterArrangeOrder,
484 typename ABlockTransferSrcAccessOrder,
485 ck::index_t ABlockTransferSrcVectorDim,
486 ck::index_t ABlockTransferSrcScalarPerVector,
487 ck::index_t ABlockTransferDstScalarPerVector_K1,
488 bool ABlockLdsAddExtraM,
489 typename B0BlockTransferThreadClusterLengths_K0_L_K1,
490 typename B0BlockTransferThreadClusterArrangeOrder,
491 typename B0BlockTransferSrcAccessOrder,
492 ck::index_t B0BlockTransferSrcVectorDim,
493 ck::index_t B0BlockTransferSrcScalarPerVector,
494 ck::index_t B0BlockTransferDstScalarPerVector_K1,
495 bool B0BlockLdsAddExtraL,
496 typename B1BlockTransferThreadClusterLengths_L0_N_L1,
497 typename B1BlockTransferThreadClusterArrangeOrder,
498 typename B1BlockTransferSrcAccessOrder,
499 ck::index_t B1BlockTransferSrcVectorDim,
500 ck::index_t B1BlockTransferSrcScalarPerVector,
501 ck::index_t B1BlockTransferDstScalarPerVector_L1,
502 bool B1BlockLdsAddExtraN,
503 index_t CShuffleMRepeatPerShuffle,
504 index_t CShuffleNRepeatPerShuffle,
505 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
506 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
507 MaskingSpecialization MaskingSpec,
512 NumDimM,
513 NumDimL,
514 NumDimK,
515 NumDimN,
516 ADataType,
517 B0DataType,
518 B1DataType,
519 CDataType,
520 Acc0BiasDataType,
521 Acc1BiasDataType,
522 AElementwiseOperation,
523 B0ElementwiseOperation,
524 AccElementwiseOperation,
525 B1ElementwiseOperation,
526 CElementwiseOperation,
527 MaskingSpec>
528{
529 static_assert(NumDimG > 0 && NumDimM > 0 && NumDimL > 0 && NumDimK > 0 && NumDimN > 0,
530 "Number of dimension must be greater than 0");
531
532 static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size();
533 static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size();
534
535 // TODO ANT: implement bias combination
536 static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
537
538 static constexpr index_t NumDimGemm0M = NumDimM;
539 static constexpr index_t NumDimGemm0N = NumDimL;
540 static constexpr index_t NumDimGemm0K = NumDimK;
541 static constexpr index_t NumDimGemm1M = NumDimM;
542 static constexpr index_t NumDimGemm1N = NumDimN;
543 static constexpr index_t NumDimGemm1K = NumDimL;
544
546
547 static constexpr auto I0 = Number<0>{};
548 static constexpr auto I1 = Number<1>{};
549 static constexpr auto I2 = Number<2>{};
550 static constexpr auto I3 = Number<3>{};
551 static constexpr auto I4 = Number<4>{};
552 static constexpr auto I5 = Number<5>{};
553 static constexpr auto I6 = Number<6>{};
554
555 static constexpr auto WmmaK = 16;
556
557 static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
558 static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma);
559 static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
560
561 static constexpr auto AEnableLds_auto = LWaves == 1 ? false : true;
562 static constexpr auto B0EnableLds_auto = MWaves == 1 ? false : true;
563 static constexpr auto B1EnableLds_auto = MWaves == 1 ? false : true;
564
565 static constexpr auto AEnableLds_manu = false;
566 static constexpr auto B0EnableLds_manu = true;
567 static constexpr auto B1EnableLds_manu = true;
568
569 static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
570 static constexpr auto B0EnableLds = B0EnableLds_auto || B0EnableLds_manu || (NumPrefetch > 1);
571 static constexpr auto B1EnableLds = B1EnableLds_auto || B1EnableLds_manu || (NumPrefetch > 1);
572
576 GemmSpec,
577 ASpec,
578 B0Spec,
579 B1Spec,
580 CSpec>;
581
582 __host__ __device__ static auto MakeAGridDescriptor(
583 const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths_vec,
584 const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides_vec)
585 {
586 if constexpr(AEnableLds)
587 {
589 Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
590 Number<AK1>{});
591 }
592 else
593 {
594 return Transform::
596 Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec,
597 a_gs_ms_ks_strides_vec),
602 Number<AK1>{});
603 }
604 }
605
606 __host__ __device__ static auto MakeB0GridDescriptor(
607 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_lengths_vec,
608 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_strides_vec)
609 {
610 if constexpr(B0EnableLds)
611 {
613 Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec,
614 b0_gs_ls_ks_strides_vec),
615 Number<BK1>{});
616 }
617 else
618 {
619 return Transform::
621 Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec,
622 b0_gs_ls_ks_strides_vec),
627 Number<BK1>{});
628 }
629 }
630
631 __host__ __device__ static auto MakeB1GridDescriptor(
632 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_lengths_vec,
633 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_strides_vec)
634 {
635 if constexpr(B1EnableLds)
636 {
638 Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec,
639 b1_gs_ns_ls_strides_vec),
640 Number<L1>{});
641 }
642 else
643 {
644 return Transform::
646 Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec,
647 b1_gs_ns_ls_strides_vec),
652 Number<L1>{});
653 }
654 }
655
656 using AGridDesc = decltype(MakeAGridDescriptor({}, {}));
657 using B0GridDesc = decltype(MakeB0GridDescriptor({}, {}));
658 using B1GridDesc = decltype(MakeB1GridDescriptor({}, {}));
664
665 __host__ __device__ constexpr static auto make_MaskOutPredicate()
666 {
667 if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled)
668 {
669 return MaskDisabledPredicate{};
670 }
671 else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle)
672 {
674 }
675 }
677
679 {
680 __host__ __device__ ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
681 const B0GridDesc_G_L_K& b0_grid_desc_g_l_k,
682 const B1GridDesc_G_N_L& b1_grid_desc_g_n_l,
683 const CGridDesc_G_M_N& c_grid_desc_g_m_n)
684 : a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
685 b0_grid_desc_g_l_k_(b0_grid_desc_g_l_k),
686 b1_grid_desc_g_n_l_(b1_grid_desc_g_n_l),
687 c_grid_desc_g_m_n_(c_grid_desc_g_m_n)
688 {
689 }
690
691 __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
692 {
693 return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
694 }
695
696 __host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const
697 {
698 return b0_grid_desc_g_l_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
699 }
700
701 __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
702 {
703 return b1_grid_desc_g_n_l_.CalculateOffset(make_multi_index(g_idx, 0, 0));
704 }
705
706 __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
707 {
708 return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
709 }
710
711 private:
712 AGridDesc_G_M_K a_grid_desc_g_m_k_;
713 B0GridDesc_G_L_K b0_grid_desc_g_l_k_;
714 B1GridDesc_G_N_L b1_grid_desc_g_n_l_;
715 CGridDesc_G_M_N c_grid_desc_g_m_n_;
716 };
717
718 // GridwiseOp
720 // DataType Family
721 ADataType,
722 B0DataType,
723 Acc0DataType,
724 B1DataType,
725 Acc1DataType,
726 CShuffleDataType,
727 CDataType,
728 // ElementwiseOp Family
729 AElementwiseOperation,
730 B0ElementwiseOperation,
731 AccElementwiseOperation,
732 B1ElementwiseOperation,
733 CElementwiseOperation,
735 // InMemory Data Descriptor
736 AGridDesc,
740 // Tiling Family
741 MPerBlock,
742 LPerBlock,
743 KPerBlock,
744 AK1,
745 BK1,
746 NPerBlock,
747 LTilePerBlock,
748 L1,
749 MPerWmma,
750 LPerWmma,
751 NPerWmma,
752 MRepeat,
753 LRepeat,
754 NRepeat,
755 // ThreadCluster Family
756 BlockSize,
757 ABlockTransferThreadClusterLengths_K0_M_K1,
758 ABlockTransferThreadClusterArrangeOrder,
759 ABlockTransferSrcAccessOrder,
760 ABlockTransferSrcVectorDim,
761 ABlockTransferSrcScalarPerVector,
762 ABlockTransferDstScalarPerVector_K1,
763 true,
765 ABlockLdsAddExtraM,
766 B0BlockTransferThreadClusterLengths_K0_L_K1,
767 B0BlockTransferThreadClusterArrangeOrder,
768 B0BlockTransferSrcAccessOrder,
769 B0BlockTransferSrcVectorDim,
770 B0BlockTransferSrcScalarPerVector,
771 B0BlockTransferDstScalarPerVector_K1,
772 true,
774 B0BlockLdsAddExtraL,
775 B1BlockTransferThreadClusterLengths_L0_N_L1,
776 B1BlockTransferThreadClusterArrangeOrder,
777 B1BlockTransferSrcAccessOrder,
778 B1BlockTransferSrcVectorDim,
779 B1BlockTransferSrcScalarPerVector,
780 B1BlockTransferDstScalarPerVector_L1,
781 false,
783 B1BlockLdsAddExtraN,
784 CShuffleMRepeatPerShuffle,
785 CShuffleNRepeatPerShuffle,
786 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
787 CShuffleBlockTransferScalarPerVector_NPerBlock,
790 NumPrefetch,
791 LoopSched,
792 PipelineVer>;
793
794 struct RawArg : public BaseArgument
795 {
796 RawArg(const ADataType* p_a_grid,
797 const B0DataType* p_b0_grid,
798 const B1DataType* p_b1_grid,
799 CDataType* p_c_grid,
800 index_t M,
801 index_t N,
802 index_t K,
803 index_t O,
804 index_t G0,
805 index_t G1,
806 float alpha,
807 bool input_permute,
808 bool output_permute)
809 : p_a_grid_{p_a_grid},
810 p_b0_grid_{p_b0_grid},
811 p_b1_grid_{p_b1_grid},
812 p_c_grid_{p_c_grid},
813 M_{M},
814 N_{N},
815 K_{K},
816 O_{O},
817 G0_{G0},
818 G1_{G1},
819 alpha_{alpha},
820 input_permute_{input_permute},
821 output_permute_{output_permute}
822 {
823 }
824 // Pointers
825 const ADataType* p_a_grid_;
826 const B0DataType* p_b0_grid_;
827 const B1DataType* p_b1_grid_;
828 CDataType* p_c_grid_;
829
830 // Raw Problem Size
837 float alpha_;
840 };
841
842 static auto MakeArgument(const ADataType* p_a,
843 const B0DataType* p_b0,
844 const B1DataType* p_b1,
845 CDataType* p_c,
846 index_t M,
847 index_t N,
848 index_t K,
849 index_t O,
850 index_t G0,
851 index_t G1,
852 float alpha,
853 bool input_permute,
854 bool output_permute)
855 {
856 return RawArg{
857 p_a, p_b0, p_b1, p_c, M, N, K, O, G0, G1, alpha, input_permute, output_permute};
858 }
859
860 static bool IsSupportedArgument(const RawArg& arg)
861 {
863 {
865 {
866 printf("DeviceOp: Acc0 Type err");
867 return false;
868 }
869
871 {
872 printf("DeviceOp: Acc1 Type err");
873 return false;
874 }
875 }
876 else
877 {
878 printf("DeviceOp: Arch err");
879 return false;
880 }
881
882 constexpr index_t array_size = 4;
883 ck::index_t G0 = arg.G0_;
884 ck::index_t G1 = arg.G1_;
885 ck::index_t M = arg.M_;
886 ck::index_t N = arg.N_;
887 ck::index_t K = arg.K_;
888 ck::index_t O = arg.O_;
889 bool input_permute = arg.input_permute_;
890 bool output_permute = arg.output_permute_;
891
892 std::array<ck::index_t, array_size> a_gs_ms_ks_lengths{G0, G1, M, K};
893 std::array<ck::index_t, array_size> a_gs_ms_ks_strides =
894 input_permute ? std::array<ck::index_t, array_size>{M * G1 * K, K, G1 * K, 1}
895 // A layout [G0, M, G1, K]
896 : std::array<ck::index_t, array_size>{
897 G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K]
898
899 std::array<ck::index_t, array_size> b0_gs_ns_ks_lengths{G0, G1, N, K};
900 std::array<ck::index_t, array_size> b0_gs_ns_ks_strides =
901 input_permute ? std::array<ck::index_t, array_size>{N * G1 * K, K, G1 * K, 1}
902 // B0 layout [G0, N, G1, K]
903 : std::array<ck::index_t, array_size>{
904 G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K]
905
906 std::array<ck::index_t, array_size> b1_gs_os_ns_lengths{G0, G1, O, N};
907 std::array<ck::index_t, array_size> b1_gs_os_ns_strides =
908 input_permute ? std::array<ck::index_t, array_size>{N * G1 * O, O, 1, G1 * O}
909 // B1 layout [G0, N, G1, O]
910 : std::array<ck::index_t, array_size>{
911 G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O]
912
913 std::array<ck::index_t, array_size> c_gs_ms_os_lengths{G0, G1, M, O};
914 std::array<ck::index_t, array_size> c_gs_ms_os_strides =
915 output_permute ? std::array<ck::index_t, array_size>{M * G1 * O, O, G1 * O, 1}
916 // C layout [G0, M, G1, O]
917 : std::array<ck::index_t, array_size>{
918 G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O]
919
920 const auto a_grid_desc =
921 DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
922 const auto b0_grid_desc =
923 DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides);
924 const auto b1_grid_desc =
925 DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides);
926 const auto c_grid_desc_m_n =
927 DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides);
928
929 const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1);
930
931 const auto c_grid_desc_g_m_n =
932 DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides);
933 index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{});
934
936 a_grid_desc, b0_grid_desc, b1_grid_desc, c_grid_desc_m_n, block_2_ctile_map))
937 {
938 return false;
939 }
940
941 // Check if C permute dimension matches GEMM + GEMM shape
942 const index_t c_g = c_grid_desc_g_m_n.GetLength(I0); // unpadded
943
944 if(!(c_g == batch_count))
945 {
946 printf("DeviceOp: BatchCount err");
947 return false;
948 }
949
950 // Note: we need raw lengths since threadwise copy can not handle vector load when part of
951 // vector is out of bounds
952 // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
953 const auto MzRaw = M;
954 const auto LzRaw = N;
955 const auto KzRaw = K;
956 const auto NzRaw = O;
957
958 // Check scalar per vector requirement
959 const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
960 const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw;
961 const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw;
962 const auto c_extent_lowest = NzRaw;
963
964 if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
965 b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
966 b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
967 c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
968 {
969 printf("DeviceOp: Data Transfer Vector scalar err");
970 return false;
971 }
972
973 std::array<index_t, NumDimG + NumDimM + NumDimN> a_mz_kz_strides_{
974 a_gs_ms_ks_strides[NumDimG + NumDimM - 1],
975 a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]};
976 std::array<index_t, NumDimG + NumDimM + NumDimN> b0_lz_kz_strides_{
977 b0_gs_ns_ks_strides[NumDimG + NumDimL - 1],
978 b0_gs_ns_ks_strides[NumDimG + NumDimL + NumDimK - 1]};
979 std::array<index_t, NumDimG + NumDimM + NumDimN> b1_nz_lz_strides_{
980 b1_gs_os_ns_strides[NumDimG + NumDimN - 1],
981 b1_gs_os_ns_strides[NumDimG + NumDimN + NumDimL - 1]};
982 std::array<index_t, NumDimG + NumDimM + NumDimN> c_mz_nz_strides_{
983 c_gs_ms_os_strides[NumDimG + NumDimM - 1],
984 c_gs_ms_os_strides[NumDimG + NumDimM + NumDimN - 1]};
985
986 // Check vector load/store requirement
987 const auto a_stride_lowest =
988 ABlockTransferSrcVectorDim == 2 ? a_mz_kz_strides_[1] : a_mz_kz_strides_[0];
989 const auto b0_stride_lowest =
990 B0BlockTransferSrcVectorDim == 2 ? b0_lz_kz_strides_[1] : b0_lz_kz_strides_[0];
991 const auto b1_stride_lowest =
992 B1BlockTransferSrcVectorDim == 2 ? b1_nz_lz_strides_[1] : b1_nz_lz_strides_[0];
993 const auto c_stride_lowest = c_mz_nz_strides_[1];
994
995 if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
996 c_stride_lowest == 1))
997 {
998 printf("DeviceOp: Data Vectorize transfer err");
999 return false;
1000 }
1001
1002 return true;
1003 }
1004
1005 // polymorphic
1006 bool IsSupportedArgument(const BaseArgument* p_arg) override
1007 {
1008 return IsSupportedArgument(*dynamic_cast<const RawArg*>(p_arg));
1009 }
1010
1012 {
1013 SelfAttnArg(const ADataType* p_qkv_grid,
1014 CDataType* p_out_grid,
1015 index_t batch_size,
1016 index_t sequence_length,
1017 index_t head_count,
1018 index_t head_size,
1019 float alpha)
1020 : p_qkv_grid_{p_qkv_grid},
1021 p_out_grid_{p_out_grid},
1022 batch_size_{batch_size},
1023 sequence_length_{sequence_length},
1024 head_count_{head_count},
1025 head_size_{head_size},
1026 alpha_{alpha}
1027 {
1028 }
1029 // Pointers
1030 const ADataType* p_qkv_grid_;
1031 CDataType* p_out_grid_;
1032
1033 // Raw Problem Size
1038 float alpha_;
1039 };
1040
1041 static auto MakeSelfAttnArgument(const ADataType* p_qkv_grid,
1042 CDataType* p_out_grid,
1043 index_t batch_size,
1044 index_t sequence_length,
1045 index_t head_count,
1046 index_t head_size,
1047 float alpha)
1048 {
1049 return SelfAttnArg{
1050 p_qkv_grid, p_out_grid, batch_size, sequence_length, head_count, head_size, alpha};
1051 }
1052
1054 {
1055 CrossAttnArg(const ADataType* p_q_grid,
1056 const B0DataType* p_kv_grid,
1057 CDataType* p_out_grid,
1058 index_t batch_size,
1059 index_t q_sequence_length,
1060 index_t kv_sequence_length,
1061 index_t head_count,
1062 index_t head_size,
1063 float alpha)
1064 : p_q_grid_{p_q_grid},
1065 p_kv_grid_{p_kv_grid},
1066 p_out_grid_{p_out_grid},
1067 batch_size_{batch_size},
1068 q_sequence_length_{q_sequence_length},
1069 kv_sequence_length_{kv_sequence_length},
1070 head_count_{head_count},
1071 head_size_{head_size},
1072 alpha_{alpha}
1073 {
1074 }
1075 // Pointers
1076 const ADataType* p_q_grid_;
1077 const B0DataType* p_kv_grid_;
1078 CDataType* p_out_grid_;
1079
1080 // Raw Problem Size
1086 float alpha_;
1087 };
1088
1089 static auto MakeCrossAttnArgument(const ADataType* p_q_grid,
1090 const B0DataType* p_kv_grid,
1091 CDataType* p_out_grid,
1092 index_t batch_size,
1093 index_t q_sequence_length,
1094 index_t kv_sequence_length,
1095 index_t head_count,
1096 index_t head_size,
1097 float alpha)
1098 {
1099 return CrossAttnArg{p_q_grid,
1100 p_kv_grid,
1101 p_out_grid,
1102 batch_size,
1103 q_sequence_length,
1104 kv_sequence_length,
1105 head_count,
1106 head_size,
1107 alpha};
1108 }
1109
1110 // Argument
1111 struct Argument : public BaseArgument
1112 {
1114 const ADataType* p_a_grid,
1115 const B0DataType* p_b0_grid,
1116 const B1DataType* p_b1_grid,
1117 CDataType* p_c_grid,
1118 const std::array<void*, NumAcc0Bias> p_acc0_biases,
1119 const std::array<void*, NumAcc1Bias> p_acc1_biases,
1120 const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths,
1121 const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides,
1122 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_lengths,
1123 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_strides,
1124 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_lengths,
1125 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_strides,
1126 const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_ns_lengths,
1127 const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_ns_strides,
1128 const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths,
1129 const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_strides,
1130 const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths,
1131 const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_strides,
1132 const index_t M01,
1133 const index_t N01,
1134 AElementwiseOperation a_element_op,
1135 B0ElementwiseOperation b0_element_op,
1136 AccElementwiseOperation acc_element_op,
1137 B1ElementwiseOperation b1_element_op,
1138 CElementwiseOperation c_element_op)
1139 : p_a_grid_{p_a_grid},
1140 p_b0_grid_{p_b0_grid},
1141 p_b1_grid_{p_b1_grid},
1142 p_c_grid_{p_c_grid},
1143 a_grid_desc{DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
1145 DeviceOp::MakeB0GridDescriptor(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)},
1147 DeviceOp::MakeB1GridDescriptor(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)},
1149 Transform::MakeCGridDescriptor_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)},
1151 Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
1153 Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)},
1155 Transform::MakeB1GridDescriptor_G_N_K(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)},
1157 Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)},
1159 block_2_ctile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)},
1160 a_element_op_{a_element_op},
1161 b0_element_op_{b0_element_op},
1162 acc_element_op_{acc_element_op},
1163 b1_element_op_{b1_element_op},
1164 c_element_op_{c_element_op},
1166 raw_lengths_mz_lz_kz_nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
1167 b0_gs_ls_ks_lengths[NumDimG + NumDimL - 1],
1168 b0_gs_ls_ks_lengths[NumDimG + NumDimL + NumDimK - 1],
1169 b1_gs_ns_ls_lengths[NumDimG + NumDimN - 1]},
1170 a_mz_kz_strides_{a_gs_ms_ks_strides[NumDimG + NumDimM - 1],
1171 a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]},
1172 b0_lz_kz_strides_{b0_gs_ls_ks_strides[NumDimG + NumDimL - 1],
1173 b0_gs_ls_ks_strides[NumDimG + NumDimL + NumDimK - 1]},
1174 b1_nz_lz_strides_{b1_gs_ns_ls_strides[NumDimG + NumDimN - 1],
1175 b1_gs_ns_ls_strides[NumDimG + NumDimN + NumDimL - 1]},
1176 c_mz_nz_strides_{c_gs_ms_ns_strides[NumDimG + NumDimM - 1],
1177 c_gs_ms_ns_strides[NumDimG + NumDimM + NumDimN - 1]},
1181 {
1182 // TODO ANT: implement bias addition
1183 ignore = p_acc0_biases;
1184 ignore = p_acc1_biases;
1185 ignore = acc0_biases_gs_ms_ls_lengths;
1186 ignore = acc0_biases_gs_ms_ls_strides;
1187 ignore = acc1_biases_gs_ms_ns_lengths;
1188 ignore = acc1_biases_gs_ms_ns_strides;
1189
1192 {
1196 }
1197 }
1198
1199 // Pointers
1200 const ADataType* p_a_grid_;
1201 const B0DataType* p_b0_grid_;
1202 const B1DataType* p_b1_grid_;
1203 CDataType* p_c_grid_;
1204
1205 // Tensor Descriptors
1210
1215
1218
1219 // Block to Tile mapping
1221
1222 // ElementwiseOp
1223 AElementwiseOperation a_element_op_;
1224 B0ElementwiseOperation b0_element_op_;
1225 AccElementwiseOperation acc_element_op_;
1226 B1ElementwiseOperation b1_element_op_;
1227 CElementwiseOperation c_element_op_;
1228
1229 // check C0 masking and padding
1231
1232 // Strides for the last M/N/K dimensions of A/B0/B1/C
1233 // for sanity check of vector load/store
1234 std::array<index_t, NumDimG + NumDimM + NumDimN> raw_lengths_mz_lz_kz_nz_;
1235 std::array<index_t, NumDimG + NumDimM + NumDimN> a_mz_kz_strides_;
1236 std::array<index_t, NumDimG + NumDimM + NumDimN> b0_lz_kz_strides_;
1237 std::array<index_t, NumDimG + NumDimM + NumDimN> b1_nz_lz_strides_;
1238 std::array<index_t, NumDimG + NumDimM + NumDimN> c_mz_nz_strides_;
1239
1241 // Batch Offset
1243 };
1244
1245 // Invoker
1247 {
1249
1250 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
1251 {
1252 const auto M0 = math::integer_divide_ceil(arg.sequence_length_, MPerBlock);
1253 const auto N0 = math::integer_divide_ceil(arg.head_size_, NPerBlock);
1254
1255 const index_t grid_size = arg.batch_size_ * arg.head_count_ * M0 * N0;
1256 const auto K = arg.head_size_;
1257
1258 auto launch_kernel = [&](auto has_main_k_block_loop) {
1259 const auto kernel = kernel_wmma_self_attention_forward<DeviceOp,
1260 GridwiseOp,
1261 ADataType,
1262 CDataType,
1263 AElementwiseOperation,
1264 B0ElementwiseOperation,
1265 AccElementwiseOperation,
1266 B1ElementwiseOperation,
1267 CElementwiseOperation,
1268 has_main_k_block_loop>;
1269
1270 return launch_and_time_kernel(stream_config,
1271 kernel,
1272 dim3(grid_size),
1273 dim3(BlockSize),
1274 0,
1275 arg.p_qkv_grid_,
1276 arg.p_out_grid_,
1277 arg.batch_size_,
1278 arg.sequence_length_,
1279 arg.head_count_,
1280 arg.head_size_,
1281 arg.alpha_);
1282 };
1283
1285 {
1286 return launch_kernel(integral_constant<bool, true>{});
1287 }
1288 else
1289 {
1290 return launch_kernel(integral_constant<bool, false>{});
1291 }
1292 }
1293
1294 // polymorphic
1295 float Run(const BaseArgument* p_arg,
1296 const StreamConfig& stream_config = StreamConfig{}) override
1297 {
1298 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
1299 }
1300 };
1301
1302 static auto MakeSelfAttnInvoker() { return SelfAttnInvoker{}; }
1303
1304 // Invoker
1306 {
1308
1309 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
1310 {
1311 const auto M0 = math::integer_divide_ceil(arg.q_sequence_length_, MPerBlock);
1312 const auto N0 = math::integer_divide_ceil(arg.head_size_, NPerBlock);
1313
1314 const index_t grid_size = arg.batch_size_ * arg.head_count_ * M0 * N0;
1315 const auto K = arg.head_size_;
1316
1317 auto launch_kernel = [&](auto has_main_k_block_loop) {
1319 GridwiseOp,
1320 ADataType,
1321 B0DataType,
1322 CDataType,
1323 AElementwiseOperation,
1324 B0ElementwiseOperation,
1325 AccElementwiseOperation,
1326 B1ElementwiseOperation,
1327 CElementwiseOperation,
1328 has_main_k_block_loop>;
1329
1330 return launch_and_time_kernel(stream_config,
1331 kernel,
1332 dim3(grid_size),
1333 dim3(BlockSize),
1334 0,
1335 arg.p_q_grid_,
1336 arg.p_kv_grid_,
1337 arg.p_out_grid_,
1338 arg.batch_size_,
1341 arg.head_count_,
1342 arg.head_size_,
1343 arg.alpha_);
1344 };
1345
1347 {
1348 return launch_kernel(integral_constant<bool, true>{});
1349 }
1350 else
1351 {
1352 return launch_kernel(integral_constant<bool, false>{});
1353 }
1354 }
1355
1356 // polymorphic
1357 float Run(const BaseArgument* p_arg,
1358 const StreamConfig& stream_config = StreamConfig{}) override
1359 {
1360 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
1361 }
1362 };
1363
1364 static auto MakeCrossAttnInvoker() { return CrossAttnInvoker{}; }
1365
1366 struct Invoker : public BaseInvoker
1367 {
1369
1370 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
1371 {
1372 const auto M0 = math::integer_divide_ceil(arg.M_, MPerBlock);
1373 const auto N0 = math::integer_divide_ceil(arg.O_, NPerBlock);
1374
1375 const index_t grid_size = arg.G0_ * arg.G1_ * M0 * N0;
1376 const auto K = arg.K_;
1377 // printf("HasKBlockLoop: %d\n", GridwiseOp::CalculateHasMainKBlockLoop(K));
1378 auto launch_kernel = [&](auto has_main_k_block_loop) {
1379 const auto kernel =
1381 GridwiseOp,
1382 ADataType,
1383 B0DataType,
1384 B1DataType,
1385 CDataType,
1386 AElementwiseOperation,
1387 B0ElementwiseOperation,
1388 AccElementwiseOperation,
1389 B1ElementwiseOperation,
1390 CElementwiseOperation,
1391 has_main_k_block_loop>;
1392
1393 return launch_and_time_kernel(stream_config,
1394 kernel,
1395 dim3(grid_size),
1396 dim3(BlockSize),
1397 0,
1398 arg.p_a_grid_,
1399 arg.p_b0_grid_,
1400 arg.p_b1_grid_,
1401 arg.p_c_grid_,
1402 arg.M_,
1403 arg.N_,
1404 arg.K_,
1405 arg.O_,
1406 arg.G0_,
1407 arg.G1_,
1408 arg.alpha_,
1409 arg.input_permute_,
1410 arg.output_permute_);
1411 };
1412
1414 {
1415 return launch_kernel(integral_constant<bool, true>{});
1416 }
1417 else
1418 {
1419 return launch_kernel(integral_constant<bool, false>{});
1420 }
1421 }
1422
1423 // polymorphic
1424 float Run(const BaseArgument* p_arg,
1425 const StreamConfig& stream_config = StreamConfig{}) override
1426 {
1427 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
1428 }
1429 };
1430
1431 static constexpr bool IsValidCompilationParameter()
1432 {
1433 // TODO: properly implement this check
1434 return true;
1435 }
1436#if 0
1437 static bool IsSupportedArgument(const Argument& arg)
1438 {
1440 {
1442 {
1443 printf("DeviceOp: Acc0 Type err");
1444 return false;
1445 }
1446
1448 {
1449 printf("DeviceOp: Acc1 Type err");
1450 return false;
1451 }
1452 }
1453 else
1454 {
1455 printf("DeviceOp: Arch err");
1456 return false;
1457 }
1458
1459 if(!GridwiseOp::CheckValidity(arg.a_grid_desc,
1460 arg.b0_grid_desc,
1461 arg.b1_grid_desc,
1462 arg.c_grid_desc_m_n_,
1463 arg.block_2_ctile_map_))
1464 {
1465 return false;
1466 }
1467
1468 // Check if C permute dimension matches GEMM + GEMM shape
1469 const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
1470
1471 if(!(c_g == arg.batch_count_))
1472 {
1473 printf("DeviceOp: BatchCount err");
1474 return false;
1475 }
1476
1477 // Note: we need raw lengths since threadwise copy can not handle vector load when part of
1478 // vector is out of bounds
1479 // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O
1480 const auto MzRaw = arg.raw_lengths_mz_lz_kz_nz_[0];
1481 const auto LzRaw = arg.raw_lengths_mz_lz_kz_nz_[1];
1482 const auto KzRaw = arg.raw_lengths_mz_lz_kz_nz_[2];
1483 const auto NzRaw = arg.raw_lengths_mz_lz_kz_nz_[3];
1484
1485 // Check scalar per vector requirement
1486 const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw;
1487 const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw;
1488 const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw;
1489 const auto c_extent_lowest = NzRaw;
1490
1491 if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
1492 b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
1493 b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
1494 c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
1495 {
1496 printf("DeviceOp: Data Transfer Vector scalar err");
1497 return false;
1498 }
1499
1500 // Check vector load/store requirement
1501 const auto a_stride_lowest =
1502 ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0];
1503 const auto b0_stride_lowest =
1504 B0BlockTransferSrcVectorDim == 2 ? arg.b0_lz_kz_strides_[1] : arg.b0_lz_kz_strides_[0];
1505 const auto b1_stride_lowest =
1506 B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_lz_strides_[1] : arg.b1_nz_lz_strides_[0];
1507 const auto c_stride_lowest = arg.c_mz_nz_strides_[1];
1508
1509 if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
1510 c_stride_lowest == 1))
1511 {
1512 printf("DeviceOp: Data Vectorize transfer err");
1513 return false;
1514 }
1515
1516 return true;
1517 }
1518
1519 // polymorphic
1520 bool IsSupportedArgument(const BaseArgument* p_arg) override
1521 {
1522 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
1523 }
1524
1525 static auto MakeArgument(
1526 const ADataType* p_a,
1527 const B0DataType* p_b0,
1528 const B1DataType* p_b1,
1529 CDataType* p_c,
1530 const std::array<void*, NumAcc0Bias> p_acc0_biases,
1531 const std::array<void*, NumAcc1Bias> p_acc1_biases,
1532 const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths,
1533 const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides,
1534 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_lengths,
1535 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b0_gs_ls_ks_strides,
1536 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_lengths,
1537 const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_ns_ls_strides,
1538 const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_ns_lengths,
1539 const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_ns_strides,
1540 const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths,
1541 const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_strides,
1542 const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths,
1543 const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_strides,
1544 AElementwiseOperation a_element_op,
1545 B0ElementwiseOperation b0_element_op,
1546 AccElementwiseOperation acc_element_op,
1547 B1ElementwiseOperation b1_element_op,
1548 CElementwiseOperation c_element_op)
1549 {
1550 return Argument{p_a,
1551 p_b0,
1552 p_b1,
1553 p_c,
1554 p_acc0_biases,
1555 p_acc1_biases,
1556 a_gs_ms_ks_lengths,
1557 a_gs_ms_ks_strides,
1558 b0_gs_ls_ks_lengths,
1559 b0_gs_ls_ks_strides,
1560 b1_gs_ns_ls_lengths,
1561 b1_gs_ns_ls_strides,
1562 c_gs_ms_ns_lengths,
1563 c_gs_ms_ns_strides,
1564 acc0_biases_gs_ms_ls_lengths,
1565 acc0_biases_gs_ms_ls_strides,
1566 acc1_biases_gs_ms_ns_lengths,
1567 acc1_biases_gs_ms_ns_strides,
1568 1,
1569 1,
1570 a_element_op,
1571 b0_element_op,
1572 acc_element_op,
1573 b1_element_op,
1574 c_element_op};
1575 }
1576#endif
1577
1578 // polymorphic
1579 std::unique_ptr<BaseArgument> MakeArgumentPointer(
1580 const void* p_a,
1581 const void* p_b0,
1582 const void* p_b1,
1583 void* p_c,
1584 const std::array<void*, NumAcc0Bias> p_acc0_biases,
1585 const std::array<void*, NumAcc1Bias> p_acc1_biases,
1586 const std::vector<index_t>& a_gs_ms_ks_lengths,
1587 const std::vector<index_t>& a_gs_ms_ks_strides,
1588 const std::vector<index_t>& b0_gs_ls_ks_lengths,
1589 const std::vector<index_t>& b0_gs_ls_ks_strides,
1590 const std::vector<index_t>& b1_gs_ns_ls_lengths,
1591 const std::vector<index_t>& b1_gs_ns_ls_strides,
1592 const std::vector<index_t>& c_gs_ms_ns_lengths,
1593 const std::vector<index_t>& c_gs_ms_ns_strides,
1594 const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths,
1595 const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ls_strides,
1596 const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths,
1597 const std::array<std::vector<ck::index_t>, NumAcc1Bias> acc1_biases_gs_ms_ns_strides,
1598 AElementwiseOperation a_element_op,
1599 B0ElementwiseOperation b0_element_op,
1600 AccElementwiseOperation acc_element_op,
1601 B1ElementwiseOperation b1_element_op,
1602 CElementwiseOperation c_element_op) override
1603 {
1604 std::array<index_t, NumDimG + NumDimM + NumDimN> a_lengths;
1605 std::array<index_t, NumDimG + NumDimM + NumDimN> a_strides;
1606 std::array<index_t, NumDimG + NumDimM + NumDimN> b0_lengths;
1607 std::array<index_t, NumDimG + NumDimM + NumDimN> b0_strides;
1608 std::array<index_t, NumDimG + NumDimM + NumDimN> b1_lengths;
1609 std::array<index_t, NumDimG + NumDimM + NumDimN> b1_strides;
1610 std::array<index_t, NumDimG + NumDimM + NumDimN> c_lengths;
1611 std::array<index_t, NumDimG + NumDimM + NumDimN> c_strides;
1612 std::transform(a_gs_ms_ks_lengths.begin(),
1613 a_gs_ms_ks_lengths.end(),
1614 a_lengths.begin(),
1615 [](index_t i) { return i; });
1616 std::transform(a_gs_ms_ks_strides.begin(),
1617 a_gs_ms_ks_strides.end(),
1618 a_strides.begin(),
1619 [](index_t i) { return i; });
1620 std::transform(b0_gs_ls_ks_lengths.begin(),
1621 b0_gs_ls_ks_lengths.end(),
1622 b0_lengths.begin(),
1623 [](index_t i) { return i; });
1624 std::transform(b0_gs_ls_ks_strides.begin(),
1625 b0_gs_ls_ks_strides.end(),
1626 b0_strides.begin(),
1627 [](index_t i) { return i; });
1628 std::transform(b1_gs_ns_ls_lengths.begin(),
1629 b1_gs_ns_ls_lengths.end(),
1630 b1_lengths.begin(),
1631 [](index_t i) { return i; });
1632 std::transform(b1_gs_ns_ls_strides.begin(),
1633 b1_gs_ns_ls_strides.end(),
1634 b1_strides.begin(),
1635 [](index_t i) { return i; });
1636 std::transform(c_gs_ms_ns_lengths.begin(),
1637 c_gs_ms_ns_lengths.end(),
1638 c_lengths.begin(),
1639 [](index_t i) { return i; });
1640 std::transform(c_gs_ms_ns_strides.begin(),
1641 c_gs_ms_ns_strides.end(),
1642 c_strides.begin(),
1643 [](index_t i) { return i; });
1644 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
1645 static_cast<const B0DataType*>(p_b0),
1646 static_cast<const B1DataType*>(p_b1),
1647 static_cast<CDataType*>(p_c),
1648 p_acc0_biases,
1649 p_acc1_biases,
1650 a_lengths,
1651 a_strides,
1652 b0_lengths,
1653 b0_strides,
1654 b1_lengths,
1655 b1_strides,
1656 c_lengths,
1657 c_strides,
1658 acc0_biases_gs_ms_ls_lengths,
1659 acc0_biases_gs_ms_ls_strides,
1660 acc1_biases_gs_ms_ns_lengths,
1661 acc1_biases_gs_ms_ns_strides,
1662 1,
1663 1,
1664 a_element_op,
1665 b0_element_op,
1666 acc_element_op,
1667 b1_element_op,
1668 c_element_op);
1669 }
1670
1671 static auto MakeInvoker() { return Invoker{}; }
1672
1673 // polymorphic
1674 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
1675 {
1676 return std::make_unique<Invoker>(Invoker{});
1677 }
1678
1679 // polymorphic
1680 std::string GetTypeString() const override
1681 {
1682 auto str = std::stringstream();
1683
1684 std::map<LoopScheduler, std::string> LoopSchedToString{
1685 {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
1686
1687 std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
1688 {PipelineVersion::v2, "v2"}};
1689
1690 // clang-format off
1691 str << "DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle"
1692 << "<"
1693 << BlockSize << ", "
1694 << MPerBlock << ", "
1695 << LPerBlock << ", "
1696 << KPerBlock << ", "
1697 << AK1 << ", "
1698 << BK1 << ", "
1699 << MPerBlock << ", "
1700 << NPerBlock << ", "
1701 << LTilePerBlock << ", "
1702 << L1 << ", "
1703 << getGemmSpecializationString(GemmSpec) << ", "
1704 << "ASpec" << getTensorSpecializationString(ASpec) << ", "
1705 << "B0Spec" << getTensorSpecializationString(B0Spec) << ", "
1706 << "B1Spec" << getTensorSpecializationString(B1Spec) << ", "
1707 << "CSpec" << getTensorSpecializationString(CSpec) << ", "
1708 << getMaskingSpecializationString(MaskingSpec)
1709 << ">"
1710 << " AEnableLds: "
1711 << AEnableLds << ", "
1712 << "B0EnableLds: "
1713 << B0EnableLds << ", "
1714 << "B1EnableLds: "
1715 << B1EnableLds << ", "
1716 << "NumPrefetch: "
1717 << NumPrefetch << ", "
1718 << "LoopScheduler: "
1719 << LoopSchedToString[LoopSched] << ", "
1720 << "PipelineVersion: "
1721 << PipelineVersionToString[PipelineVer];
1722 // clang-format on
1723
1724 return str.str();
1725 }
1726};
1727
1728} // namespace device
1729} // namespace tensor_operation
1730} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
std::string getMaskingSpecializationString(const MaskingSpecialization &s)
Definition masking_specialization.hpp:17
MaskingSpecialization
Definition masking_specialization.hpp:11
@ MaskDisabled
Definition masking_specialization.hpp:12
@ MaskOutUpperTriangle
Definition masking_specialization.hpp:13
TensorSpecialization
Definition tensor_specialization.hpp:11
GemmSpecialization
Definition gemm_specialization.hpp:11
__global__ void kernel_wmma_cross_attention_forward(const QDataType *__restrict__ p_q_grid, const KVDataType *__restrict__ p_kv_grid, ODataType *__restrict__ p_out_grid, index_t batch_size, index_t q_sequence_length, index_t kv_sequence_length, index_t head_count, index_t head_size, float alpha)
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:315
__global__ void kernel_batched_gemm_softmax_gemm_wmma_cshuffle(const ADataType *__restrict__ p_a_grid, const B0DataType *__restrict__ p_b0_grid, const B1DataType *__restrict__ p_b1_grid, CDataType *__restrict__ p_c_grid, index_t M, index_t N, index_t K, index_t O, index_t G0, index_t G1, float alpha, bool input_permute, bool output_permute)
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:45
std::string getTensorSpecializationString(const TensorSpecialization &s)
Definition tensor_specialization.hpp:16
__global__ void kernel_wmma_self_attention_forward(const QKVDataType *__restrict__ p_qkv_grid, ODataType *__restrict__ p_out_grid, index_t batch_size, index_t sequence_length, index_t head_count, index_t head_size, float alpha)
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:183
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__device__ index_t get_grid_size()
Definition get_id.hpp:49
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
integral_constant< index_t, N > Number
Definition number.hpp:12
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
__device__ index_t get_thread_global_1d_id()
Definition get_id.hpp:43
constexpr bool is_same_v
Definition type.hpp:283
LoopScheduler
Definition loop_scheduler.hpp:15
@ Default
Definition loop_scheduler.hpp:16
@ Interwave
Definition loop_scheduler.hpp:17
int64_t long_index_t
Definition ck.hpp:300
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v2
Definition gridwise_gemm_pipeline_selector.hpp:20
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:93
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::DefaultBlock2CTileMap
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))> DefaultBlock2CTileMap
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:682
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::MakeDefaultBlock2CTileMap
__host__ static __device__ constexpr auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n, index_t, index_t)
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:672
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))> CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:679
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
__host__ static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:653
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::CalculateHasMainKBlockLoop
__host__ static __device__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:645
ck::GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer >::CheckValidity
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc &a_grid_desc, const B0GridDesc &b0_grid_desc, const B1GridDesc &b1_grid_desc, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:511
Definition utility/sequence.hpp:43
Definition utility/integral_constant.hpp:20
Definition transform_contraction_to_gemm_arraybase.hpp:122
__host__ static __device__ constexpr auto MakeB0GridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k, const Number &BK1)
Definition transform_contraction_to_gemm_arraybase.hpp:245
__host__ static __device__ auto MakeCGridDescriptor_G_M_N(const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_os_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_os_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:375
__host__ static __device__ constexpr auto MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k, const Number &AK1)
Definition transform_contraction_to_gemm_arraybase.hpp:172
__host__ static __device__ auto MakeB1GridDescriptor_N_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_os_ns_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_os_ns_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:307
__host__ static __device__ auto MakeB1GridDescriptor_G_N_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_os_ns_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_os_ns_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:301
__host__ static __device__ constexpr auto MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1(const BGridDesc_L_K &b_grid_desc_l_k, const WmmaK &, const LRepeat &, const LWaves &, const LPerWmma &, const BK1 &)
Definition transform_contraction_to_gemm_arraybase.hpp:266
__host__ static __device__ auto MakeB0GridDescriptor_G_N_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ns_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ns_ks_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:228
__host__ static __device__ constexpr auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K &b1_grid_desc_n_k, const Number &B1K1)
Definition transform_contraction_to_gemm_arraybase.hpp:318
__host__ static __device__ auto MakeAGridDescriptor_G_M_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:156
__host__ static __device__ constexpr auto MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1(const AGridDesc_M_K &a_grid_desc_m_k, const WmmaK &, const MRepeat &, const MWaves &, const MPerWmma &, const AK1 &)
Definition transform_contraction_to_gemm_arraybase.hpp:193
__host__ static __device__ auto MakeCGridDescriptor_M_N(const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_os_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_os_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:381
__host__ static __device__ constexpr auto MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1(const BGridDesc_N_L &b_grid_desc_n_l, const WmmaL &, const NRepeat &, const NWaves &, const NPerWmma &, const BL1 &)
Definition transform_contraction_to_gemm_arraybase.hpp:340
__host__ static __device__ auto MakeAGridDescriptor_M_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:162
__host__ static __device__ auto MakeB0GridDescriptor_N_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ns_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ns_ks_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:234
Definition device_base.hpp:197
Definition masking_specialization.hpp:57
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:679
__host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:696
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:691
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:701
__host__ __device__ ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K &a_grid_desc_g_m_k, const B0GridDesc_G_L_K &b0_grid_desc_g_l_k, const B1GridDesc_G_N_L &b1_grid_desc_g_n_l, const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:680
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:706
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1112
CDataType * p_c_grid_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1203
AGridDesc_G_M_K a_grid_desc_g_m_k_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1211
ComputeBasePtrOfStridedBatch compute_ptr_offset_of_batch_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1242
const B0DataType * p_b0_grid_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1201
AccElementwiseOperation acc_element_op_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1225
const B1DataType * p_b1_grid_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1202
B1ElementwiseOperation b1_element_op_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1226
std::array< index_t, NumDimG+NumDimM+NumDimN > a_mz_kz_strides_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1235
std::array< index_t, NumDimG+NumDimM+NumDimN > c_mz_nz_strides_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1238
Argument(const ADataType *p_a_grid, const B0DataType *p_b0_grid, const B1DataType *p_b1_grid, CDataType *p_c_grid, const std::array< void *, NumAcc0Bias > p_acc0_biases, const std::array< void *, NumAcc1Bias > p_acc1_biases, const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_lengths, const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_strides, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ls_ks_lengths, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ls_ks_strides, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_ns_ls_lengths, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_ns_ls_strides, const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_ns_lengths, const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_ns_strides, const std::array< std::vector< ck::index_t >, NumAcc0Bias > acc0_biases_gs_ms_ls_lengths, const std::array< std::vector< ck::index_t >, NumAcc0Bias > acc0_biases_gs_ms_ls_strides, const std::array< std::vector< ck::index_t >, NumAcc1Bias > acc1_biases_gs_ms_ns_lengths, const std::array< std::vector< ck::index_t >, NumAcc1Bias > acc1_biases_gs_ms_ns_strides, const index_t M01, const index_t N01, AElementwiseOperation a_element_op, B0ElementwiseOperation b0_element_op, AccElementwiseOperation acc_element_op, B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op)
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1113
B0GridDesc_G_L_K b0_grid_desc_g_l_k_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1212
CElementwiseOperation c_element_op_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1227
GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1220
C0MatrixMask c0_matrix_mask_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1230
B0ElementwiseOperation b0_element_op_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1224
AGridDesc a_grid_desc
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1206
B1GridDesc b1_grid_desc
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1208
B1GridDesc_G_N_L b1_grid_desc_g_n_l_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1213
const ADataType * p_a_grid_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1200
CGridDesc_M_N c_grid_desc_m_n_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1209
AElementwiseOperation a_element_op_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1223
std::array< index_t, NumDimG+NumDimM+NumDimN > raw_lengths_mz_lz_kz_nz_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1234
CGridDesc_G_M_N c_grid_desc_g_m_n_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1214
B0GridDesc b0_grid_desc
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1207
index_t batch_count_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1240
std::array< index_t, NumDimG+NumDimM+NumDimN > b0_lz_kz_strides_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1236
GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1217
std::array< index_t, NumDimG+NumDimM+NumDimN > b1_nz_lz_strides_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1237
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1054
const ADataType * p_q_grid_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1076
float alpha_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1086
const B0DataType * p_kv_grid_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1077
CDataType * p_out_grid_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1078
index_t q_sequence_length_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1082
index_t kv_sequence_length_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1083
index_t head_size_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1085
CrossAttnArg(const ADataType *p_q_grid, const B0DataType *p_kv_grid, CDataType *p_out_grid, index_t batch_size, index_t q_sequence_length, index_t kv_sequence_length, index_t head_count, index_t head_size, float alpha)
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1055
index_t batch_size_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1081
index_t head_count_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1084
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1306
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1309
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1357
DeviceOp::CrossAttnArg Argument
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1307
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1367
DeviceOp::RawArg Argument
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1368
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1424
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1370
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:795
const B0DataType * p_b0_grid_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:826
index_t K_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:833
bool output_permute_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:839
const B1DataType * p_b1_grid_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:827
index_t G1_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:836
float alpha_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:837
index_t O_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:834
CDataType * p_c_grid_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:828
RawArg(const ADataType *p_a_grid, const B0DataType *p_b0_grid, const B1DataType *p_b1_grid, CDataType *p_c_grid, index_t M, index_t N, index_t K, index_t O, index_t G0, index_t G1, float alpha, bool input_permute, bool output_permute)
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:796
index_t N_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:832
index_t M_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:831
const ADataType * p_a_grid_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:825
index_t G0_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:835
bool input_permute_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:838
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1012
SelfAttnArg(const ADataType *p_qkv_grid, CDataType *p_out_grid, index_t batch_size, index_t sequence_length, index_t head_count, index_t head_size, float alpha)
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1013
float alpha_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1038
index_t sequence_length_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1035
index_t head_count_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1036
const ADataType * p_qkv_grid_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1030
index_t head_size_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1037
index_t batch_size_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1034
CDataType * p_out_grid_
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1031
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1247
DeviceOp::SelfAttnArg Argument
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1248
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1250
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1295
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:528
static constexpr index_t NumAcc0Bias
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:532
static constexpr auto B0EnableLds_manu
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:566
__host__ __device__ static constexpr auto make_MaskOutPredicate()
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:665
decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})) B1GridDesc_G_N_L
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:662
decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})) B0GridDesc_G_L_K
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:661
static constexpr auto I4
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:551
decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})) AGridDesc_G_M_K
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:660
TransformBatchedContractionContractionToBatchedGemmGemm_Wmma< Sequence< NumDimG, NumDimM, NumDimL, NumDimK, NumDimN >, Sequence< MPerBlock, LPerBlock, KPerBlock, NPerBlock >, GemmSpec, ASpec, B0Spec, B1Spec, CSpec > Transform
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:573
__host__ static __device__ auto MakeB1GridDescriptor(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_ns_ls_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_ns_ls_strides_vec)
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:631
static constexpr index_t NumDimGemm1N
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:542
decltype(MakeAGridDescriptor({}, {})) AGridDesc
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:656
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1674
static auto MakeArgument(const ADataType *p_a, const B0DataType *p_b0, const B1DataType *p_b1, CDataType *p_c, index_t M, index_t N, index_t K, index_t O, index_t G0, index_t G1, float alpha, bool input_permute, bool output_permute)
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:842
__host__ static __device__ auto MakeB0GridDescriptor(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ls_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ls_ks_strides_vec)
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:606
std::string GetTypeString() const override
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1680
static constexpr auto B1EnableLds
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:571
decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})) CGridDesc_G_M_N
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:663
static constexpr auto MWaves
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:557
static constexpr auto WmmaK
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:555
static auto MakeCrossAttnArgument(const ADataType *p_q_grid, const B0DataType *p_kv_grid, CDataType *p_out_grid, index_t batch_size, index_t q_sequence_length, index_t kv_sequence_length, index_t head_count, index_t head_size, float alpha)
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1089
static constexpr auto B0EnableLds
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:570
static auto MakeCrossAttnInvoker()
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1364
DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle DeviceOp
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:545
static constexpr index_t NumDimGemm1M
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:541
static auto MakeInvoker()
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1671
static constexpr index_t NumDimGemm0N
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:539
static constexpr index_t NumDimGemm0K
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:540
static bool IsSupportedArgument(const RawArg &arg)
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:860
static constexpr auto I2
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:549
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1006
static constexpr index_t NumDimGemm0M
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:538
static constexpr auto I5
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:552
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b0, const void *p_b1, void *p_c, const std::array< void *, NumAcc0Bias > p_acc0_biases, const std::array< void *, NumAcc1Bias > p_acc1_biases, const std::vector< index_t > &a_gs_ms_ks_lengths, const std::vector< index_t > &a_gs_ms_ks_strides, const std::vector< index_t > &b0_gs_ls_ks_lengths, const std::vector< index_t > &b0_gs_ls_ks_strides, const std::vector< index_t > &b1_gs_ns_ls_lengths, const std::vector< index_t > &b1_gs_ns_ls_strides, const std::vector< index_t > &c_gs_ms_ns_lengths, const std::vector< index_t > &c_gs_ms_ns_strides, const std::array< std::vector< ck::index_t >, NumAcc0Bias > acc0_biases_gs_ms_ls_lengths, const std::array< std::vector< ck::index_t >, NumAcc0Bias > acc0_biases_gs_ms_ls_strides, const std::array< std::vector< ck::index_t >, NumAcc1Bias > acc1_biases_gs_ms_ns_lengths, const std::array< std::vector< ck::index_t >, NumAcc1Bias > acc1_biases_gs_ms_ns_strides, AElementwiseOperation a_element_op, B0ElementwiseOperation b0_element_op, AccElementwiseOperation acc_element_op, B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op) override
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1579
static auto MakeSelfAttnArgument(const ADataType *p_qkv_grid, CDataType *p_out_grid, index_t batch_size, index_t sequence_length, index_t head_count, index_t head_size, float alpha)
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1041
static constexpr auto B1EnableLds_manu
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:567
static constexpr index_t NumDimGemm1K
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:543
static constexpr auto I6
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:553
static constexpr auto LWaves
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:558
static constexpr auto AEnableLds
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:569
C0MatrixMask_impl< decltype(make_MaskOutPredicate())> C0MatrixMask
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:676
static constexpr auto B0EnableLds_auto
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:562
decltype(MakeB0GridDescriptor({}, {})) B0GridDesc
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:657
decltype(MakeB1GridDescriptor({}, {})) B1GridDesc
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:658
static constexpr auto I3
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:550
static constexpr bool IsValidCompilationParameter()
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1431
static constexpr auto AEnableLds_auto
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:561
static constexpr auto I0
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:547
static constexpr auto I1
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:548
static constexpr index_t NumAcc1Bias
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:533
static auto MakeSelfAttnInvoker()
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:1302
__host__ static __device__ auto MakeAGridDescriptor(const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_strides_vec)
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:582
static constexpr auto NWaves
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:559
GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer > GridwiseOp
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:719
decltype(Transform::MakeCGridDescriptor_M_N({}, {})) CGridDesc_M_N
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:659
static constexpr auto B1EnableLds_auto
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:563
static constexpr auto AEnableLds_manu
Definition device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp:565
Definition device_batched_gemm_softmax_gemm_permute.hpp:34
Definition masking_specialization.hpp:29
Definition masking_specialization.hpp:43