RTC Toolkit 5.0.0
Loading...
Searching...
No Matches
gpuLib.hpp
Go to the documentation of this file.
1
13#ifndef EXAMPLEGPULIB_H
14#define EXAMPLEGPULIB_H
15
16#include <cassert>
17#include <string>
18#include <vector>
19
20#include "cublas_v2.h"
21#include "cuda_runtime.h"
22
23#include "rtctk/exampleDataTask/CumulativeAverage_cuda.cuh"
24
25class GpuLib {
26public:
27 GpuLib(int input_length, int output_length, int gpu);
29
30 void SetMatrix(float* mat, bool flip = true);
31 std::vector<float> GetMatrix();
32
34
35 std::vector<float> GetAvgSlopes();
36
37 std::vector<float> GetResults(bool download = false);
38
39 void NewSample(const float* sample, int callback_count);
40 void Compute();
41
43
44protected:
45 // sets the required GPU
46 void SetGPU();
47
48 void PrintCudaError(cudaError_t error);
49 std::string CublasGetStatusString(cublasStatus_t error);
50 void PrintCublasStatus(cublasStatus_t status);
51
52private:
53 // cppcheck-suppress-begin unusedStructMember
54
55 int m_gpu;
56 int current_sample;
57 float m_alpha;
58 float m_beta;
59
60 cublasHandle_t handle;
61
62 // input_vector
63 int m_slopes;
64 int m_modes;
65
66 // slopes vector
67 float* m_slopes_vector;
68 float* m_slopes_vector_d;
69
70 // avg_array_vector
71 float* m_avg_slopes;
72 float* m_avg_slopes_d;
73
74 // matrix
75 float* m_slopes_to_modes_matrix;
76 float* m_slopes_to_modes_matrix_d;
77
78 // output vector
79 float* m_modes_vector;
80 float* m_modes_vector_d;
81
82 // cppcheck-suppress-end unusedStructMember
83};
84
85#endif // EXAMPLEGPULIB_H
Definition gpuLib.hpp:25
std::string CublasGetStatusString(cublasStatus_t error)
void SetMatrix(float *mat, bool flip=true)
void InitReaderThread()
void ResetAvgSlopes()
void SetGPU()
void Compute()
std::vector< float > GetResults(bool download=false)
std::vector< float > GetMatrix()
void NewSample(const float *sample, int callback_count)
std::vector< float > GetAvgSlopes()
void PrintCublasStatus(cublasStatus_t status)
GpuLib(int input_length, int output_length, int gpu)
void PrintCudaError(cudaError_t error)