transform_contraction_to_gemm.hpp Source File

transform_contraction_to_gemm.hpp Source File#

Composable Kernel: transform_contraction_to_gemm.hpp Source File
transform_contraction_to_gemm.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
10
11namespace ck {
12namespace tensor_operation {
13
14// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
15template <index_t NumDimG,
16 index_t NumDimM,
17 index_t NumDimN,
19static auto MakeGridDescriptorPair(const std::vector<index_t>& gs_ms_ns_lengths_vec,
20 const std::vector<index_t>& gs_ms_ns_strides_vec)
21{
22 if(!(gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
23 gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN))
24 {
25 throw std::runtime_error("wrong! dimension must match input lengths");
26 }
27
28 const auto to_tuple = [&](auto& vec, auto start, auto end) {
29 return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
30 };
31
32 const auto gs_ms_ns_lengths =
33 to_tuple(gs_ms_ns_lengths_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
34 const auto gs_ms_ns_strides =
35 to_tuple(gs_ms_ns_strides_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
36
37 // dimension Ids for G0, G1, ...
38 constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{};
39
40 // dimension Ids for M0, M1, ...
41 constexpr auto mDimIds =
43
44 // dimension Ids for N0, N1, ...
45 constexpr auto nDimIds =
47
48 // lengths for G0, G1, ...
49 const auto gLengths = get_container_subset(gs_ms_ns_lengths, gDimIds);
50
51 // lengths for M0, M1, ...
52 const auto mLengths = get_container_subset(gs_ms_ns_lengths, mDimIds);
53
54 // lengths for N0, N1, ...
55 const auto nLengths = get_container_subset(gs_ms_ns_lengths, nDimIds);
56
57 if constexpr(TensorSpec == device::TensorSpecialization::Packed)
58 {
59 auto G = container_reduce(gLengths, math::multiplies{}, Number<1>{});
60 auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
61 auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
62 const auto grid_desc_g_mraw_nraw = make_naive_tensor_descriptor(
63 make_tuple(G, M, N),
64 make_tuple(gs_ms_ns_strides[Number<NumDimG - 1>{}],
65 gs_ms_ns_strides[Number<NumDimG + NumDimM - 1>{}],
66 gs_ms_ns_strides[Number<NumDimG + NumDimM + NumDimN - 1>{}]));
67
68 const auto grid_desc_mraw_nraw = make_naive_tensor_descriptor(
69 make_tuple(M, N),
70 make_tuple(gs_ms_ns_strides[Number<NumDimG + NumDimM - 1>{}],
71 gs_ms_ns_strides[Number<NumDimG + NumDimM + NumDimN - 1>{}]));
72
73 return std::make_pair(grid_desc_g_mraw_nraw, grid_desc_mraw_nraw);
74 }
75 else
76 {
77 // naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
78 const auto grid_desc_gs_ms_ns =
79 make_naive_tensor_descriptor(gs_ms_ns_lengths, gs_ms_ns_strides);
80
81 // transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
82 // N2 * ...]
83 // Note: This does not require padding as it only provides G offset calculation. Technically
84 // descriptor for only G is needed. Here we opt for backward compatibility purpose to return
85 // G_M_N
86 const auto grid_desc_g_mraw_nraw =
87 transform_tensor_descriptor(grid_desc_gs_ms_ns,
89 make_merge_transform(mLengths),
90 make_merge_transform(nLengths)),
91 make_tuple(gDimIds, mDimIds, nDimIds),
92 make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
93
94 const auto c_ms_ns_lengths = to_tuple(
95 gs_ms_ns_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
96 const auto c_ms_ns_strides = to_tuple(
97 gs_ms_ns_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
98
99 // transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
100 // N2 * ...]
101 const auto grid_desc_ms_ns = make_naive_tensor_descriptor(c_ms_ns_lengths, c_ms_ns_strides);
102
103 const auto grid_desc_mraw_nraw = transform_tensor_descriptor(
104 grid_desc_ms_ns,
106 make_tuple(mDimIds - Number<NumDimG>{}, nDimIds - Number<NumDimG>{}),
107 make_tuple(Sequence<0>{}, Sequence<1>{}));
108
109 return std::make_pair(grid_desc_g_mraw_nraw, grid_desc_mraw_nraw);
110 }
111}
112
113template <typename NumDims_G_M_N_K_O, // Sequence<>
114 typename PerBlock_M_N_K_O, // Sequence<>
121{
122 static constexpr auto I0 = Number<0>{};
123 static constexpr auto I1 = Number<1>{};
124 static constexpr auto I2 = Number<2>{};
125 static constexpr auto I3 = Number<3>{};
126 static constexpr auto I4 = Number<4>{};
127
128 static constexpr index_t NumDimG = NumDims_G_M_N_K_O::At(I0);
129 static constexpr index_t NumDimM = NumDims_G_M_N_K_O::At(I1);
130 static constexpr index_t NumDimN = NumDims_G_M_N_K_O::At(I2);
131 static constexpr index_t NumDimK = NumDims_G_M_N_K_O::At(I3);
132 static constexpr index_t NumDimO = NumDims_G_M_N_K_O::At(I4);
133
134 static constexpr index_t MPerBlock = PerBlock_M_N_K_O::At(I0);
135 static constexpr index_t NPerBlock = PerBlock_M_N_K_O::At(I1);
136 static constexpr index_t KPerBlock = PerBlock_M_N_K_O::At(I2);
137 static constexpr index_t OPerBlock = PerBlock_M_N_K_O::At(I3);
138
142
143 //
144 // A
145 //
146 static auto MakeAGridDescriptorPair(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
147 const std::vector<index_t>& a_gs_ms_ks_strides_vec)
148 {
149 return MakeGridDescriptorPair<NumDimG, NumDimM, NumDimK, ASpec>(a_gs_ms_ks_lengths_vec,
150 a_gs_ms_ks_strides_vec);
151 }
152
153 // TODO: rename to G_MRaw_KRaw
154 static auto MakeAGridDescriptor_G_M_K(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
155 const std::vector<index_t>& a_gs_ms_ks_strides_vec)
156 {
157 return MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).first;
158 }
159 static auto MakeAGridDescriptor_M_K(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
160 const std::vector<index_t>& a_gs_ms_ks_strides_vec)
161 {
162 return matrix_padder.PadADescriptor_M_K(
163 MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).second);
164 }
165
166 template <typename AGridDesc_M_K, typename Number>
167 __host__ __device__ static constexpr auto
168 MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k, const Number& AK1)
169 {
170 const auto M = a_grid_desc_m_k.GetLength(I0);
171 const auto K = a_grid_desc_m_k.GetLength(I1);
172
173 const auto AK0 = K / AK1;
174
175 return transform_tensor_descriptor(a_grid_desc_m_k,
180 }
181
182 //
183 // B (alias of B0)
184 //
185 static auto MakeB0GridDescriptorPair(const std::vector<index_t>& b0_gs_ns_ks_lengths_vec,
186 const std::vector<index_t>& b0_gs_ns_ks_strides_vec)
187 {
188 return MakeGridDescriptorPair<NumDimG, NumDimN, NumDimK, B0Spec>(b0_gs_ns_ks_lengths_vec,
189 b0_gs_ns_ks_strides_vec);
190 }
191
192 // TODO: rename to G_MRaw_NRaw
193 static auto MakeB0GridDescriptor_G_N_K(const std::vector<index_t>& b0_gs_ns_ks_lengths_vec,
194 const std::vector<index_t>& b0_gs_ns_ks_strides_vec)
195 {
196 return MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).first;
197 }
198 static auto MakeB0GridDescriptor_N_K(const std::vector<index_t>& b0_gs_ns_ks_lengths_vec,
199 const std::vector<index_t>& b0_gs_ns_ks_strides_vec)
200 {
201 // alias of matrix_padder.PadB0Descriptor_N_K
202 return matrix_padder.PadBDescriptor_N_K(
203 MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).second);
204 }
205
206 template <typename BGridDesc_N_K, typename Number>
207 __host__ __device__ static constexpr auto
208 MakeB0GridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k, const Number& BK1)
209 {
210 const auto N = b_grid_desc_n_k.GetLength(I0);
211 const auto K = b_grid_desc_n_k.GetLength(I1);
212
213 const auto BK0 = K / BK1;
214
215 return transform_tensor_descriptor(b_grid_desc_n_k,
220 }
221
222 //
223 // B1
224 //
225 static auto MakeB1GridDescriptorPair(const std::vector<index_t>& b1_gs_os_ns_lengths_vec,
226 const std::vector<index_t>& b1_gs_os_ns_strides_vec)
227 {
228 return MakeGridDescriptorPair<NumDimG, NumDimO, NumDimN, B1Spec>(b1_gs_os_ns_lengths_vec,
229 b1_gs_os_ns_strides_vec);
230 }
231
232 // TODO: rename to G_NRaw_KRaw
233 static auto MakeB1GridDescriptor_G_N_K(const std::vector<index_t>& b1_gs_os_ns_lengths_vec,
234 const std::vector<index_t>& b1_gs_os_ns_strides_vec)
235 {
236 return MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).first;
237 }
238 static auto MakeB1GridDescriptor_N_K(const std::vector<index_t>& b1_gs_os_ns_lengths_vec,
239 const std::vector<index_t>& b1_gs_os_ns_strides_vec)
240 {
241 // alias of matrix_padder.PadB1Descriptor_O_N
242 return matrix_padder.PadB1Descriptor_N_K(
243 MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).second);
244 }
245
246 template <typename B1GridDesc_N_K, typename Number>
247 __host__ __device__ static constexpr auto
248 MakeB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K& b1_grid_desc_n_k, const Number& B1K1)
249 {
250 const auto N = b1_grid_desc_n_k.GetLength(I0);
251 const auto K = b1_grid_desc_n_k.GetLength(I1);
252
253 const auto B1K0 = K / B1K1;
254
256 b1_grid_desc_n_k,
261 }
262
263 //
264 // C
265 //
266 static auto MakeCGridDescriptorPair(const std::vector<index_t>& c_gs_ms_os_lengths_vec,
267 const std::vector<index_t>& c_gs_ms_os_strides_vec)
268 {
269 return MakeGridDescriptorPair<NumDimG, NumDimM, NumDimO, CSpec>(c_gs_ms_os_lengths_vec,
270 c_gs_ms_os_strides_vec);
271 }
272
273 // TODO: rename to G_MRaw_NRaw
274 static auto MakeCGridDescriptor_G_M_N(const std::vector<index_t>& c_gs_ms_os_lengths_vec,
275 const std::vector<index_t>& c_gs_ms_os_strides_vec)
276 {
277 return MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).first;
278 }
279 static auto MakeCGridDescriptor_M_N(const std::vector<index_t>& c_gs_ms_os_lengths_vec,
280 const std::vector<index_t>& c_gs_ms_os_strides_vec)
281 {
282 return matrix_padder.PadCDescriptor_M_N(
283 MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).second);
284 }
285};
286
287} // namespace tensor_operation
288} // namespace ck
TensorSpecialization
Definition tensor_specialization.hpp:11
@ Packed
Definition tensor_specialization.hpp:13
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto get_container_subset(const Array< T, N > &arr, Sequence< Is... >)
Definition utility/container_helper.hpp:346
__host__ __device__ constexpr auto container_reduce(const Container &x, Reduce reduce, Init init, Number< IBegin >=Number< 0 >{}, Number< IEnd >=Number< Container::Size()>{}, Number< IStep >=Number< 1 >{})
Definition utility/container_helper.hpp:111
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
Definition utility/sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition utility/sequence.hpp:271
static auto MakeB0GridDescriptor_N_K(const std::vector< index_t > &b0_gs_ns_ks_lengths_vec, const std::vector< index_t > &b0_gs_ns_ks_strides_vec)
Definition transform_contraction_to_gemm.hpp:198
static auto MakeAGridDescriptor_G_M_K(const std::vector< index_t > &a_gs_ms_ks_lengths_vec, const std::vector< index_t > &a_gs_ms_ks_strides_vec)
Definition transform_contraction_to_gemm.hpp:154
static auto MakeB0GridDescriptor_G_N_K(const std::vector< index_t > &b0_gs_ns_ks_lengths_vec, const std::vector< index_t > &b0_gs_ns_ks_strides_vec)
Definition transform_contraction_to_gemm.hpp:193
__host__ static __device__ constexpr auto MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k, const Number &AK1)
Definition transform_contraction_to_gemm.hpp:168
static auto MakeB1GridDescriptorPair(const std::vector< index_t > &b1_gs_os_ns_lengths_vec, const std::vector< index_t > &b1_gs_os_ns_strides_vec)
Definition transform_contraction_to_gemm.hpp:225
__host__ static __device__ constexpr auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K &b1_grid_desc_n_k, const Number &B1K1)
Definition transform_contraction_to_gemm.hpp:248
static auto MakeB0GridDescriptorPair(const std::vector< index_t > &b0_gs_ns_ks_lengths_vec, const std::vector< index_t > &b0_gs_ns_ks_strides_vec)
Definition transform_contraction_to_gemm.hpp:185
static auto MakeAGridDescriptor_M_K(const std::vector< index_t > &a_gs_ms_ks_lengths_vec, const std::vector< index_t > &a_gs_ms_ks_strides_vec)
Definition transform_contraction_to_gemm.hpp:159
static auto MakeAGridDescriptorPair(const std::vector< index_t > &a_gs_ms_ks_lengths_vec, const std::vector< index_t > &a_gs_ms_ks_strides_vec)
Definition transform_contraction_to_gemm.hpp:146
static auto MakeCGridDescriptor_G_M_N(const std::vector< index_t > &c_gs_ms_os_lengths_vec, const std::vector< index_t > &c_gs_ms_os_strides_vec)
Definition transform_contraction_to_gemm.hpp:274
static auto MakeB1GridDescriptor_G_N_K(const std::vector< index_t > &b1_gs_os_ns_lengths_vec, const std::vector< index_t > &b1_gs_os_ns_strides_vec)
Definition transform_contraction_to_gemm.hpp:233
static auto MakeB1GridDescriptor_N_K(const std::vector< index_t > &b1_gs_os_ns_lengths_vec, const std::vector< index_t > &b1_gs_os_ns_strides_vec)
Definition transform_contraction_to_gemm.hpp:238
static auto MakeCGridDescriptorPair(const std::vector< index_t > &c_gs_ms_os_lengths_vec, const std::vector< index_t > &c_gs_ms_os_strides_vec)
Definition transform_contraction_to_gemm.hpp:266
static auto MakeCGridDescriptor_M_N(const std::vector< index_t > &c_gs_ms_os_lengths_vec, const std::vector< index_t > &c_gs_ms_os_strides_vec)
Definition transform_contraction_to_gemm.hpp:279
__host__ static __device__ constexpr auto MakeB0GridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k, const Number &BK1)
Definition transform_contraction_to_gemm.hpp:208
Definition matrix_padder.hpp:63