device_memory.hpp Source File

device_memory.hpp Source File#

Composable Kernel: device_memory.hpp Source File
library/utility/device_memory.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <hip/hip_runtime.h>
7
8template <typename T>
9__global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size)
10{
11 for(uint64_t i = threadIdx.x; i < buffer_element_size; i += blockDim.x)
12 {
13 p[i] = x;
14 }
15}
16
22{
23 DeviceMem() : mpDeviceBuf(nullptr), mMemSize(0) {}
24 DeviceMem(std::size_t mem_size);
25 void Realloc(std::size_t mem_size);
26 void* GetDeviceBuffer() const;
27 std::size_t GetBufferSize() const;
28 void ToDevice(const void* p) const;
29 void ToDevice(const void* p, const std::size_t cpySize) const;
30 void FromDevice(void* p) const;
31 void FromDevice(void* p, const std::size_t cpySize) const;
32 void SetZero() const;
33 template <typename T>
34 void SetValue(T x) const;
36
38 std::size_t mMemSize;
39};
40
41template <typename T>
42void DeviceMem::SetValue(T x) const
43{
44 if(mMemSize % sizeof(T) != 0)
45 {
46 throw std::runtime_error("wrong! not entire DeviceMem will be set");
47 }
48
49 set_buffer_value<T><<<1, 1024>>>(static_cast<T*>(mpDeviceBuf), x, mMemSize / sizeof(T));
50}
__global__ void set_buffer_value(T *p, T x, uint64_t buffer_element_size)
Definition library/utility/device_memory.hpp:9
unsigned __int64 uint64_t
Definition stdint.h:136
void SetValue(T x) const
Definition library/utility/device_memory.hpp:42
void Realloc(std::size_t mem_size)
void * GetDeviceBuffer() const
DeviceMem(std::size_t mem_size)
void FromDevice(void *p) const
void FromDevice(void *p, const std::size_t cpySize) const
DeviceMem()
Definition library/utility/device_memory.hpp:23
void ToDevice(const void *p, const std::size_t cpySize) const
void ToDevice(const void *p) const
void * mpDeviceBuf
Definition library/utility/device_memory.hpp:37
std::size_t GetBufferSize() const
void SetZero() const
std::size_t mMemSize
Definition library/utility/device_memory.hpp:38