18template <
typename ThreadGroup,
19 typename ElementwiseOperation,
20 typename SliceLengths,
21 typename ThreadClusterLengths,
22 typename ThreadClusterArrangeOrder,
27 typename DimAccessOrder,
30 bool ThreadTransferSrcResetCoordinateAfterRun,
31 bool ThreadTransferDstResetCoordinateAfterRun>
41 const SrcDesc& src_desc,
42 const Index& src_block_slice_origin,
43 const DstDesc& dst_desc,
44 const Index& dst_block_slice_origin,
45 const ElementwiseOperation& element_op)
46 : threadwise_transfer_(src_desc,
55 nDim == ThreadClusterLengths::Size() &&
56 nDim == ThreadClusterArrangeOrder::Size() &&
57 nDim == DimAccessOrder::Size(),
58 "wrong! nDim not consistent");
62 "wrong! threads should be mapped to cover entire slicing window");
64 static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
65 "wrong! ThreadGroup::GetNumOfThread() too small");
67 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
68 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
70 const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
71 make_multi_index(ThreadGroup::GetThreadId()));
73 const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
75 threadwise_transfer_.SetSrcSliceOrigin(src_desc,
76 src_block_slice_origin + thread_data_idx_begin);
77 threadwise_transfer_.SetDstSliceOrigin(dst_desc,
78 dst_block_slice_origin + thread_data_idx_begin);
82 template <
typename SrcBuffer,
typename DstBuffer, InMemoryDataOperationEnum DstInMemOp>
83 __device__
void Run(
const SrcDesc& src_desc,
84 const SrcBuffer& src_buf,
85 const DstDesc& dst_desc,
88 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
89 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
92 src_desc, src_buf, dst_desc, dst_buf);
98 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
99 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
101 threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
107 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
108 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
110 threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
116 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
117 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
119 const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
124 threadwise_transfer_.SetSrcSliceOrigin(src_desc,
125 src_block_slice_origin + thread_data_idx_begin);
131 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
132 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
134 const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
139 threadwise_transfer_.SetDstSliceOrigin(dst_desc,
140 dst_block_slice_origin + thread_data_idx_begin);
145 static constexpr auto thread_cluster_desc_ =
148 using ThreadwiseTransfer =
149 ThreadwiseTensorSliceTransfer_v6r1r2<SrcData,
153 ElementwiseOperation,
158 ThreadTransferSrcResetCoordinateAfterRun,
159 ThreadTransferDstResetCoordinateAfterRun>;
161 ThreadwiseTransfer threadwise_transfer_;
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
__host__ __device__ constexpr auto make_zero_multi_index()
Definition array_multi_index.hpp:21
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
__device__ void MoveDstSliceWindow(const DstDesc &dst_desc, const Index &step)
Definition thread_group_tensor_slice_transfer_v6r1r2.hpp:105
static constexpr auto thread_slice_lengths
Definition thread_group_tensor_slice_transfer_v6r1r2.hpp:36
__device__ void SetSrcSliceOrigin(const SrcDesc &src_desc, const Index &src_block_slice_origin)
Definition thread_group_tensor_slice_transfer_v6r1r2.hpp:114
static constexpr index_t nDim
Definition thread_group_tensor_slice_transfer_v6r1r2.hpp:34
MultiIndex< nDim > Index
Definition thread_group_tensor_slice_transfer_v6r1r2.hpp:38
__device__ constexpr ThreadGroupTensorSliceTransfer_v6r1r2(const SrcDesc &src_desc, const Index &src_block_slice_origin, const DstDesc &dst_desc, const Index &dst_block_slice_origin, const ElementwiseOperation &element_op)
Definition thread_group_tensor_slice_transfer_v6r1r2.hpp:40
__device__ void SetDstSliceOrigin(const DstDesc &dst_desc, const Index &dst_block_slice_origin)
Definition thread_group_tensor_slice_transfer_v6r1r2.hpp:129
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition thread_group_tensor_slice_transfer_v6r1r2.hpp:83
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &step)
Definition thread_group_tensor_slice_transfer_v6r1r2.hpp:96