Loading [MathJax]/extensions/tex2jax.js
mxlib
c++ tools for analyzing astronomical data and other tasks by Jared R. Males. [git repo]
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Modules Pages
templateBLAS.hpp
Go to the documentation of this file.
1/** \file templateBLAS.hpp
2 * \brief Declares and defines templatized wrappers for the BLAS
3 * \ingroup gen_math_files
4 * \author Jared R. Males (jaredmales@gmail.com)
5 *
6 */
7
8//***********************************************************************//
9// Copyright 2015, 2016, 2017 Jared R. Males (jaredmales@gmail.com)
10//
11// This file is part of mxlib.
12//
13// mxlib is free software: you can redistribute it and/or modify
14// it under the terms of the GNU General Public License as published by
15// the Free Software Foundation, either version 3 of the License, or
16// (at your option) any later version.
17//
18// mxlib is distributed in the hope that it will be useful,
19// but WITHOUT ANY WARRANTY; without even the implied warranty of
20// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
21// GNU General Public License for more details.
22//
23// You should have received a copy of the GNU General Public License
24// along with mxlib. If not, see <http://www.gnu.org/licenses/>.
25//***********************************************************************//
26
27#include <complex>
28
29extern "C"
30{
31#if defined( MXLIB_MKL )
32
33#include <mkl.h>
34
35#elif defined( MXLIB_OPENBLAS )
36
37#include <cblas.h>
38
39#else
40
41#include <gsl/gsl_cblas.h>
42
43#endif
44}
45
46#ifndef math_templateBLAS_hpp
47#define math_templateBLAS_hpp
48
49namespace mx
50{
51namespace math
52{
53
54/// Template wrapper for cblas xSCAL
55/**
56 * \tparam dataT the data type of the alpha, X, and Y
57 *
58 * \ingroup template_blas
59 */
60template <typename dataT>
61void scal( const int N, const dataT &alpha, dataT *X, const int incX )
62{
63 // static_assert(0, "templateBLAS: no scal wrapper defined for type dataT");
64 return; // No BLAS for this type.
65}
66
67template <>
68void scal<float>( const int N, const float &alpha, float *X, const int incX );
69
70template <>
71void scal<double>( const int N, const double &alpha, double *X, const int incX );
72
73template <>
74void scal<std::complex<float>>( const int N, const std::complex<float> &alpha, std::complex<float> *X, const int incX );
75
76template <>
77void scal<std::complex<double>>( const int N,
78 const std::complex<double> &alpha,
79 std::complex<double> *X,
80 const int incX );
81
82/// Implementation of the Hadamard (element-wise) product of two vectors
83/** Computes the the Hadamard or element-wise product: X <- alpha*X*Y
84 *
85 * \param N [in] the length of the two vectors
86 * \param alpha [in] scalar to multiply each element by
87 * \param Y [in] vector to perform element-wise multiplication with
88 * \param incY [in] in-memory increment or stride for Y
89 * \param X [in.out] vector which is multiplied by alpha and element-wise multiplied by Y
90 * \param incX [in] in-memeory increment or stride for X
91 *
92 * \tparam dataT the data type of the alpha, X, and Y
93 *
94 * \ingroup template_blas
95 */
96template <typename dataT>
97void hadp_impl( const int N, dataT *__restrict__ Y, dataT *__restrict__ X )
98{
99 dataT *x = (dataT *)__builtin_assume_aligned( X, 16 );
100 dataT *y = (dataT *)__builtin_assume_aligned( Y, 16 );
101
102#pragma omp simd
103 for( int i = 0; i < N; i++ )
104 {
105 x[i] *= y[i];
106 }
107}
108
109/// Template wrapper for cblas-extension xHADP
110/** Computes the the Hadamard or element-wise product: X <- alpha*X*Y
111 *
112 * \param N [in] the length of the two vectors
113 * \param alpha [in] scalar to multiply each element by
114 * \param Y [in] vector to perform element-wise multiplication with
115 * \param incY [in] in-memory increment or stride for Y
116 * \param X [in.out] vector which is multiplied by alpha and element-wise multiplied by Y
117 * \param incX [in] in-memeory increment or stride for X
118 *
119 * \tparam dataT the data type of the alpha, X, and Y
120 *
121 * \ingroup template_blas
122 */
123template <typename dataT>
124void hadp( const int N, dataT *Y, dataT *X )
125{
126 hadp_impl( N, Y, X );
127}
128
129/// Implementation of the Hadamard (element-wise) division of two vectors
130/** Computes the the Hadamard or element-wise product: X <- alpha*X/Y
131 *
132 * \param N [in] the length of the two vectors
133 * \param alpha [in] scalar to multiply each element by
134 * \param Y [in] vector to perform element-wise division with
135 * \param incY [in] in-memory increment or stride for Y
136 * \param X [in.out] vector which is multiplied by alpha and element-wise divided by Y
137 * \param incX [in] in-memeory increment or stride for X
138 *
139 * \tparam dataT the data type of the alpha, X, and Y
140 *
141 * \ingroup template_blas
142 */
143template <typename dataT>
144void hadd_impl( const int N, const dataT alpha, const dataT *Y, const int incY, dataT *X, const int incX )
145{
146#pragma omp parallel for
147 for( int i = 0; i < N; ++i )
148 {
149 X[i * incX] = alpha * X[i * incX] / Y[i * incY];
150 }
151}
152
153/// Template wrapper for cblas-extension xHADD
154/** Computes the the Hadamard or element-wise product: X <- alpha*X/Y
155 *
156 * \param N [in] the length of the two vectors
157 * \param alpha [in] scalar to multiply each element by
158 * \param Y [in] vector to perform element-wise division with
159 * \param incY [in] in-memory increment or stride for Y
160 * \param X [in.out] vector which is multiplied by alpha and element-wise divided by Y
161 * \param incX [in] in-memeory increment or stride for X
162 *
163 * \tparam dataT the data type of the alpha, X, and Y
164 *
165 * \ingroup template_blas
166 */
167template <typename dataT>
168void hadd( const int N, const dataT alpha, const dataT *Y, const int incY, dataT *X, const int incX )
169{
170 hadd_impl( N, alpha, Y, incY, X, incX );
171}
172
173/// Template Wrapper for cblas xGEMM
174/**
175 *
176 * \ingroup template_blas
177 */
178template <typename dataT>
182 const int M,
183 const int N,
184 const int K,
185 const dataT &alpha,
186 const dataT *A,
187 const int lda,
188 const dataT *B,
189 const int ldb,
190 const dataT &beta,
191 dataT *C,
192 const int ldc )
193{
194 // static_assert(0, "templateBLAS: no gemm wrapper defined for type dataT");
195 return; // No BLAS for this type.
196}
197
198template <>
199void gemm<float>( const CBLAS_ORDER Order,
202 const int M,
203 const int N,
204 const int K,
205 const float &alpha,
206 const float *A,
207 const int lda,
208 const float *B,
209 const int ldb,
210 const float &beta,
211 float *C,
212 const int ldc );
213
214template <>
215void gemm<double>( const CBLAS_ORDER Order,
218 const int M,
219 const int N,
220 const int K,
221 const double &alpha,
222 const double *A,
223 const int lda,
224 const double *B,
225 const int ldb,
226 const double &beta,
227 double *C,
228 const int ldc );
229
230template <>
234 const int M,
235 const int N,
236 const int K,
237 const std::complex<float> &alpha,
238 const std::complex<float> *A,
239 const int lda,
240 const std::complex<float> *B,
241 const int ldb,
242 const std::complex<float> &beta,
243 std::complex<float> *C,
244 const int ldc );
245
246template <>
250 const int M,
251 const int N,
252 const int K,
253 const std::complex<double> &alpha,
254 const std::complex<double> *A,
255 const int lda,
256 const std::complex<double> *B,
257 const int ldb,
258 const std::complex<double> &beta,
259 std::complex<double> *C,
260 const int ldc );
261
262/// Template Wrapper for cblas xSYRK
263/**
264 *
265 * \ingroup template_blas
266 */
267template <typename dataT>
268/*void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
269 const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
270 const dataT alpha, const dataT *A, const int lda,
271 const dataT beta, dataT *C, const int ldc)*/
273 const CBLAS_UPLO Uplo,
275 const int N,
276 const int K,
277 const dataT &alpha,
278 const dataT *A,
279 const int lda,
280 const dataT &beta,
281 dataT *C,
282 const int ldc )
283{
284 // static_assert(0, "templateBLAS: no syrk wrapper defined for type dataT");
285 return; // No BLAS for this time.
286}
287
288template <>
289void syrk<float>( const CBLAS_ORDER Order,
290 const CBLAS_UPLO Uplo,
292 const int N,
293 const int K,
294 const float &alpha,
295 const float *A,
296 const int lda,
297 const float &beta,
298 float *C,
299 const int ldc );
300
301template <>
302void syrk<double>( const CBLAS_ORDER Order,
303 const CBLAS_UPLO Uplo,
305 const int N,
306 const int K,
307 const double &alpha,
308 const double *A,
309 const int lda,
310 const double &beta,
311 double *C,
312 const int ldc );
313
314template <>
316 const CBLAS_UPLO Uplo,
318 const int N,
319 const int K,
320 const std::complex<float> &alpha,
321 const std::complex<float> *A,
322 const int lda,
323 const std::complex<float> &beta,
324 std::complex<float> *C,
325 const int ldc );
326
327template <>
329 const CBLAS_UPLO Uplo,
331 const int N,
332 const int K,
333 const std::complex<double> &alpha,
334 const std::complex<double> *A,
335 const int lda,
336 const std::complex<double> &beta,
337 std::complex<double> *C,
338 const int ldc );
339
340} // namespace math
341} // namespace mx
342
343#endif // math_templateBLAS_hpp
constexpr floatT six_fifths()
Return 6/5 in the specified precision.
void hadp(const int N, dataT *Y, dataT *X)
Template wrapper for cblas-extension xHADP.
void syrk(const CBLAS_ORDER Order, const CBLAS_UPLO Uplo, const CBLAS_TRANSPOSE Trans, const int N, const int K, const dataT &alpha, const dataT *A, const int lda, const dataT &beta, dataT *C, const int ldc)
Template Wrapper for cblas xSYRK.
void scal(const int N, const dataT &alpha, dataT *X, const int incX)
Template wrapper for cblas xSCAL.
void hadd(const int N, const dataT alpha, const dataT *Y, const int incY, dataT *X, const int incX)
Template wrapper for cblas-extension xHADD.
void hadp_impl(const int N, dataT *__restrict__ Y, dataT *__restrict__ X)
Implementation of the Hadamard (element-wise) product of two vectors.
void hadd_impl(const int N, const dataT alpha, const dataT *Y, const int incY, dataT *X, const int incX)
Implementation of the Hadamard (element-wise) division of two vectors.
void gemm(const CBLAS_ORDER Order, const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K, const dataT &alpha, const dataT *A, const int lda, const dataT *B, const int ldb, const dataT &beta, dataT *C, const int ldc)
Template Wrapper for cblas xGEMM.
The mxlib c++ namespace.
Definition mxError.hpp:106