mxlib
c++ tools for analyzing astronomical data and other tasks by Jared R. Males. [git repo]
Loading...
Searching...
No Matches
templateCublas.hpp
Go to the documentation of this file.
1/** \file templateCublas.hpp
2 * \author Jared R. Males
3 * \brief A template interface to cuBlas
4 * \ingroup cuda_files
5 *
6 */
7
8//***********************************************************************//
9// Copyright 2019,2020 Jared R. Males (jaredmales@gmail.com)
10//
11// This file is part of mxlib.
12//
13// mxlib is free software: you can redistribute it and/or modify
14// it under the terms of the GNU General Public License as published by
15// the Free Software Foundation, either version 3 of the License, or
16// (at your option) any later version.
17//
18// mxlib is distributed in the hope that it will be useful,
19// but WITHOUT ANY WARRANTY; without even the implied warranty of
20// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
21// GNU General Public License for more details.
22//
23// You should have received a copy of the GNU General Public License
24// along with mxlib. If not, see <http://www.gnu.org/licenses/>.
25//***********************************************************************//
26
27#ifndef math_templateCublas_hpp
28#define math_templateCublas_hpp
29
30#include <cuda_runtime.h>
31#include <cublas_v2.h>
32
33namespace mx
34{
35namespace cuda
36{
37
38/// Multiplies a vector by a scalar, overwriting the vector with the result.
39/** Implements
40 * \f[
41 * \vec{x} = \alpha \vec{x}
42 * \f]
43 *
44 * Specializations are provided for float, double, complex-float, and complex-double
45 *
46 * \tparam floatT a floating-point type, either float, double, complex-float, or complex-double
47 *
48 * \test Scenario: scaling a vector with cublas \ref test_math_templateCublas_scal "[test doc]"
49 *
50 * \ingroup cublas
51 */
52template <typename floatT>
53cublasStatus_t cublasTscal( cublasHandle_t handle, ///< [in] The cublas context handle
54 int n, ///< [in] Number of elements in the vector
55 const floatT *alpha, ///< [in] The scalar
56 floatT *x, ///< [in.out] The vector of length n
57 int incx ///< [in] The stride of the vector
58);
59
60/// Multiplies a vector by a scalar, adding it to a second vector which is overwritten by the result.
61/** Implements
62 * \f[
63 * \vec{y} = \alpha \vec{x} + \vec{y}
64 * \f]
65 *
66 * Specializations are provided for float, double, complex-float, and complex-double
67 *
68 * \tparam floatT a floating-point type, either float, double, complex-float, or complex-double
69 *
70 * \test Scenario: scaling and accumulating a vector with cublas \ref test_math_templateCublas_axpy "[test doc]"
71 *
72 * \ingroup cublas
73 */
74template <typename floatT>
75cublasStatus_t cublasTaxpy( cublasHandle_t handle, ///< [in] handle to the cuBLAS library context.
76 int n, ///< [in] scalar used for multiplication.
77 const floatT *alpha, ///< [in] number of elements in the vector x and y
78 const floatT *x, ///< [in] vector with n elements.
79 int incx, ///< [in] stride between consecutive elements of x
80 floatT *y, ///< [in.out] vector with n elements.
81 int incy ///< [in] stride between consecutive elements of y
82);
83
84//----------------------------------------------------
85// Element-wise (Hadamard) products of vectors
86
87/// Calculates the element-wise product of two vectors, storing the result in the first.
88/** Calculates
89 * \f$
90 * x = x * y
91 * \f$
92 * element by element, a.k.a. the Hadamard product.
93 *
94 * Specializations are provided for:
95 * - float,float
96 * - complex-float, float
97 * - complex-float, complex-float
98 * - double, double
99 * - complex-double, double
100 * - complex-double, complex-double
101 *
102 * Tests:
103 * - Multiplying two vectors element by element \ref test_math_templateCublas_elementwiseXxY "[test doc]"
104 *
105 * \ingroup cublas
106 */
107template <typename dataT1, typename dataT2>
108cudaError_t elementwiseXxY(
109 dataT1 *x, ///< [in.out] device pointer for the 1st vector. Is replaced with the product of the two vectors
110 dataT2 *y, ///< [in] device pointer for the 2nd vector.
111 int size ///< [in] the number of elements in the vectors.
112);
113
114//----------------------------------------------------
115// Tgemv
116
117/// Perform a matrix-vector multiplication.
118/** Implements
119 * \f[
120 * \vec{y} = \alpha \mathbf{A} \vec{x} + \beta \vec{y}
121 * \f]
122 *
123 * Specializations are provided for float, double, complex-float, and complex-double
124 *
125 * \tparam floatT a floating-point type, either float, double, complex-float, or complex-double
126 *
127 * Tests:
128 * - Multiplying a vectory by a matrix \ref test_math_templateCublas_cublasTgemv_inc "[code doc]"
129 *
130 * \ingroup cublas
131 */
132template <typename floatT>
133cublasStatus_t
134cublasTgemv( cublasHandle_t handle, ///< [in] handle to the cuBLAS library context.
135 cublasOperation_t trans, ///< [in] operation on a, CUBLAS_OP_N for none, and CUBLAS_OP_T for transpose
136 int m, ///< [in] rows in matrix A.
137 int n, ///< [in] columns in matrix A.
138 const floatT *alpha, ///< [in] scalar used for multiplication of A
139 const floatT *A, ///< [in] array of dimension lda x n with lda >= max(1,m). The leading m by n part of the
140 ///< array A is multiplied by alpha and x. Unchanged.
141 int lda, ///< [in] leading dimension of A. lda must be at least max(1,m).
142 const floatT *x, ///< [in] vector of at least (1+(n-1)*abs(incx)) elements if transa==CUBLAS_OP_N and at
143 ///< least (1+(m-1)*abs(incx)) elements otherwise.
144 int incx, ///< [in] stride of x.
145 const floatT *
146 beta, ///< [in] scalar used for multiplication of y, if beta==0 then y does not need to be initialized.
147 floatT *y, ///< [in.out] vector of at least (1+(m-1)*abs(incy)) elements if transa==CUBLAS_OP_N and at
148 ///< least (1+(n-1)*abs(incy)) elements otherwise.
149 int incy ///< [in] stride of y
150);
151
152/// Perform a matrix-vector multiplication for stride-less arrays
153/** Implements
154 * \f[
155 * \vec{y} = \alpha \mathbf{A} \vec{x} + \beta \vec{y}
156 * \f]
157 *
158 * Specializations are provided for float, double, complex-float, and complex-double
159 *
160 * \overload
161 * This version assumes stride is 1 in all arrays.
162 *
163 * \tparam floatT a floating-point type, either float, double, complex-float, or complex-double
164 *
165 * \ingroup cublas
166 */
167template <typename floatT>
168cublasStatus_t
169cublasTgemv( cublasHandle_t handle, ///< [in] handle to the cuBLAS library context.
170 cublasOperation_t trans, ///< [in] operation on a, CUBLAS_OP_N for none, and CUBLAS_OP_T for transpose
171 int m, ///< [in] rows in matrix A.
172 int n, ///< [in] columns in matrix A.
173 const floatT *alpha, ///< [in] scalar used for multiplication of A
174 const floatT *A, ///< [in] array of dimension m x n. Unchanged.
175 const floatT *x, ///< [in] vector of at least (1+(n-1)*abs(incx)) elements if transa==CUBLAS_OP_N and at
176 ///< least (1+(m-1)*abs(incx)) elements otherwise.
177 const floatT *
178 beta, ///< [in] scalar used for multiplication of y, if beta==0 then y does not need to be initialized.
179 floatT *y ///< [in.out] vector of at least (1+(m-1)*abs(incy)) elements if transa==CUBLAS_OP_N and at least
180 ///< (1+(n-1)*abs(incy)) elements otherwise.
181);
182
183template <>
184cublasStatus_t cublasTgemv<float>( cublasHandle_t handle,
185 cublasOperation_t trans,
186 int m,
187 int n,
188 const float *alpha,
189 const float *A,
190 int lda,
191 const float *x,
192 int incx,
193 const float *beta,
194 float *y,
195 int incy );
196
197template <>
198cublasStatus_t cublasTgemv<double>( cublasHandle_t handle,
199 cublasOperation_t trans,
200 int m,
201 int n,
202 const double *alpha,
203 const double *A,
204 int lda,
205 const double *x,
206 int incx,
207 const double *beta,
208 double *y,
209 int incy );
210
211template <>
212cublasStatus_t cublasTgemv<float>( cublasHandle_t handle,
213 cublasOperation_t trans,
214 int m,
215 int n,
216 const float *alpha,
217 const float *A,
218 const float *x,
219 const float *beta,
220 float *y );
221
222template <>
223cublasStatus_t cublasTgemv<double>( cublasHandle_t handle,
224 cublasOperation_t trans,
225 int m,
226 int n,
227 const double *alpha,
228 const double *A,
229 const double *x,
230 const double *beta,
231 double *y );
232
233} // namespace cuda
234} // namespace mx
235#endif // math_templateCublas_hpp
cudaError_t elementwiseXxY(dataT1 *x, dataT2 *y, int size)
Calculates the element-wise product of two vectors, storing the result in the first.
cublasStatus_t cublasTaxpy(cublasHandle_t handle, int n, const floatT *alpha, const floatT *x, int incx, floatT *y, int incy)
Multiplies a vector by a scalar, adding it to a second vector which is overwritten by the result.
cublasStatus_t cublasTscal(cublasHandle_t handle, int n, const floatT *alpha, floatT *x, int incx)
Multiplies a vector by a scalar, overwriting the vector with the result.
cublasStatus_t cublasTgemv(cublasHandle_t handle, cublasOperation_t trans, int m, int n, const floatT *alpha, const floatT *A, int lda, const floatT *x, int incx, const floatT *beta, floatT *y, int incy)
Perform a matrix-vector multiplication.
The mxlib c++ namespace.
Definition mxError.hpp:40