mxlib
c++ tools for analyzing astronomical data and other tasks by Jared R. Males. [git repo]
Loading...
Searching...
No Matches
templateCublas_test.cpp
Go to the documentation of this file.
1/** \file templateCublas_test.cpp
2 */
3#include "../../../catch2/catch.hpp"
4
5#include "../../../../include/math/cuda/cudaPtr.hpp"
6#include "../../../../include/math/cuda/templateCublas.hpp"
7
8/** Scenario: scaling a vector with cublas
9 * Tests cublasTscal, as well as basic cudaPtr operations.
10 *
11 * \anchor test_math_templateCublas_scal
12 */
13SCENARIO( "scaling a vector with cublas", "[math::cuda::templateCublas]" )
14{
15 GIVEN( "a vector" )
16 {
17 WHEN( "type is single precision real" )
18 {
19 std::vector<float> hx;
21 float alpha = 2;
22
23 hx.resize( 5 );
24 for( size_t n = 0; n < hx.size(); ++n )
25 hx[n] = n;
26
27 dx.upload( hx.data(), hx.size() );
28 REQUIRE( dx.size() == hx.size() );
29
30 cublasHandle_t handle;
31 cublasStatus_t stat;
32 stat = cublasCreate( &handle );
33 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
34
35 stat = mx::cuda::cublasTscal( handle, dx.size(), &alpha, dx(), 1 );
36 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
37
38 dx.download( hx.data() );
39
40 REQUIRE( hx[0] == 0 );
41 REQUIRE( hx[1] == 2 );
42 REQUIRE( hx[2] == 4 );
43 REQUIRE( hx[3] == 6 );
44 REQUIRE( hx[4] == 8 );
45
46 stat = cublasDestroy( handle );
47
48 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
49 }
50
51 WHEN( "type is double precision real" )
52 {
53 std::vector<double> hx;
55 double alpha = 2;
56
57 hx.resize( 5 );
58 for( size_t n = 0; n < hx.size(); ++n )
59 hx[n] = n;
60
61 dx.upload( hx.data(), hx.size() );
62 REQUIRE( dx.size() == hx.size() );
63
64 cublasHandle_t handle;
65 cublasStatus_t stat;
66 stat = cublasCreate( &handle );
67 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
68
69 stat = mx::cuda::cublasTscal( handle, dx.size(), &alpha, dx(), 1 );
70 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
71
72 dx.download( hx.data() );
73
74 REQUIRE( hx[0] == 0 );
75 REQUIRE( hx[1] == 2 );
76 REQUIRE( hx[2] == 4 );
77 REQUIRE( hx[3] == 6 );
78 REQUIRE( hx[4] == 8 );
79
80 stat = cublasDestroy( handle );
81
82 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
83 }
84
85 WHEN( "type is single precision complex" )
86 {
87 std::vector<std::complex<float>> hx;
89 std::complex<float> alpha = 2;
90
91 hx.resize( 5 );
92 for( size_t n = 0; n < hx.size(); ++n )
93 hx[n] = std::complex<float>( n, n );
94
95 dx.upload( hx.data(), hx.size() );
96 REQUIRE( dx.size() == hx.size() );
97
98 cublasHandle_t handle;
99 cublasStatus_t stat;
100 stat = cublasCreate( &handle );
101 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
102
103 stat = mx::cuda::cublasTscal( handle, dx.size(), (cuComplex *)&alpha, dx(), 1 );
104 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
105
106 dx.download( hx.data() );
107
108 REQUIRE( hx[0] == std::complex<float>( 0, 0 ) );
109 REQUIRE( hx[1] == std::complex<float>( 2, 2 ) );
110 REQUIRE( hx[2] == std::complex<float>( 4, 4 ) );
111 REQUIRE( hx[3] == std::complex<float>( 6, 6 ) );
112 REQUIRE( hx[4] == std::complex<float>( 8, 8 ) );
113
114 stat = cublasDestroy( handle );
115
116 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
117 }
118
119 WHEN( "type is double precision complex" )
120 {
121 std::vector<std::complex<double>> hx;
123 std::complex<double> alpha = 2;
124
125 hx.resize( 5 );
126 for( size_t n = 0; n < hx.size(); ++n )
127 hx[n] = std::complex<double>( n, n );
128
129 dx.upload( hx.data(), hx.size() );
130 REQUIRE( dx.size() == hx.size() );
131
132 cublasHandle_t handle;
133 cublasStatus_t stat;
134 stat = cublasCreate( &handle );
135 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
136
137 stat = mx::cuda::cublasTscal( handle, dx.size(), (cuDoubleComplex *)&alpha, dx(), 1 );
138 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
139
140 dx.download( hx.data() );
141
142 REQUIRE( hx[0] == std::complex<double>( 0, 0 ) );
143 REQUIRE( hx[1] == std::complex<double>( 2, 2 ) );
144 REQUIRE( hx[2] == std::complex<double>( 4, 4 ) );
145 REQUIRE( hx[3] == std::complex<double>( 6, 6 ) );
146 REQUIRE( hx[4] == std::complex<double>( 8, 8 ) );
147
148 stat = cublasDestroy( handle );
149
150 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
151 }
152 }
153}
154
155/** Scenario: scaling and accumulating a vector with cublas
156 * Tests cublasTaxpy, as well as basic cudaPtr operations.
157 *
158 * \anchor test_math_templateCublas_axpy
159 */
160SCENARIO( "scaling and accumulating a vector with cublas", "[math::cuda::templateCublas]" )
161{
162 GIVEN( "a vector" )
163 {
164 WHEN( "type is single precision real" )
165 {
166 std::vector<float> hx, hy;
168 float alpha = 2;
169
170 hx.resize( 5 );
171 for( size_t n = 0; n < hx.size(); ++n )
172 hx[n] = n;
173
174 hy.resize( 5 );
175 for( size_t n = 0; n < hy.size(); ++n )
176 hy[n] = 1;
177
178 dx.upload( hx.data(), hx.size() );
179 REQUIRE( dx.size() == hx.size() );
180
181 dy.upload( hy.data(), hy.size() );
182 REQUIRE( dy.size() == hy.size() );
183
184 cublasHandle_t handle;
185 cublasStatus_t stat;
186 stat = cublasCreate( &handle );
187 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
188
189 stat = mx::cuda::cublasTaxpy( handle, dx.size(), &alpha, dx(), 1, dy(), 1 );
190 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
191
192 dy.download( hy.data() );
193
194 REQUIRE( hy[0] == 1 );
195 REQUIRE( hy[1] == 3 );
196 REQUIRE( hy[2] == 5 );
197 REQUIRE( hy[3] == 7 );
198 REQUIRE( hy[4] == 9 );
199
200 stat = cublasDestroy( handle );
201
202 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
203 }
204
205 WHEN( "type is double precision real" )
206 {
207 std::vector<double> hx, hy;
209 double alpha = 2;
210
211 hx.resize( 5 );
212 for( size_t n = 0; n < hx.size(); ++n )
213 hx[n] = n;
214
215 hy.resize( 5 );
216 for( size_t n = 0; n < hy.size(); ++n )
217 hy[n] = 1;
218
219 dx.upload( hx.data(), hx.size() );
220 REQUIRE( dx.size() == hx.size() );
221
222 dy.upload( hy.data(), hy.size() );
223 REQUIRE( dy.size() == hy.size() );
224
225 cublasHandle_t handle;
226 cublasStatus_t stat;
227 stat = cublasCreate( &handle );
228 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
229
230 stat = mx::cuda::cublasTaxpy( handle, dx.size(), &alpha, dx(), 1, dy(), 1 );
231 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
232
233 dy.download( hy.data() );
234
235 REQUIRE( hy[0] == 1 );
236 REQUIRE( hy[1] == 3 );
237 REQUIRE( hy[2] == 5 );
238 REQUIRE( hy[3] == 7 );
239 REQUIRE( hy[4] == 9 );
240
241 stat = cublasDestroy( handle );
242
243 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
244 }
245
246 WHEN( "type is single precision complex" )
247 {
248 std::vector<std::complex<float>> hx, hy;
250 std::complex<float> alpha = 2;
251
252 hx.resize( 5 );
253 for( size_t n = 0; n < hx.size(); ++n )
254 hx[n] = std::complex<float>( n, n );
255
256 hy.resize( 5 );
257 for( size_t n = 0; n < hy.size(); ++n )
258 hy[n] = std::complex<float>( 1, 1 );
259
260 dx.upload( hx.data(), hx.size() );
261 REQUIRE( dx.size() == hx.size() );
262
263 dy.upload( hy.data(), hy.size() );
264 REQUIRE( dy.size() == hy.size() );
265
266 cublasHandle_t handle;
267 cublasStatus_t stat;
268 stat = cublasCreate( &handle );
269 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
270
271 stat = mx::cuda::cublasTaxpy( handle, dx.size(), (cuComplex *)&alpha, dx(), 1, dy(), 1 );
272 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
273
274 dy.download( hy.data() );
275
276 REQUIRE( hy[0] == std::complex<float>( 1, 1 ) );
277 REQUIRE( hy[1] == std::complex<float>( 3, 3 ) );
278 REQUIRE( hy[2] == std::complex<float>( 5, 5 ) );
279 REQUIRE( hy[3] == std::complex<float>( 7, 7 ) );
280 REQUIRE( hy[4] == std::complex<float>( 9, 9 ) );
281
282 stat = cublasDestroy( handle );
283
284 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
285 }
286
287 WHEN( "type is double precision complex" )
288 {
289 std::vector<std::complex<double>> hx, hy;
291 std::complex<double> alpha = 2;
292
293 hx.resize( 5 );
294 for( size_t n = 0; n < hx.size(); ++n )
295 hx[n] = std::complex<double>( n, n );
296
297 hy.resize( 5 );
298 for( size_t n = 0; n < hy.size(); ++n )
299 hy[n] = std::complex<double>( 1, 1 );
300
301 dx.upload( hx.data(), hx.size() );
302 REQUIRE( dx.size() == hx.size() );
303
304 dy.upload( hy.data(), hy.size() );
305 REQUIRE( dy.size() == hy.size() );
306
307 cublasHandle_t handle;
308 cublasStatus_t stat;
309 stat = cublasCreate( &handle );
310 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
311
312 stat = mx::cuda::cublasTaxpy( handle, dx.size(), (cuDoubleComplex *)&alpha, dx(), 1, dy(), 1 );
313 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
314
315 dy.download( hy.data() );
316
317 REQUIRE( hy[0] == std::complex<double>( 1, 1 ) );
318 REQUIRE( hy[1] == std::complex<double>( 3, 3 ) );
319 REQUIRE( hy[2] == std::complex<double>( 5, 5 ) );
320 REQUIRE( hy[3] == std::complex<double>( 7, 7 ) );
321 REQUIRE( hy[4] == std::complex<double>( 9, 9 ) );
322
323 stat = cublasDestroy( handle );
324
325 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
326 }
327 }
328}
329
330/** Scenario: multiplying two vectors element by element
331 * Tests mx::cuda::elementwiseXxY, as well as basic cudaPtr operations.
332 *
333 * \anchor test_math_templateCublas_elementwiseXxY
334 */
335SCENARIO( "multiplying two vectors element by element", "[math::cuda::templateCublas]" )
336{
337 GIVEN( "a vector" )
338 {
339 WHEN( "both types are single precision real" )
340 {
341 std::vector<float> hx, hy;
343
344 hx.resize( 5 );
345 for( size_t n = 0; n < hx.size(); ++n )
346 hx[n] = n;
347
348 hy.resize( 5 );
349 for( size_t n = 0; n < hy.size(); ++n )
350 hy[n] = 2 * n;
351
352 dx.upload( hx.data(), hx.size() );
353 REQUIRE( dx.size() == hx.size() );
354
355 dy.resize( hy.size() );
356 dy.upload( hy.data() );
357 REQUIRE( dy.size() == hy.size() );
358
359 cudaError_t rv = mx::cuda::elementwiseXxY( dx(), dy(), dx.size() );
360 REQUIRE( rv == cudaSuccess );
361
362 dx.download( hx.data() );
363
364 REQUIRE( hx[0] == 0 );
365 REQUIRE( hx[1] == 2 );
366 REQUIRE( hx[2] == 8 );
367 REQUIRE( hx[3] == 18 );
368 REQUIRE( hx[4] == 32 );
369 }
370
371 WHEN( "type1 is complex-float, and type2 is float" )
372 {
373 std::vector<std::complex<float>> hx;
375
376 std::vector<float> hy;
378
379 hx.resize( 5 );
380 for( size_t n = 0; n < hx.size(); ++n )
381 hx[n] = std::complex<float>( n, n );
382
383 hy.resize( 5 );
384 for( size_t n = 0; n < hy.size(); ++n )
385 hy[n] = 2 * n;
386
387 dx.upload( hx.data(), hx.size() );
388 REQUIRE( dx.size() == hx.size() );
389
390 dy.resize( hy.size() );
391 dy.upload( hy.data() );
392 REQUIRE( dy.size() == hy.size() );
393
394 cudaError_t rv = mx::cuda::elementwiseXxY( dx(), dy(), dx.size() );
395 REQUIRE( rv == cudaSuccess );
396
397 dx.download( hx.data() );
398
399 REQUIRE( hx[0] == std::complex<float>( 0, 0 ) );
400 REQUIRE( hx[1] == std::complex<float>( 2, 2 ) );
401 REQUIRE( hx[2] == std::complex<float>( 8, 8 ) );
402 REQUIRE( hx[3] == std::complex<float>( 18, 18 ) );
403 REQUIRE( hx[4] == std::complex<float>( 32, 32 ) );
404 }
405
406 WHEN( "type1 is complex-float, and type2 is complex-float" )
407 {
408 std::vector<std::complex<float>> hx;
410
411 std::vector<std::complex<float>> hy;
413
414 hx.resize( 5 );
415 for( size_t n = 0; n < hx.size(); ++n )
416 hx[n] = std::complex<float>( n, n );
417
418 hy.resize( 5 );
419 for( size_t n = 0; n < hy.size(); ++n )
420 hy[n] = std::complex<float>( 0, 2 * n );
421
422 dx.upload( hx.data(), hx.size() );
423 REQUIRE( dx.size() == hx.size() );
424
425 dy.resize( hy.size() );
426 dy.upload( hy.data() );
427 REQUIRE( dy.size() == hy.size() );
428
429 cudaError_t rv = mx::cuda::elementwiseXxY( dx(), dy(), dx.size() );
430 REQUIRE( rv == cudaSuccess );
431
432 dx.download( hx.data() );
433
434 REQUIRE( hx[0] == std::complex<float>( 0, 0 ) );
435 REQUIRE( hx[1] == std::complex<float>( -2, 2 ) ); //(1,1) * (0,2) = (0 + 2i + 0i -2) = (-2,2)
436 REQUIRE( hx[2] == std::complex<float>( -8, 8 ) );
437 REQUIRE( hx[3] == std::complex<float>( -18, 18 ) );
438 REQUIRE( hx[4] == std::complex<float>( -32, 32 ) );
439 }
440
441 WHEN( "both types are double precision real" )
442 {
443 std::vector<double> hx, hy;
445
446 hx.resize( 5 );
447 for( size_t n = 0; n < hx.size(); ++n )
448 hx[n] = n;
449
450 hy.resize( 5 );
451 for( size_t n = 0; n < hy.size(); ++n )
452 hy[n] = 2 * n;
453
454 dx.upload( hx.data(), hx.size() );
455 REQUIRE( dx.size() == hx.size() );
456
457 dy.resize( hy.size() );
458 dy.upload( hy.data() );
459 REQUIRE( dy.size() == hy.size() );
460
461 cudaError_t rv = mx::cuda::elementwiseXxY( dx(), dy(), dx.size() );
462 REQUIRE( rv == cudaSuccess );
463
464 dx.download( hx.data() );
465
466 REQUIRE( hx[0] == 0 );
467 REQUIRE( hx[1] == 2 );
468 REQUIRE( hx[2] == 8 );
469 REQUIRE( hx[3] == 18 );
470 REQUIRE( hx[4] == 32 );
471 }
472
473 WHEN( "type1 is complex-double, and type2 is double" )
474 {
475 std::vector<std::complex<double>> hx;
477
478 std::vector<double> hy;
480
481 hx.resize( 5 );
482 for( size_t n = 0; n < hx.size(); ++n )
483 hx[n] = std::complex<double>( n, n );
484
485 hy.resize( 5 );
486 for( size_t n = 0; n < hy.size(); ++n )
487 hy[n] = 2 * n;
488
489 dx.upload( hx.data(), hx.size() );
490 REQUIRE( dx.size() == hx.size() );
491
492 dy.resize( hy.size() );
493 dy.upload( hy.data() );
494 REQUIRE( dy.size() == hy.size() );
495
496 cudaError_t rv = mx::cuda::elementwiseXxY( dx(), dy(), dx.size() );
497 REQUIRE( rv == cudaSuccess );
498
499 dx.download( hx.data() );
500
501 REQUIRE( hx[0] == std::complex<double>( 0, 0 ) );
502 REQUIRE( hx[1] == std::complex<double>( 2, 2 ) );
503 REQUIRE( hx[2] == std::complex<double>( 8, 8 ) );
504 REQUIRE( hx[3] == std::complex<double>( 18, 18 ) );
505 REQUIRE( hx[4] == std::complex<double>( 32, 32 ) );
506 }
507
508 WHEN( "type1 is complex-double, and type2 is complex-double" )
509 {
510 std::vector<std::complex<double>> hx;
512
513 std::vector<std::complex<double>> hy;
515
516 hx.resize( 5 );
517 for( size_t n = 0; n < hx.size(); ++n )
518 hx[n] = std::complex<double>( n, n );
519
520 hy.resize( 5 );
521 for( size_t n = 0; n < hy.size(); ++n )
522 hy[n] = std::complex<double>( 1, 2 * n );
523
524 dx.upload( hx.data(), hx.size() );
525 REQUIRE( dx.size() == hx.size() );
526
527 dy.resize( hy.size() );
528 dy.upload( hy.data() );
529 REQUIRE( dy.size() == hy.size() );
530
531 cudaError_t rv = mx::cuda::elementwiseXxY( dx(), dy(), dx.size() );
532 REQUIRE( rv == cudaSuccess );
533
534 dx.download( hx.data() );
535
536 REQUIRE( hx[0] == std::complex<double>( 0, 0 ) ); //(0,0) * (0,0) = (0,0)
537 REQUIRE( hx[1] == std::complex<double>( -1, 3 ) ); //(1,1) * (1,2) = (1 + 2i + i -2) = (-1,3)
538 REQUIRE( hx[2] == std::complex<double>( -6, 10 ) );
539 REQUIRE( hx[3] == std::complex<double>( -15, 21 ) );
540 REQUIRE( hx[4] == std::complex<double>( -28, 36 ) );
541 }
542 }
543}
544
545/** Scenario: multiplying a vector by a matrix
546 * Tests mx::cuda::cublasTgemv, as well as basic cudaPtr operations.
547 *
548 * \anchor test_math_templateCublas_cublasTgemv_inc
549 */
550SCENARIO( "multiplying a vector by a matrix giving increments", "[math::cuda::templateCublas]" )
551{
552 GIVEN( "a 2x2 matrix, float" )
553 {
554 WHEN( "float precision, beta is 0" )
555 {
556 std::vector<float> hA; // This will actually be a vector
558
559 std::vector<float> hx; // This will actually be a vector
561
562 std::vector<float> hy;
564
565 /* Column major order:
566 1 3
567 2 4
568 */
569 hA.resize( 4 );
570 hA[0] = 1;
571 hA[1] = 2;
572 hA[2] = 3;
573 hA[3] = 4;
574
575 dA.resize( 2, 2 );
576 dA.upload( hA.data() );
577
578 hx.resize( 2 );
579 hx[0] = 1;
580 hx[1] = 2;
581
582 dx.upload( hx.data(), hx.size() );
583
584 hy.resize( 2 );
585
586 dy.resize( 2 );
587 dy.initialize();
588
589 float alpha = 1;
590 float beta = 0;
591
592 cublasHandle_t handle;
593 cublasStatus_t stat;
594 stat = cublasCreate( &handle );
595 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
596
597 stat = mx::cuda::cublasTgemv( handle, CUBLAS_OP_N, 2, 2, &alpha, dA(), 2, dx(), 1, &beta, dy(), 1 );
598 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
599
600 dy.download( hy.data() );
601
602 REQUIRE( hy[0] == 7 );
603 REQUIRE( hy[1] == 10 );
604
605 stat = cublasDestroy( handle );
606 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
607 }
608
609 WHEN( "float precision, beta is 1, but y is all 0" )
610 {
611 std::vector<float> hA; // This will actually be a vector
613
614 std::vector<float> hx; // This will actually be a vector
616
617 std::vector<float> hy;
619
620 /* Column major order:
621 1 3
622 2 4
623 */
624 hA.resize( 4 );
625 hA[0] = 1;
626 hA[1] = 2;
627 hA[2] = 3;
628 hA[3] = 4;
629
630 dA.resize( 2, 2 );
631 dA.upload( hA.data() );
632
633 hx.resize( 2 );
634 hx[0] = 1;
635 hx[1] = 2;
636
637 dx.upload( hx.data(), hx.size() );
638
639 hy.resize( 2 );
640
641 dy.resize( 2 );
642 dy.initialize();
643
644 float alpha = 1;
645 float beta = 1;
646
647 cublasHandle_t handle;
648 cublasStatus_t stat;
649 stat = cublasCreate( &handle );
650 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
651
652 stat = mx::cuda::cublasTgemv( handle, CUBLAS_OP_N, 2, 2, &alpha, dA(), 2, dx(), 1, &beta, dy(), 1 );
653 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
654
655 dy.download( hy.data() );
656
657 REQUIRE( hy[0] == 7 );
658 REQUIRE( hy[1] == 10 );
659
660 stat = cublasDestroy( handle );
661 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
662 }
663 WHEN( "float precision, beta is 1, y is [1,2]" )
664 {
665 std::vector<float> hA; // This will actually be a vector
667
668 std::vector<float> hx; // This will actually be a vector
670
671 std::vector<float> hy;
673
674 /* Column major order:
675 1 3
676 2 4
677 */
678 hA.resize( 4 );
679 hA[0] = 1;
680 hA[1] = 2;
681 hA[2] = 3;
682 hA[3] = 4;
683
684 dA.resize( 2, 2 );
685 dA.upload( hA.data() );
686
687 hx.resize( 2 );
688 hx[0] = 1;
689 hx[1] = 2;
690
691 dx.upload( hx.data(), hx.size() );
692
693 hy.resize( 2 );
694 hy[0] = 1;
695 hy[1] = 2;
696
697 dy.resize( 2 );
698 dy.upload( hx.data() );
699
700 float alpha = 1;
701 float beta = 1;
702
703 cublasHandle_t handle;
704 cublasStatus_t stat;
705 stat = cublasCreate( &handle );
706 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
707
708 stat = mx::cuda::cublasTgemv( handle, CUBLAS_OP_N, 2, 2, &alpha, dA(), 2, dx(), 1, &beta, dy(), 1 );
709 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
710
711 dy.download( hy.data() );
712
713 REQUIRE( hy[0] == 8 );
714 REQUIRE( hy[1] == 12 );
715
716 stat = cublasDestroy( handle );
717 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
718 }
719 }
720
721 GIVEN( "a 2x2 matrix, double" )
722 {
723 WHEN( "double precision, beta is 0" )
724 {
725 std::vector<double> hA; // This will actually be a vector
727
728 std::vector<double> hx; // This will actually be a vector
730
731 std::vector<double> hy;
733
734 /* Column major order:
735 1 3
736 2 4
737 */
738 hA.resize( 4 );
739 hA[0] = 1;
740 hA[1] = 2;
741 hA[2] = 3;
742 hA[3] = 4;
743
744 dA.resize( 2, 2 );
745 dA.upload( hA.data() );
746
747 hx.resize( 2 );
748 hx[0] = 1;
749 hx[1] = 2;
750
751 dx.upload( hx.data(), hx.size() );
752
753 hy.resize( 2 );
754
755 dy.resize( 2 );
756 dy.initialize();
757
758 double alpha = 1;
759 double beta = 0;
760
761 cublasHandle_t handle;
762 cublasStatus_t stat;
763 stat = cublasCreate( &handle );
764 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
765
766 stat = mx::cuda::cublasTgemv( handle, CUBLAS_OP_N, 2, 2, &alpha, dA(), 2, dx(), 1, &beta, dy(), 1 );
767 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
768
769 dy.download( hy.data() );
770
771 REQUIRE( hy[0] == 7 );
772 REQUIRE( hy[1] == 10 );
773
774 stat = cublasDestroy( handle );
775 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
776 }
777
778 WHEN( "double precision, beta is 1, but y is all 0" )
779 {
780 std::vector<double> hA; // This will actually be a vector
782
783 std::vector<double> hx; // This will actually be a vector
785
786 std::vector<double> hy;
788
789 /* Column major order:
790 1 3
791 2 4
792 */
793 hA.resize( 4 );
794 hA[0] = 1;
795 hA[1] = 2;
796 hA[2] = 3;
797 hA[3] = 4;
798
799 dA.resize( 2, 2 );
800 dA.upload( hA.data() );
801
802 hx.resize( 2 );
803 hx[0] = 1;
804 hx[1] = 2;
805
806 dx.upload( hx.data(), hx.size() );
807
808 hy.resize( 2 );
809
810 dy.resize( 2 );
811 dy.initialize();
812
813 double alpha = 1;
814 double beta = 1;
815
816 cublasHandle_t handle;
817 cublasStatus_t stat;
818 stat = cublasCreate( &handle );
819 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
820
821 stat = mx::cuda::cublasTgemv( handle, CUBLAS_OP_N, 2, 2, &alpha, dA(), 2, dx(), 1, &beta, dy(), 1 );
822 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
823
824 dy.download( hy.data() );
825
826 REQUIRE( hy[0] == 7 );
827 REQUIRE( hy[1] == 10 );
828
829 stat = cublasDestroy( handle );
830 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
831 }
832 WHEN( "double precision, beta is 1, y is [1,2]" )
833 {
834 std::vector<double> hA; // This will actually be a vector
836
837 std::vector<double> hx; // This will actually be a vector
839
840 std::vector<double> hy;
842
843 /* Column major order:
844 1 3
845 2 4
846 */
847 hA.resize( 4 );
848 hA[0] = 1;
849 hA[1] = 2;
850 hA[2] = 3;
851 hA[3] = 4;
852
853 dA.resize( 2, 2 );
854 dA.upload( hA.data() );
855
856 hx.resize( 2 );
857 hx[0] = 1;
858 hx[1] = 2;
859
860 dx.upload( hx.data(), hx.size() );
861
862 hy.resize( 2 );
863 hy[0] = 1;
864 hy[1] = 2;
865
866 dy.resize( 2 );
867 dy.upload( hx.data() );
868
869 double alpha = 1;
870 double beta = 1;
871
872 cublasHandle_t handle;
873 cublasStatus_t stat;
874 stat = cublasCreate( &handle );
875 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
876
877 stat = mx::cuda::cublasTgemv( handle, CUBLAS_OP_N, 2, 2, &alpha, dA(), 2, dx(), 1, &beta, dy(), 1 );
878 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
879
880 dy.download( hy.data() );
881
882 REQUIRE( hy[0] == 8 );
883 REQUIRE( hy[1] == 12 );
884
885 stat = cublasDestroy( handle );
886 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
887 }
888 }
889}
cudaError_t elementwiseXxY(dataT1 *x, dataT2 *y, int size)
Calculates the element-wise product of two vectors, storing the result in the first.
cublasStatus_t cublasTaxpy(cublasHandle_t handle, int n, const floatT *alpha, const floatT *x, int incx, floatT *y, int incy)
Multiplies a vector by a scalar, adding it to a second vector which is overwritten by the result.
cublasStatus_t cublasTscal(cublasHandle_t handle, int n, const floatT *alpha, floatT *x, int incx)
Multiplies a vector by a scalar, overwriting the vector with the result.
cublasStatus_t cublasTgemv(cublasHandle_t handle, cublasOperation_t trans, int m, int n, const floatT *alpha, const floatT *A, int lda, const floatT *x, int incx, const floatT *beta, floatT *y, int incy)
Perform a matrix-vector multiplication.
A smart-pointer wrapper for cuda device pointers.
Definition cudaPtr.hpp:47
cudaError_t initialize()
Initialize the array bytes to 0.
Definition cudaPtr.hpp:218
int download(hostPtrT *dest)
Copy from the device to the host.
Definition cudaPtr.hpp:272
int upload(const hostPtrT *src)
Copy from the host to the device, after allocation.
Definition cudaPtr.hpp:244
int resize(size_t sz)
Resize the memory allocation, in 1D.
Definition cudaPtr.hpp:187
SCENARIO("scaling a vector with cublas", "[math::cuda::templateCublas]")