mxlib
c++ tools for analyzing astronomical data and other tasks by Jared R. Males. [git repo]
Loading...
Searching...
No Matches
templateCublas.hpp
Go to the documentation of this file.
1/** \file templateCublas.hpp
2 * \author Jared R. Males
3 * \brief A template interface to cuBlas
4 * \ingroup cuda_files
5 *
6 */
7
8//***********************************************************************//
9// Copyright 2019,2020 Jared R. Males (jaredmales@gmail.com)
10//
11// This file is part of mxlib.
12//
13// mxlib is free software: you can redistribute it and/or modify
14// it under the terms of the GNU General Public License as published by
15// the Free Software Foundation, either version 3 of the License, or
16// (at your option) any later version.
17//
18// mxlib is distributed in the hope that it will be useful,
19// but WITHOUT ANY WARRANTY; without even the implied warranty of
20// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
21// GNU General Public License for more details.
22//
23// You should have received a copy of the GNU General Public License
24// along with mxlib. If not, see <http://www.gnu.org/licenses/>.
25//***********************************************************************//
26
27#ifndef math_templateCublas_hpp
28#define math_templateCublas_hpp
29
30#include <cuda_runtime.h>
31#include <cublas_v2.h>
32
33namespace mx
34{
35namespace cuda
36{
37
38/// Multiplies a vector by a scalar, overwriting the vector with the result.
39/** Implements
40 * \f[
41 * \vec{x} = \alpha \vec{x}
42 * \f]
43 *
44 * Specializations are provided for float, double, complex-float, and complex-double
45 *
46 * \tparam floatT a floating-point type, either float, double, complex-float, or complex-double
47 *
48 * \test Scenario: scaling a vector with cublas \ref test_math_templateCublas_scal "[test doc]"
49 *
50 * \ingroup cublas
51 */
52template <typename floatT>
53cublasStatus_t cublasTscal( cublasHandle_t handle, ///< [in] The cublas context handle
54 int n, ///< [in] Number of elements in the vector
55 const floatT *alpha, ///< [in] The scalar
56 floatT *x, ///< [in.out] The vector of length n
57 int incx ///< [in] The stride of the vector
58);
59
60/// Multiplies a vector by a scalar, adding it to a second vector which is overwritten by the result.
61/** Implements
62 * \f[
63 * \vec{y} = \alpha \vec{x} + \vec{y}
64 * \f]
65 *
66 * Specializations are provided for float, double, complex-float, and complex-double
67 *
68 * \tparam floatT a floating-point type, either float, double, complex-float, or complex-double
69 *
70 * \test Scenario: scaling and accumulating a vector with cublas \ref test_math_templateCublas_axpy "[test doc]"
71 *
72 * \ingroup cublas
73 */
74template <typename floatT>
75cublasStatus_t cublasTaxpy( cublasHandle_t handle, ///< [in] handle to the cuBLAS library context.
76 int n, ///< [in] scalar used for multiplication.
77 const floatT *alpha, ///< [in] number of elements in the vector x and y
78 const floatT *x, ///< [in] vector with n elements.
79 int incx, ///< [in] stride between consecutive elements of x
80 floatT *y, ///< [in.out] vector with n elements.
81 int incy ///< [in] stride between consecutive elements of y
82);
83
84//----------------------------------------------------
85// Element-wise (Hadamard) products of vectors
86
87/// Calculates the element-wise product of two vectors, storing the result in the first.
88/** Calculates
89 * \f$
90 * x = x * y
91 * \f$
92 * element by element, a.k.a. the Hadamard product.
93 *
94 * Specializations are provided for:
95 * - float,float
96 * - complex-float, float
97 * - complex-float, complex-float
98 * - double, double
99 * - complex-double, double
100 * - complex-double, complex-double
101 *
102 * \test Scenario: multiplying two vectors element by element \ref test_math_templateCublas_elementwiseXxY "[test doc]"
103 *
104 * \ingroup cublas
105 */
106template <typename dataT1, typename dataT2>
107cudaError_t elementwiseXxY(
108 dataT1 *x, ///< [in.out] device pointer for the 1st vector. Is replaced with the product of the two vectors
109 dataT2 *y, ///< [in] device pointer for the 2nd vector.
110 int size ///< [in] the number of elements in the vectors.
111);
112
113//----------------------------------------------------
114// Tgemv
115
116/// Perform a matrix-vector multiplication.
117/** Implements
118 * \f[
119 * \vec{y} = \alpha \mathbf{A} \vec{x} + \beta \vec{y}
120 * \f]
121 *
122 * Specializations are provided for float, double, complex-float, and complex-double
123 *
124 * \tparam floatT a floating-point type, either float, double, complex-float, or complex-double
125 *
126 * \tests Scenario: multiplying a vectory by a matrix \ref test_math_templateCublas_cublasTgemv_inc "[code doc]"
127 *
128 * \ingroup cublas
129 */
130template <typename floatT>
131cublasStatus_t
132cublasTgemv( cublasHandle_t handle, ///< [in] handle to the cuBLAS library context.
133 cublasOperation_t trans, ///< [in] operation on a, CUBLAS_OP_N for none, and CUBLAS_OP_T for transpose
134 int m, ///< [in] rows in matrix A.
135 int n, ///< [in] columns in matrix A.
136 const floatT *alpha, ///< [in] scalar used for multiplication of A
137 const floatT *A, ///< [in] array of dimension lda x n with lda >= max(1,m). The leading m by n part of the
138 ///< array A is multiplied by alpha and x. Unchanged.
139 int lda, ///< [in] leading dimension of A. lda must be at least max(1,m).
140 const floatT *x, ///< [in] vector of at least (1+(n-1)*abs(incx)) elements if transa==CUBLAS_OP_N and at
141 ///< least (1+(m-1)*abs(incx)) elements otherwise.
142 int incx, ///< [in] stride of x.
143 const floatT *
144 beta, ///< [in] scalar used for multiplication of y, if beta==0 then y does not need to be initialized.
145 floatT *y, ///< [in.out] vector of at least (1+(m-1)*abs(incy)) elements if transa==CUBLAS_OP_N and at
146 ///< least (1+(n-1)*abs(incy)) elements otherwise.
147 int incy ///< [in] stride of y
148);
149
150/// Perform a matrix-vector multiplication for stride-less arrays
151/** Implements
152 * \f[
153 * \vec{y} = \alpha \mathbf{A} \vec{x} + \beta \vec{y}
154 * \f]
155 *
156 * Specializations are provided for float, double, complex-float, and complex-double
157 *
158 * \overload
159 * This version assumes stride is 1 in all arrays.
160 *
161 * \tparam floatT a floating-point type, either float, double, complex-float, or complex-double
162 *
163 * \ingroup cublas
164 */
165template <typename floatT>
166cublasStatus_t
167cublasTgemv( cublasHandle_t handle, ///< [in] handle to the cuBLAS library context.
168 cublasOperation_t trans, ///< [in] operation on a, CUBLAS_OP_N for none, and CUBLAS_OP_T for transpose
169 int m, ///< [in] rows in matrix A.
170 int n, ///< [in] columns in matrix A.
171 const floatT *alpha, ///< [in] scalar used for multiplication of A
172 const floatT *A, ///< [in] array of dimension m x n. Unchanged.
173 const floatT *x, ///< [in] vector of at least (1+(n-1)*abs(incx)) elements if transa==CUBLAS_OP_N and at
174 ///< least (1+(m-1)*abs(incx)) elements otherwise.
175 const floatT *
176 beta, ///< [in] scalar used for multiplication of y, if beta==0 then y does not need to be initialized.
177 floatT *y ///< [in.out] vector of at least (1+(m-1)*abs(incy)) elements if transa==CUBLAS_OP_N and at least
178 ///< (1+(n-1)*abs(incy)) elements otherwise.
179);
180
181template <>
182cublasStatus_t cublasTgemv<float>( cublasHandle_t handle,
183 cublasOperation_t trans,
184 int m,
185 int n,
186 const float *alpha,
187 const float *A,
188 int lda,
189 const float *x,
190 int incx,
191 const float *beta,
192 float *y,
193 int incy );
194
195template <>
196cublasStatus_t cublasTgemv<double>( cublasHandle_t handle,
197 cublasOperation_t trans,
198 int m,
199 int n,
200 const double *alpha,
201 const double *A,
202 int lda,
203 const double *x,
204 int incx,
205 const double *beta,
206 double *y,
207 int incy );
208
209template <>
210cublasStatus_t cublasTgemv<float>( cublasHandle_t handle,
211 cublasOperation_t trans,
212 int m,
213 int n,
214 const float *alpha,
215 const float *A,
216 const float *x,
217 const float *beta,
218 float *y );
219
220template <>
221cublasStatus_t cublasTgemv<double>( cublasHandle_t handle,
222 cublasOperation_t trans,
223 int m,
224 int n,
225 const double *alpha,
226 const double *A,
227 const double *x,
228 const double *beta,
229 double *y );
230
231} // namespace cuda
232} // namespace mx
233#endif // math_templateCublas_hpp
cudaError_t elementwiseXxY(dataT1 *x, dataT2 *y, int size)
Calculates the element-wise product of two vectors, storing the result in the first.
cublasStatus_t cublasTaxpy(cublasHandle_t handle, int n, const floatT *alpha, const floatT *x, int incx, floatT *y, int incy)
Multiplies a vector by a scalar, adding it to a second vector which is overwritten by the result.
cublasStatus_t cublasTscal(cublasHandle_t handle, int n, const floatT *alpha, floatT *x, int incx)
Multiplies a vector by a scalar, overwriting the vector with the result.
cublasStatus_t cublasTgemv(cublasHandle_t handle, cublasOperation_t trans, int m, int n, const floatT *alpha, const floatT *A, int lda, const floatT *x, int incx, const floatT *beta, floatT *y, int incy)
Perform a matrix-vector multiplication.
The mxlib c++ namespace.
Definition mxError.hpp:106