threadwise_tensor_slice_transfer_v7r3.hpp Source File

threadwise_tensor_slice_transfer_v7r3.hpp Source File#

Composable Kernel: threadwise_tensor_slice_transfer_v7r3.hpp Source File
threadwise_tensor_slice_transfer_v7r3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
12
14
15namespace ck {
16// Thread-level multi-source, multi-destination tensor slice data movement
17// Assume:
18// 1. All sources and destinations are DynamicBuffer
19// 2. Same VectorDim and ScalerPerVector for all sources and destinations
20// 3. DstInMemOps are per destination tensor
21// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor
22// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor
23// 6. Does not need to know src_descs and dst_descs at compile-time
24// 7. Does not need to know src_slice_origins and dst_slice_origins at compile-time,
25//
26// Does following things to avoid scratch memory issue
27// 1. Use StaticallyIndexedArray or vector_type instead of C array for thread buffer
28// 2. Pass tensor descritpors by reference (or tuple of references)
29// 3. Does not keep reference to tensor descriptor
30// 4. Does not construct new tensor coordinate when call Run()
31template <typename SrcDatas,
32 typename DstDatas,
33 typename SrcDescs,
34 typename DstDescs,
35 typename ElementwiseOperation,
36 typename DstInMemOps, // Sequence<InMemoryDataOperationEnum ...>
37 typename SliceLengths,
38 typename SrcDimAccessOrder,
39 typename DstDimAccessOrder,
40 index_t SrcVectorDim,
41 index_t DstVectorDim,
42 typename SrcScalarPerVectors,
43 index_t DstScalarPerVector,
44 typename SrcResetCoordinateAfterRunFlags, // Sequence<bool ...>
45 typename DstResetCoordinateAfterRunFlags, // Sequence<bool ...>
46 index_t NumThreadScratch = 1,
47 typename InterDatas = DstDatas>
49{
50 static constexpr auto I0 = Number<0>{};
51
52 static constexpr index_t nDim = SliceLengths::Size();
53
54 static constexpr index_t nSrc = SrcDescs::Size();
55 static constexpr index_t nDst = DstDescs::Size();
56
58
59 // return a tuple of coordiantes for a tuple of tensor
60 template <typename Descs,
61 typename Indices,
62 enable_if_t<Descs::Size() == Indices::Size(), bool> = false>
63 static constexpr auto MakeCoordinates(const Descs& descs, const Indices& indices)
64 {
65 return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); },
66 Number<Descs::Size()>{});
67 }
68
69 static constexpr auto SrcScalarPerVector =
70 reduce_on_sequence(SrcScalarPerVectors{},
72 Number<1>{}); // GetMinSrcScalarPerVector(); SrcScalarPerVectors{}[I0];
75
76 // scalar per access on each dim
77 // FIXME: don't use lambda_scalar_per_access
80
83
85 SrcDimAccessOrder,
87 false>;
88
90 DstDimAccessOrder,
92 false>;
93
95 const SrcDescs& src_descs,
96 const StaticallyIndexedArray<Index, nSrc>& src_slice_origins,
97 const DstDescs& dst_descs,
98 const StaticallyIndexedArray<Index, nDst>& dst_slice_origins,
99 const ElementwiseOperation& element_op)
100 : src_coords_(MakeCoordinates(src_descs, src_slice_origins)),
101 dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)),
102 element_op_(element_op)
103 {
104 static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
105 "wrong! cannot evenly divide");
106
107 static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
108 "wrong! cannot evenly divide");
109 }
110
111 template <typename Indices, enable_if_t<SrcDescs::Size() == Indices::Size(), bool> = false>
112 __device__ void SetSrcSliceOrigins(const SrcDescs& src_descs,
113 const Indices& src_slice_origin_idxs)
114 {
115 static_for<0, nSrc, 1>{}([&](auto i) {
116 src_coords_(i) = make_tensor_coordinate(src_descs[i], src_slice_origin_idxs[i]);
117 });
118 }
119
120 template <typename Indices, enable_if_t<DstDescs::Size() == Indices::Size(), bool> = false>
121 __device__ void SetDstSliceOrigins(const DstDescs& dst_descs,
122 const Indices& dst_slice_origin_idxs)
123 {
124 static_for<0, nDst, 1>{}([&](auto i) {
125 dst_coords_(i) = make_tensor_coordinate(dst_descs[i], dst_slice_origin_idxs[i]);
126 });
127 }
128
129 template <typename DataTypes, index_t ScalarPerVector>
130 __device__ static auto generate_vectors()
131 {
132 auto data_types = DataTypes{};
133
134 constexpr index_t num = data_types.Size();
135
136 return generate_tuple(
137 [&](auto i) {
138 using DataType = remove_cvref_t<decltype(data_types[i])>;
139
141 },
142 Number<num>{});
143 }
144
145 // SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
146 // SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
147 template <typename SrcBuffers,
148 index_t ThreadScratchId = 0,
149 enable_if_t<SrcDescs::Size() == SrcBuffers::Size(), bool> = false>
150 __device__ void RunRead(const SrcDescs& src_descs,
151 const SrcBuffers& src_bufs,
153 {
154 // loop over space-filling curve
155 static_for<0, src_num_access, 1>{}([&](auto iAccess) {
158
159 bool oob_val = true;
160
161 // copy data from src_bufs into src_vectors
162 static_for<0, nSrc, 1>{}([&](auto i) {
163 using src_vector_t = typename remove_cvref_t<decltype(src_vectors[i])>::type;
164
165 const bool is_src_valid =
167 src_coords_[i]);
168
169 oob_val = oob_val & is_src_valid;
170
171 // TODO: With column-major matrices this step restricts the transferred tensor slice
172 // to just one element, which consequently prevents using atomic operations if the
173 // matrix data type is on 16 bits.
174 if constexpr(SrcScalarPerVectors{}[i] == 1)
175 {
176 auto data_types = SrcDatas{};
177 using DataType = remove_cvref_t<decltype(data_types[i])>;
178 const auto tmp =
179 src_bufs[i].template Get<DataType>(src_coords_[i].GetOffset(), true);
180
181 static_for<0, SrcScalarPerVector, 1>{}(
182 [&](auto j) { src_vectors(i).template AsType<DataType>()(j) = tmp; });
183 }
184 else
185 {
186 src_vectors(i).template AsType<src_vector_t>()(I0) =
187 src_bufs[i].template Get<src_vector_t>(src_coords_[i].GetOffset(), true);
188 }
189 });
190
191 constexpr auto get_elem_op_vec_len = []() {
192 if constexpr(is_detected<is_pack8_invocable_t, decltype(element_op_)>::value)
193 {
194 if constexpr(decltype(element_op_)::is_pack8_invocable)
195 return math::min(8, SrcScalarPerVector);
196 }
197 if constexpr(is_detected<is_pack4_invocable_t, decltype(element_op_)>::value)
198 {
199 if constexpr(decltype(element_op_)::is_pack4_invocable)
200 return math::min(4, SrcScalarPerVector);
201 }
202 if constexpr(is_detected<is_pack2_invocable_t, decltype(element_op_)>::value)
203 {
204 if constexpr(decltype(element_op_)::is_pack2_invocable)
205 return math::min(2, SrcScalarPerVector);
206 }
207 return 1;
208 };
209
210 constexpr index_t elem_op_vec_len = get_elem_op_vec_len();
211
212 // apply pointwise function
213 static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto i) {
214 // get reference to src data
215 const auto src_data_refs = generate_tie(
216 // return type should be lvalue
217 [&](auto iSrc) -> const auto& {
218 using SrcData = remove_cvref_t<tuple_element_t<iSrc.value, SrcDatas>>;
219
220 using elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type;
221
222 return src_vectors[iSrc].template AsType<elem_op_vec_t>()[i];
223 },
224 Number<nSrc>{});
225
226 // get reference to dst data
227 auto dst_data_refs = generate_tie(
228 // return type should be lvalue
229 [&](auto iDst) -> auto& {
230 using InterData = remove_cvref_t<tuple_element_t<iDst.value, InterDatas>>;
231
232 using elem_op_vec_t =
233 typename vector_type<InterData, elem_op_vec_len>::type;
234
235 return elm_vectors(iDst).template AsType<elem_op_vec_t>()(i);
236 },
237 Number<nDst>{});
238
239 // apply pointwise function
240 // pointwise function signature:
241 // element_op_(dst_data_refs[I0],
242 // dst_data_refs[I1],
243 // ...,
244 // src_data_refs[I0],
245 // src_data_refs[I1],
246 // ...)
247 unpack2(element_op_, dst_data_refs, src_data_refs);
248 });
249
250 elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors;
251 oob_vectors_tuple_(thread_scratch_id)(iAccess) = oob_val;
252
253 // move coordinate
254 if constexpr(iAccess.value != src_num_access - 1)
255 {
256 constexpr auto forward_step = SrcSpaceFillingCurve::GetForwardStep(iAccess);
257
258 static_for<0, nSrc, 1>{}([&](auto i) {
259 move_tensor_coordinate(src_descs[i],
260 src_coords_(i),
261 make_tensor_coordinate_step(src_descs[i], forward_step));
262 });
263 }
264 });
265
266 // move coordinate back to slice origin (or not)
267 static_for<0, nSrc, 1>{}([&](auto i) {
268 if constexpr(SrcResetCoordinateAfterRunFlags::At(i))
269 {
270 const auto src_reset_step =
272
273 move_tensor_coordinate(src_descs[i], src_coords_(i), src_reset_step);
274 }
275 });
276 }
277
278#if 1
279 template <index_t ThreadScratchId = 0>
280 __device__ void OOBCheck(Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
281 {
282 // loop over space-filling curve
283 static_for<0, src_num_access, 1>{}([&](auto iAccess) {
284 auto elm_vectors = elm_vectors_tuple_[thread_scratch_id][iAccess];
285 auto oob_val = oob_vectors_tuple_[thread_scratch_id][iAccess];
286
287 static_for<0, nDst, 1>{}([&](auto i) {
288 using elm_vector_t = typename remove_cvref_t<decltype(elm_vectors[i])>::type;
289 elm_vectors(i).template AsType<elm_vector_t>()(I0) =
290 oob_val ? elm_vectors(i).template AsType<elm_vector_t>()[I0] : elm_vector_t{0};
291 });
292
293 elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors;
294 });
295 }
296#endif
297
298 template <index_t ThreadScratchId = 0>
299 __device__ void
301 {
302 using InterData = remove_cvref_t<decltype(InterDatas{}[I0])>;
303
304 using ElmThreadScratch =
305 StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
306 InterData,
309 true>;
310 using DstThreadScratch =
311 StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
312 InterData,
313 DstScalarPerVector,
315 true>;
316
317 ElmThreadScratch elm_thread_scratch_;
318 DstThreadScratch dst_thread_scratch_;
319
320 elm_thread_scratch_.data_ =
321 bit_cast<decltype(elm_thread_scratch_.data_)>(elm_vectors_tuple_[thread_scratch_id]);
322
323 if constexpr(SrcVectorDim != DstVectorDim &&
324 ((is_same<half_t, remove_cvref_t<InterData>>::value &&
325 SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) ||
326 (is_same<f8_t, remove_cvref_t<InterData>>::value &&
327 SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0) ||
328 (is_same<int8_t, remove_cvref_t<InterData>>::value &&
329 SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
330 {
331 // each transpose does
332 // DstScalarPerVector # of src vectors in src_thread_scratch_
333 // SrcScalarPerVector # of dst vectors in dst_thread_scratch_
334 constexpr index_t num_src_vector = Number<DstScalarPerVector>{};
335 constexpr index_t num_dst_vector = Number<SrcScalarPerVector>{};
336
337 // Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose
338 // TODO: make this logic generic for all scenario
339
340 constexpr auto src_scalar_step_in_vector = generate_sequence(
341 detail::lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
342
343 constexpr auto dst_scalar_step_in_vector = generate_sequence(
344 detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
345
346 constexpr auto scalar_per_access = generate_sequence(
347 detail::lambda_scalar_per_access_for_src_and_dst<SrcVectorDim,
349 DstVectorDim,
350 DstScalarPerVector>{},
351 Number<nDim>{});
352
353 constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
354
355 static_ford<decltype(access_lengths)>{}([&](auto access_idx) {
356 constexpr auto data_idx = access_idx * scalar_per_access;
357
358 constexpr auto data_idx_seq = generate_sequence_v2(
359 [&](auto i) { return Number<data_idx[i]>{}; }, Number<nDim>{});
360
363
364 // get DstScalarPerVector # of read-only references to src vectors from
365 // src_thread_scratch_
366 const auto src_vector_refs = generate_tie(
367 [&](auto i) -> const src_vector_t& {
368 // i increment corresponds to movement in DstVectorDim
369 return elm_thread_scratch_.GetVectorTypeReference(
370 data_idx_seq + i * dst_scalar_step_in_vector);
371 },
373
374 // get SrcScalarPerVector # of references to dst vectors from
375 // dst_thread_scratch_
376 auto dst_vector_refs = generate_tie(
377 [&](auto i) -> dst_vector_t& {
378 // i increment corresponds to movement in SrcVectorDim
379 return dst_thread_scratch_.GetVectorTypeReference(
380 data_idx_seq + i * src_scalar_step_in_vector);
381 },
383
384 // do data transpose
385 transpose_vectors<InterData, DstScalarPerVector, SrcScalarPerVector>{}(
386 src_vector_refs, dst_vector_refs);
387 });
388 }
389 else
390 {
391 static_ford<SliceLengths>{}(
392 [&](auto idx) { dst_thread_scratch_(idx) = elm_thread_scratch_[idx]; });
393 }
394
395 dst_vectors_tuple_(thread_scratch_id) = bit_cast<DstVectorTuple>(dst_thread_scratch_.data_);
396 }
397
398 // DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
399 // DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
400 // DstVgprDescs: Tuple<const DstVgprDesc0&, const DstVgprDesc1&, ...>
401 // DstVgprBuffers: Tuple<DstVgprBuffer0&, DstVgprBuffer1&, ...>
402 template <typename DstBuffers,
403 typename DstVgprDescs,
404 typename DstVgprBuffers,
405 index_t ThreadScratchId = 0,
406 enable_if_t<DstDescs::Size() == 1 && DstBuffers::Size() == 1, bool> = false>
407 __device__ void
408 RunWriteAndStoreVgpr(const DstDescs& dst_descs,
409 DstBuffers dst_bufs,
410 const DstVgprDescs&,
411 DstVgprBuffers dst_vgpr_buf,
413 {
414 // Same functionality of RunWrite but additionally store internal Vgpr in dst_vgpr_buf
415 OOBCheck(thread_scratch_id);
416 TransposeFromElmToDst(thread_scratch_id);
417
418 // Vgpr buffer origin is set internally to 0
419 constexpr auto dst_slice_origin_idx =
420 generate_tuple([&](auto) { return I0; }, Number<nDim>{});
421 constexpr auto dst_scalar_step_in_vector =
422 generate_sequence(detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
423
424 // loop over space-filling curve
425 static_for<0, dst_num_access, 1>{}([&](auto iAccess) {
426 auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess];
427
428 static_for<0, nDst, 1>{}([&](auto i) {
429 // copy data from buf_vectors into dst_bufs
430 using DstData = remove_cvref_t<decltype(DstDatas{}[i])>;
431 using InterData = remove_cvref_t<decltype(InterDatas{}[i])>;
432
434 using dst_vector_t =
435 typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
436
437 static_for<0, DstScalarPerVector, 1>{}([&](auto j) {
438 dst_vector.template AsType<DstData>()(j) =
439 type_convert<DstData>(dst_vectors[i].template AsType<InterData>()[j]);
440 });
441
442 const bool is_dst_valid =
444 dst_coords_[i]);
445
446 constexpr InMemoryDataOperationEnum DstInMemOp =
447 static_cast<InMemoryDataOperationEnum>(DstInMemOps::At(i.value));
448
449 dst_bufs(i).template Update<DstInMemOp, dst_vector_t>(
450 dst_coords_[i].GetOffset(),
451 is_dst_valid,
452 dst_vector.template AsType<dst_vector_t>()[I0]);
453
454 // store Vgpr
455 using DstVgprDesc = remove_cvref_t<decltype(DstVgprDescs{}.At(i))>;
456 static_assert(DstVgprDesc::IsKnownAtCompileTime(),
457 "wrong! DstDesc need to known at compile-time");
458 constexpr auto dst_vgpr_desc = DstVgprDesc{};
459
460 constexpr auto src_data_idx = DstSpaceFillingCurve::GetIndex(iAccess);
461 static_for<0, DstScalarPerVector, 1>{}([&](auto j) {
462 constexpr index_t dst_offset =
463 dst_vgpr_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) +
464 src_data_idx + j * dst_scalar_step_in_vector);
465
466 dst_vgpr_buf(I0)(Number<dst_offset>{}) =
467 is_dst_valid ? dst_vectors[i].template AsType<InterData>()[j]
469 });
470 });
471
472 // move coordinate
473 if constexpr(iAccess.value != dst_num_access - 1)
474 {
475 constexpr auto forward_step = DstSpaceFillingCurve::GetForwardStep(iAccess);
476
477 static_for<0, nDst, 1>{}([&](auto i) {
478 move_tensor_coordinate(dst_descs[i],
479 dst_coords_(i),
480 make_tensor_coordinate_step(dst_descs[i], forward_step));
481 });
482 }
483 });
484
485 static_for<0, nDst, 1>{}([&](auto i) {
486 if constexpr(DstResetCoordinateAfterRunFlags::At(i))
487 {
488 const auto dst_reset_step =
490
491 move_tensor_coordinate(dst_descs[i], dst_coords_(i), dst_reset_step);
492 }
493 });
494 }
495
496 // DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
497 // DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
498 template <typename DstBuffers,
499 index_t ThreadScratchId = 0,
500 enable_if_t<DstDescs::Size() == 1 && DstBuffers::Size() == 1, bool> = false>
501 __device__ void RunWrite(const DstDescs& dst_descs,
502 DstBuffers dst_bufs,
504 {
506 "RunWrite doesn't support inter data type different from dst data type");
507
508 OOBCheck(thread_scratch_id);
509 TransposeFromElmToDst(thread_scratch_id);
510
511 // loop over space-filling curve
512 static_for<0, dst_num_access, 1>{}([&](auto iAccess) {
513 auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess];
514
515 // copy data from buf_vectors into dst_bufs
516 static_for<0, nDst, 1>{}([&](auto i) {
517 using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
518
519 const bool is_dst_valid =
521 dst_coords_[i]);
522
523 constexpr InMemoryDataOperationEnum DstInMemOp =
524 static_cast<InMemoryDataOperationEnum>(DstInMemOps::At(i.value));
525
526 dst_bufs(i).template Update<DstInMemOp, dst_vector_t>(
527 dst_coords_[i].GetOffset(),
528 is_dst_valid,
529 dst_vectors[i].template AsType<dst_vector_t>()[I0]);
530 });
531
532 // move coordinate
533 if constexpr(iAccess.value != dst_num_access - 1)
534 {
535 constexpr auto forward_step = DstSpaceFillingCurve::GetForwardStep(iAccess);
536
537 static_for<0, nDst, 1>{}([&](auto i) {
538 move_tensor_coordinate(dst_descs[i],
539 dst_coords_(i),
540 make_tensor_coordinate_step(dst_descs[i], forward_step));
541 });
542 }
543 });
544
545 static_for<0, nDst, 1>{}([&](auto i) {
546 if constexpr(DstResetCoordinateAfterRunFlags::At(i))
547 {
548 const auto dst_reset_step =
550
551 move_tensor_coordinate(dst_descs[i], dst_coords_(i), dst_reset_step);
552 }
553 });
554 }
555
556 // SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
557 // SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
558 // DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
559 // DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
560 template <typename SrcBuffers,
561 typename DstBuffers,
562 enable_if_t<SrcDescs::Size() == SrcBuffers::Size() &&
563 DstDescs::Size() == DstBuffers::Size(),
564 bool> = false>
565 __device__ void Run(const SrcDescs& src_descs,
566 const SrcBuffers& src_bufs,
567 const DstDescs& dst_descs,
568 DstBuffers dst_bufs)
569 {
570 RunRead(src_descs, src_bufs);
571 RunWrite(dst_descs, dst_bufs);
572 }
573
574 __device__ static constexpr auto GetSrcCoordinateResetStep()
575 {
576 if constexpr(src_num_access == 0)
577 {
578 return typename SrcSpaceFillingCurve::Index{};
579 }
580 else
581 {
583 }
584 }
585
586 __device__ static constexpr auto GetDstCoordinateResetStep()
587 {
588 if constexpr(dst_num_access == 0)
589 {
590 return typename DstSpaceFillingCurve::Index{};
591 }
592 else
593 {
595 }
596 }
597
598 __device__ static constexpr auto GetSrcThreadScratchDescriptor()
599 {
600 // constexpr auto src_scalar_per_access = generate_sequence(
601 // detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{},
602 // Number<nDim>{});
603
604 constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
605
606 constexpr auto src_access_lengths_and_vector_length = container_push_back(
608
609 // 1st stage of transforms
610 constexpr auto desc0 =
611 make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length);
612
613 // 2nd stage of transforms
614 constexpr auto transforms = generate_tuple(
615 [&](auto i) {
616 if constexpr(i == SrcVectorDim)
617 {
619 make_tuple(src_access_lengths_and_vector_length[i],
620 src_access_lengths_and_vector_length[Number<nDim>{}]));
621 }
622 else
623 {
624 return make_pass_through_transform(src_access_lengths_and_vector_length[i]);
625 }
626 },
627 Number<nDim>{});
628
629 constexpr auto low_dim_idss = generate_tuple(
630 [&](auto i) {
631 if constexpr(i == SrcVectorDim)
632 {
633 return Sequence<i.value, nDim>{};
634 }
635 else
636 {
637 return Sequence<i.value>{};
638 }
639 },
640 Number<nDim>{});
641
642 constexpr auto up_dim_idss =
643 generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
644
645 return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
646 }
647
648 __device__ static constexpr auto GetDstThreadScratchDescriptor()
649 {
650 // 1st stage of transforms
651 // constexpr auto dst_scalar_per_access = generate_sequence(
652 // detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{},
653 // Number<nDim>{});
654
655 constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
656
657 constexpr auto dst_access_lengths_and_vector_length = container_push_back(
659
660 constexpr auto desc0 =
661 make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length);
662
663 // 2nd stage of transforms
664 constexpr auto transforms = generate_tuple(
665 [&](auto i) {
666 if constexpr(i == DstVectorDim)
667 {
669 make_tuple(dst_access_lengths_and_vector_length[i],
670 dst_access_lengths_and_vector_length[Number<nDim>{}]));
671 }
672 else
673 {
674 return make_pass_through_transform(dst_access_lengths_and_vector_length[i]);
675 }
676 },
677 Number<nDim>{});
678
679 constexpr auto low_dim_idss = generate_tuple(
680 [&](auto i) {
681 if constexpr(i == DstVectorDim)
682 {
683 return Sequence<i.value, nDim>{};
684 }
685 else
686 {
687 return Sequence<i.value>{};
688 }
689 },
690 Number<nDim>{});
691
692 constexpr auto up_dim_idss =
693 generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
694
695 return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
696 }
697
698 // src_slice_origin_step_idx need to be known at compile-time, for performance reason
699 template <index_t ISrc>
700 __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs,
701 Number<ISrc> iSrc,
702 const Index& src_slice_origin_step_idx)
703 {
704 // if src coord was not reset by RunRead(), then need to adjust the step here
705 const auto adjusted_step_idx =
706 SrcResetCoordinateAfterRunFlags::At(iSrc)
707 ? src_slice_origin_step_idx
708 : src_slice_origin_step_idx + GetSrcCoordinateResetStep();
709
710 // is it OK to construct a new step every time?
711 const auto adjusted_step = make_tensor_coordinate_step(src_descs[iSrc], adjusted_step_idx);
712
713 move_tensor_coordinate(src_descs[iSrc], src_coords_(iSrc), adjusted_step);
714 }
715
716 // dst_slice_origin_step_idx need to be known at compile-time, for performance reason
717 template <index_t IDst>
718 __device__ void MoveDstSliceWindow(const DstDescs& dst_descs,
719 Number<IDst> iDst,
720 const Index& dst_slice_origin_step_idx)
721 {
722 // if dst coord was not reset by Run(), then need to adjust the step here
723 const auto adjusted_step_idx =
724 DstResetCoordinateAfterRunFlags::At(iDst)
725 ? dst_slice_origin_step_idx
726 : dst_slice_origin_step_idx + GetDstCoordinateResetStep();
727
728 // is it OK to construct a new step every time?
729 const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx);
730
731 move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step);
732 }
733
734 private:
735 using SrcVectorsType = decltype(generate_vectors<SrcDatas, SrcScalarPerVector>());
736 using ElmVectorsType = decltype(generate_vectors<InterDatas, SrcScalarPerVector>());
737 using DstVectorsType = decltype(generate_vectors<InterDatas, DstScalarPerVector>());
738
739 static constexpr auto src_num_access = SrcSpaceFillingCurve::GetNumOfAccess();
740 static constexpr auto dst_num_access = DstSpaceFillingCurve::GetNumOfAccess();
741
744
747
750
751 SrcCoords src_coords_;
752 DstCoords dst_coords_;
753 const ElementwiseOperation element_op_;
754};
755
756} // namespace ck
__host__ __device__ constexpr T min(T x)
Definition utility/math.hpp:116
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
decltype(ck::declval< T & >().is_pack8_invocable) is_pack8_invocable_t
Definition is_detected.hpp:43
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc &, const VisibleIndex &idx_diff_visible, UpdateLowerIndexHack)
Definition tensor_description/tensor_descriptor.hpp:444
__host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const TensorCoordStep &coord_step)
Definition tensor_description/tensor_descriptor.hpp:508
__host__ __device__ constexpr auto container_push_back(const Array< TData, NSize > &a, const TData &x)
Definition utility/container_helper.hpp:18
InMemoryDataOperationEnum
Definition ck.hpp:277
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
decltype(ck::declval< T & >().is_pack4_invocable) is_pack4_invocable_t
Definition is_detected.hpp:40
__host__ __device__ constexpr bool coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc &tensor_desc, const TensorCoord &coord)
Definition tensor_description/tensor_descriptor.hpp:560
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition is_detected.hpp:34
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence< Is... >)
Definition utility/container_helper.hpp:380
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto generate_sequence_v2(F &&f, Number< N >)
Definition sequence_helper.hpp:25
__host__ __device__ constexpr auto to_multi_index(const T &x)
Definition array_multi_index.hpp:28
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
decltype(ck::declval< T & >().is_pack2_invocable) is_pack2_invocable_t
Definition is_detected.hpp:37
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto generate_sequence(F, Number< N >)
Definition sequence_helper.hpp:18
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr Y bit_cast(const X &x)
Definition type.hpp:306
typename remove_cv< T >::type remove_cv_t
Definition type.hpp:295
__host__ __device__ constexpr auto unpack2(F &&f, X &&x, Y &&y)
Definition functional4.hpp:55
__host__ __device__ constexpr index_t reduce_on_sequence(Seq, Reduce f, Number< Init >)
Definition utility/sequence.hpp:884
typename std::enable_if< B, T >::type enable_if_t
Definition enable_if.hpp:27
__host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc &tensor_desc, const VisibleIndex &idx_visible)
Definition tensor_description/tensor_descriptor.hpp:407
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
typename vector_type_maker< T, N >::type vector_type_maker_t
Definition dtype_vector.hpp:54
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
__host__ static __device__ constexpr T QuietNaN()
Definition numeric_limits.hpp:313
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
static __device__ __host__ constexpr auto GetStepBetween(Number< AccessIdx1dBegin >, Number< AccessIdx1dEnd >)
Definition tensor_space_filling_curve.hpp:52
__host__ static __device__ constexpr index_t GetNumOfAccess()
Definition tensor_space_filling_curve.hpp:41
static __device__ __host__ constexpr Index GetIndex(Number< AccessIdx1d >)
Definition tensor_space_filling_curve.hpp:81
static __device__ __host__ constexpr auto GetForwardStep(Number< AccessIdx1d >)
Definition tensor_space_filling_curve.hpp:66
MultiIndex< nDim > Index
Definition tensor_space_filling_curve.hpp:23
__device__ void MoveSrcSliceWindow(const SrcDescs &src_descs, Number< ISrc > iSrc, const Index &src_slice_origin_step_idx)
Definition threadwise_tensor_slice_transfer_v7r3.hpp:700
__device__ void SetDstSliceOrigins(const DstDescs &dst_descs, const Indices &dst_slice_origin_idxs)
Definition threadwise_tensor_slice_transfer_v7r3.hpp:121
__device__ void Run(const SrcDescs &src_descs, const SrcBuffers &src_bufs, const DstDescs &dst_descs, DstBuffers dst_bufs)
Definition threadwise_tensor_slice_transfer_v7r3.hpp:565
static __device__ constexpr auto GetSrcCoordinateResetStep()
Definition threadwise_tensor_slice_transfer_v7r3.hpp:574
__device__ void RunRead(const SrcDescs &src_descs, const SrcBuffers &src_bufs, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v7r3.hpp:150
__device__ void SetSrcSliceOrigins(const SrcDescs &src_descs, const Indices &src_slice_origin_idxs)
Definition threadwise_tensor_slice_transfer_v7r3.hpp:112
static __device__ constexpr auto GetDstThreadScratchDescriptor()
Definition threadwise_tensor_slice_transfer_v7r3.hpp:648
__device__ void RunWrite(const DstDescs &dst_descs, DstBuffers dst_bufs, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v7r3.hpp:501
static constexpr auto MakeCoordinates(const Descs &descs, const Indices &indices)
Definition threadwise_tensor_slice_transfer_v7r3.hpp:63
__device__ constexpr ThreadwiseTensorSliceTransfer_v7r3(const SrcDescs &src_descs, const StaticallyIndexedArray< Index, nSrc > &src_slice_origins, const DstDescs &dst_descs, const StaticallyIndexedArray< Index, nDst > &dst_slice_origins, const ElementwiseOperation &element_op)
Definition threadwise_tensor_slice_transfer_v7r3.hpp:94
__device__ void RunWriteAndStoreVgpr(const DstDescs &dst_descs, DstBuffers dst_bufs, const DstVgprDescs &, DstVgprBuffers dst_vgpr_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v7r3.hpp:408
__device__ void OOBCheck(Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v7r3.hpp:280
static __device__ constexpr auto GetSrcThreadScratchDescriptor()
Definition threadwise_tensor_slice_transfer_v7r3.hpp:598
__device__ void TransposeFromElmToDst(Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v7r3.hpp:300
static __device__ auto generate_vectors()
Definition threadwise_tensor_slice_transfer_v7r3.hpp:130
static __device__ constexpr auto GetDstCoordinateResetStep()
Definition threadwise_tensor_slice_transfer_v7r3.hpp:586
__device__ void MoveDstSliceWindow(const DstDescs &dst_descs, Number< IDst > iDst, const Index &dst_slice_origin_step_idx)
Definition threadwise_tensor_slice_transfer_v7r3.hpp:718
Definition threadwise_tensor_slice_transfer_util.hpp:20
Definition utility/math.hpp:50
Definition functional2.hpp:33
vector_type< T, N > type
Definition dtype_vector.hpp:31