block_gemm_areg_bgmem_creg_v1.hpp Source File

block_gemm_areg_bgmem_creg_v1.hpp Source File#

Composable Kernel: block_gemm_areg_bgmem_creg_v1.hpp Source File
block_gemm_areg_bgmem_creg_v1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
9
10namespace ck_tile {
11
12// A is block distributed tensor
13// B is block window on global memory
14// C is block distributed tensor
15// This will:
16// 1. load B from global memory into shared memory and then
17// 2. Call BlockGemmARegSGmemCRegV1
18template <typename Problem_, typename Policy_ = BlockGemmARegBGmemCRegV1DefaultPolicy>
20{
27
28 static constexpr index_t kBlockSize = Problem::kBlockSize;
29
30 // use BlockGemmARegBSmemCRegV1 as the underlying block-GEMM implementation
34
36 {
37 return sizeof(BDataType) *
38 Policy::template MakeBSmemBlockDescriptor<Problem>().get_element_space_size();
39 }
40
41 // C += A * B
42 template <typename CBlockTensor, typename ABlockTensor, typename BBlockGmemWindowTmp>
43 CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
44 const ABlockTensor& a_block_tensor,
45 const BBlockGmemWindowTmp& b_block_gmem_window_tmp,
46 void* smem_ptr) const
47 {
48 static_assert(
49 std::is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
50 std::is_same_v<BDataType, remove_cv_t<typename BBlockGmemWindowTmp::DataType>> &&
51 std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
52 "wrong!");
53
54 constexpr index_t MPerBlock = ABlockTensor{}.get_lengths()[number<0>{}];
55 constexpr index_t NPerBlock = BBlockGmemWindowTmp{}.get_window_lengths()[number<0>{}];
56 constexpr index_t KPerBlock = ABlockTensor{}.get_lengths()[number<1>{}];
57
58 static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
59 KPerBlock == BlockGemmShape::kK,
60 "wrong!");
61
62 const auto b_block_gmem_window =
63 make_tile_window(b_block_gmem_window_tmp.get_bottom_tensor_view(),
65 b_block_gmem_window_tmp.get_window_origin(),
66 Policy::template MakeBGmemTileDistribution<Problem>());
67
68 // B LDS and LDS window
70 reinterpret_cast<BDataType*>(smem_ptr),
71 Policy::template MakeBSmemBlockDescriptor<Problem>());
72
73 auto b_block_smem_window = make_tile_window(
74 b_block_smem, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
75
76 // load B tile from global mem
77 const auto b_block_tile = load_tile(b_block_gmem_window);
78
79 // store B tile into shared mem
80 store_tile(b_block_smem_window, b_block_tile);
81
82 // wait for store_tile to finish
84
85 // block GEMM
86 BlockGemmARegBGmemCRegImpl{}(c_block_tensor, a_block_tensor, b_block_smem_window);
87 }
88
89 // C = A * B
90 template <typename ABlockTensor, typename BBlockGmemWindowTmp>
91 CK_TILE_DEVICE auto operator()(const ABlockTensor& a_block_tensor,
92 const BBlockGmemWindowTmp& b_block_gmem_window_tmp,
93 void* smem_ptr) const
94 {
95 static_assert(
96 std::is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
97 std::is_same_v<BDataType, remove_cv_t<typename BBlockGmemWindowTmp::DataType>>,
98 "wrong!");
99
100 constexpr index_t MPerBlock = ABlockTensor{}.get_lengths()[number<0>{}];
101 constexpr index_t NPerBlock = BBlockGmemWindowTmp{}.get_window_lengths()[number<0>{}];
102 constexpr index_t KPerBlock = ABlockTensor{}.get_lengths()[number<1>{}];
103
104 static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
105 KPerBlock == BlockGemmShape::kK,
106 "wrong!");
107
108 const auto b_block_gmem_window =
109 make_tile_window(b_block_gmem_window_tmp.get_bottom_tensor_view(),
111 b_block_gmem_window_tmp.get_window_origin(),
112 Policy::template MakeBGmemTileDistribution<Problem>());
113
114 // B LDS and LDS window
116 reinterpret_cast<BDataType*>(smem_ptr),
117 Policy::template MakeBSmemBlockDescriptor<Problem>());
118
119 auto b_block_smem_window = make_tile_window(
120 b_block_smem, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
121
122 // load B tile from global mem
123 const auto b_block_tile = load_tile(b_block_gmem_window);
124
125 // store B tile into shared mem
126 store_tile(b_block_smem_window, b_block_tile);
127
128 // wait for store_tile to finish
130
131 // block GEMM
132 return BlockGemmARegBGmemCRegImpl{}(a_block_tensor, b_block_smem_window);
133 }
134};
135
136} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_gemm_areg_bgmem_creg_v1_default_policy.hpp:13
Definition block_gemm_areg_bgmem_creg_v1.hpp:20
static CK_TILE_HOST_DEVICE constexpr index_t GetStaticLdsSize()
Definition block_gemm_areg_bgmem_creg_v1.hpp:35
remove_cvref_t< Problem_ > Problem
Definition block_gemm_areg_bgmem_creg_v1.hpp:21
remove_cvref_t< typename Problem::ADataType > ADataType
Definition block_gemm_areg_bgmem_creg_v1.hpp:23
remove_cvref_t< typename Problem::CDataType > CDataType
Definition block_gemm_areg_bgmem_creg_v1.hpp:25
remove_cvref_t< typename Problem::BDataType > BDataType
Definition block_gemm_areg_bgmem_creg_v1.hpp:24
CK_TILE_DEVICE auto operator()(const ABlockTensor &a_block_tensor, const BBlockGmemWindowTmp &b_block_gmem_window_tmp, void *smem_ptr) const
Definition block_gemm_areg_bgmem_creg_v1.hpp:91
remove_cvref_t< typename Problem::BlockGemmShape > BlockGemmShape
Definition block_gemm_areg_bgmem_creg_v1.hpp:26
BlockGemmARegBGmemCRegV1< BlockGemmProblem< ADataType, BDataType, CDataType, kBlockSize, BlockGemmShape >, BlockGemmARegBGmemCRegV1DefaultPolicy > BlockGemmARegBGmemCRegImpl
Definition block_gemm_areg_bgmem_creg_v1.hpp:31
CK_TILE_DEVICE void operator()(CBlockTensor &c_block_tensor, const ABlockTensor &a_block_tensor, const BBlockGmemWindowTmp &b_block_gmem_window_tmp, void *smem_ptr) const
Definition block_gemm_areg_bgmem_creg_v1.hpp:43
remove_cvref_t< Policy_ > Policy
Definition block_gemm_areg_bgmem_creg_v1.hpp:22
Definition block_gemm_problem.hpp:18