reference_elementwise.hpp Source File

reference_elementwise.hpp Source File#

Composable Kernel: reference_elementwise.hpp Source File
reference_elementwise.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8#include <thread>
9
10namespace ck_tile {
11template <typename ADataType, typename BDataType, typename ComputeDataType, typename ElementOp>
14 ElementOp element_op)
15{
16 // TODO: imeplement gpu version reference function
17 auto f = [&](auto i) {
18 auto v_a = type_convert<ComputeDataType>(a.mData[i]);
19 auto v_b = element_op(v_a);
21 };
22
23 make_ParallelTensorFunctor(f, b.get_element_space_size())(std::thread::hardware_concurrency());
24}
25
26template <typename ADataType,
27 typename BDataType,
28 typename CDataType,
29 typename ComputeDataType,
30 typename ElementOp>
32 const HostTensor<BDataType>& b,
34 ElementOp element_op)
35{
36 // TODO: imeplement gpu version reference function
37 auto f = [&](auto i) {
38 auto v_a = type_convert<ComputeDataType>(a.mData[i]);
39 auto v_b = type_convert<ComputeDataType>(b.mData[i]);
40 auto v_c = element_op(v_a, v_b);
42 };
43
44 make_ParallelTensorFunctor(f, c.get_element_space_size())(std::thread::hardware_concurrency());
45}
46
47} // namespace ck_tile
#define CK_TILE_HOST
Definition config.hpp:40
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition tile/host/host_tensor.hpp:329
CK_TILE_HOST void reference_binary_elementwise(const HostTensor< ADataType > &a, const HostTensor< BDataType > &b, HostTensor< CDataType > &c, ElementOp element_op)
Definition reference_elementwise.hpp:31
CK_TILE_HOST void reference_unary_elementwise(const HostTensor< ADataType > &a, HostTensor< BDataType > &b, ElementOp element_op)
Definition reference_elementwise.hpp:12
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
Definition tile/host/host_tensor.hpp:336
std::size_t get_element_space_size() const
Definition tile/host/host_tensor.hpp:400
Data mData
Definition tile/host/host_tensor.hpp:801