gridwise_gemm_pipeline_v4_direct_load.hpp Source File

gridwise_gemm_pipeline_v4_direct_load.hpp Source File#

Composable Kernel: gridwise_gemm_pipeline_v4_direct_load.hpp Source File
gridwise_gemm_pipeline_v4_direct_load.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
9
10namespace lds_direct_load {
11
12__device__ void sched_barrier()
13{
14#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
15 // When direct loads and `waitcnt` instructions are submitted using inline asm, the usage of
16 // `sched_barrier` is necessary to make sure no instructions that use the loaded memory
17 // are scheduled by the compiler before the `waitcnt` instruction.
18 __builtin_amdgcn_sched_barrier(0);
19#endif
20}
21
22} // namespace lds_direct_load
23
24namespace ck {
25
26template <index_t NumPrefetch>
28
29// 1-stage prefetch
30template <>
32{
33 static constexpr auto I0 = Number<0>{};
34
35 __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
36
37 __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
38 {
39 return num_loop > 1;
40 }
41
42 template <bool HasMainLoop,
43 typename AGridDesc,
44 typename ABlockDesc,
45 typename ABlockTransfer,
46 typename AGridBuffer,
47 typename ABlockBuffers,
48 typename ABlockTransferStep,
49 typename BGridDesc,
50 typename BBlockDesc,
51 typename BBlockTransfer,
52 typename BGridBuffer,
53 typename BBlockBuffers,
54 typename BBlockTransferStep,
55 typename BlockwiseGemm,
56 typename CThreadBuffer>
57 __device__ static void Run(const AGridDesc& a_grid_desc,
58 const ABlockDesc& a_block_desc,
59 ABlockTransfer& a_blockwise_copy,
60 const AGridBuffer& a_grid_buf,
61 ABlockBuffers& a_block_bufs,
62 const ABlockTransferStep& a_block_copy_step,
63 const BGridDesc& b_grid_desc,
64 const BBlockDesc& b_block_desc,
65 BBlockTransfer& b_blockwise_copy,
66 const BGridBuffer& b_grid_buf,
67 BBlockBuffers& b_block_bufs,
68 const BBlockTransferStep& b_block_copy_step,
69 const BlockwiseGemm& blockwise_gemm,
70 CThreadBuffer& c_thread_buf,
71 index_t num_loop)
72 {
73 static_assert(ABlockBuffers::Size() == 1 && BBlockBuffers::Size() == 1);
74 auto& a_block_buf = a_block_bufs.At(I0);
75 auto& b_block_buf = b_block_bufs.At(I0);
76
77 a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
78 b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
79
80 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
81 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
82
83 // Initialize C
84 c_thread_buf.Clear();
85
86 // main body
87 if constexpr(HasMainLoop)
88 {
89 index_t i = 0;
90
91 do
92 {
95
96 blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
97
100
101 a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
102 b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
103
104 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
105 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
106
107 ++i;
108 } while(i < (num_loop - 1));
109 }
110
111 // tail
112 {
115
116 blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
117 }
118 }
119};
120
121// 2-stages prefetch
122template <>
124{
125 static constexpr auto I0 = Number<0>{};
126 static constexpr auto I1 = Number<1>{};
127
128 __host__ __device__ static constexpr bool IsSupported(index_t num_loop)
129 {
130 return num_loop % 2 == 0;
131 }
132
133 __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
134 {
135 return (num_loop / 2) > 1;
136 }
137
138 template <bool HasMainLoop,
139 typename AGridDesc,
140 typename ABlockDesc,
141 typename ABlockTransfer,
142 typename AGridBuffer,
143 typename ABlockBuffers,
144 typename ABlockTransferStep,
145 typename BGridDesc,
146 typename BBlockDesc,
147 typename BBlockTransfer,
148 typename BGridBuffer,
149 typename BBlockBuffers,
150 typename BBlockTransferStep,
151 typename BlockwiseGemm,
152 typename CThreadBuffer>
153 __device__ static void Run(const AGridDesc& a_grid_desc,
154 const ABlockDesc& a_block_desc,
155 ABlockTransfer& a_blockwise_copy,
156 const AGridBuffer& a_grid_buf,
157 ABlockBuffers& a_block_bufs,
158 const ABlockTransferStep& a_block_copy_step,
159 const BGridDesc& b_grid_desc,
160 const BBlockDesc& b_block_desc,
161 BBlockTransfer& b_blockwise_copy,
162 const BGridBuffer& b_grid_buf,
163 BBlockBuffers& b_block_bufs,
164 const BBlockTransferStep& b_block_copy_step,
165 const BlockwiseGemm& blockwise_gemm,
166 CThreadBuffer& c_thread_buf,
167 index_t num_loop)
168 {
169 static_assert(ABlockBuffers::Size() == 2 && BBlockBuffers::Size() == 2);
170 auto& a_block_buf1 = a_block_bufs.At(I0);
171 auto& a_block_buf2 = a_block_bufs.At(I1);
172 auto& b_block_buf1 = b_block_bufs.At(I0);
173 auto& b_block_buf2 = b_block_bufs.At(I1);
174
175 a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf1);
176 b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf1);
177
178 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
179 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
180
181 // Initialize C
182 c_thread_buf.Clear();
183
184 // main body
185 if constexpr(HasMainLoop)
186 {
187 index_t i = 0;
188
189 do
190 {
193
194 a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf2);
195 b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf2);
196
197 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
198 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
199
200 blockwise_gemm.Run(a_block_buf1, b_block_buf1, c_thread_buf);
201
204
205 a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf1);
206 b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf1);
207
208 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
209 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
210
211 blockwise_gemm.Run(a_block_buf2, b_block_buf2, c_thread_buf);
212
213 i += 2;
214 } while(i < (num_loop - 2));
215 }
216
217 // tail
218 {
221
222 a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf2);
223 b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf2);
224
225 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
226 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
227
228 blockwise_gemm.Run(a_block_buf1, b_block_buf1, c_thread_buf);
229
232
233 blockwise_gemm.Run(a_block_buf2, b_block_buf2, c_thread_buf);
234 }
235 }
236};
237
238} // namespace ck
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
__device__ void block_sync_lds_direct_load()
Definition synchronization.hpp:43
Definition gridwise_gemm_pipeline_v4_direct_load.hpp:10
__device__ void sched_barrier()
Definition gridwise_gemm_pipeline_v4_direct_load.hpp:12
static constexpr auto I0
Definition gridwise_gemm_pipeline_v4_direct_load.hpp:33
__host__ static __device__ constexpr bool CalculateHasMainLoop(index_t num_loop)
Definition gridwise_gemm_pipeline_v4_direct_load.hpp:37
__host__ static __device__ constexpr bool IsSupported(index_t)
Definition gridwise_gemm_pipeline_v4_direct_load.hpp:35
static __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffers &a_block_bufs, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffers &b_block_bufs, const BBlockTransferStep &b_block_copy_step, const BlockwiseGemm &blockwise_gemm, CThreadBuffer &c_thread_buf, index_t num_loop)
Definition gridwise_gemm_pipeline_v4_direct_load.hpp:57
static __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffers &a_block_bufs, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffers &b_block_bufs, const BBlockTransferStep &b_block_copy_step, const BlockwiseGemm &blockwise_gemm, CThreadBuffer &c_thread_buf, index_t num_loop)
Definition gridwise_gemm_pipeline_v4_direct_load.hpp:153
__host__ static __device__ constexpr bool IsSupported(index_t num_loop)
Definition gridwise_gemm_pipeline_v4_direct_load.hpp:128
__host__ static __device__ constexpr bool CalculateHasMainLoop(index_t num_loop)
Definition gridwise_gemm_pipeline_v4_direct_load.hpp:133
static constexpr auto I1
Definition gridwise_gemm_pipeline_v4_direct_load.hpp:126
static constexpr auto I0
Definition gridwise_gemm_pipeline_v4_direct_load.hpp:125
Definition gridwise_gemm_pipeline_v4_direct_load.hpp:27