device_gemm_multiple_d_layernorm.hpp Source File

device_gemm_multiple_d_layernorm.hpp Source File#

Composable Kernel: device_gemm_multiple_d_layernorm.hpp Source File
device_gemm_multiple_d_layernorm.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <array>
7#include "device_base.hpp"
8
9namespace ck {
10namespace tensor_operation {
11namespace device {
12
13// GEMM:
14// input : A[M, K]
15// input : B[N, K]
16// input : D0[M, N], D1[M, N], ...
17// output : E[M, N]
18// output : H[M, N]
19// C = a_op(A) * b_op(B)
20// E = cde_op(C, D0, D1, ...)
21// H = layernorm(E)
22// Assume:
23// D0, D1, ... and E have the same layout
24// Calculate mean & variance along N dimension in layernorm(E)
25template <typename ALayout,
26 typename BLayout,
27 typename DsLayout,
28 typename HLayout,
29 typename ADataType,
30 typename BDataType,
31 typename DsDataType,
32 typename GammaDataType,
33 typename BetaDataType,
34 typename HDataType,
35 typename AElementwiseOperation,
36 typename BElementwiseOperation,
37 typename CDEElementwiseOperation,
38 typename HElementwiseOperation>
40{
41 static constexpr index_t NumDTensor = DsDataType::Size();
42 virtual std::unique_ptr<BaseArgument>
43 MakeArgumentPointer(const void* p_a,
44 const void* p_b,
45 std::array<const void*, NumDTensor> p_ds,
46 const void* p_gamma,
47 const void* p_beta,
48 void* p_h,
49 index_t MRaw,
50 index_t NRaw,
51 index_t KRaw,
52 index_t StrideA,
53 index_t StrideB,
54 std::array<index_t, NumDTensor> StrideDs,
55 index_t StrideH,
56 double epsilon,
57 AElementwiseOperation a_element_op,
58 BElementwiseOperation b_element_op,
59 CDEElementwiseOperation cde_element_op,
60 HElementwiseOperation h_element_op) = 0;
61
62 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
63}; // namespace device
64
65} // namespace device
66} // namespace tensor_operation
67} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
Definition device_gemm_multiple_d_layernorm.hpp:40
static constexpr index_t NumDTensor
Definition device_gemm_multiple_d_layernorm.hpp:41
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, const void *p_gamma, const void *p_beta, void *p_h, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideH, double epsilon, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, HElementwiseOperation h_element_op)=0