mxlib
c++ tools for analyzing astronomical data and other tasks by Jared R. Males. [git repo]
Loading...
Searching...
No Matches
templateCublas_test.cpp
Go to the documentation of this file.
1/** \file templateCublas_test.cpp
2 */
3#include "../../../catch2/catch.hpp"
4
5#include "../../../../include/math/cuda/cudaPtr.hpp"
6#include "../../../../include/math/cuda/templateCublas.hpp"
7
8/// Tests cublasTscal, as well as basic cudaPtr operations.
9/**
10 * \test
11 */
12SCENARIO( "scaling a vector with cublas", "[math::cuda::templateCublas]" )
13{
14 GIVEN( "a vector" )
15 {
16 WHEN( "type is single precision real" )
17 {
18 std::vector<float> hx;
19 mx::cuda::cudaPtr<float> dx;
20 float alpha = 2;
21
22 hx.resize( 5 );
23 for( size_t n = 0; n < hx.size(); ++n )
24 hx[n] = n;
25
26 dx.upload( hx.data(), hx.size() );
27 REQUIRE( dx.size() == hx.size() );
28
29 cublasHandle_t handle;
30 cublasStatus_t stat;
31 stat = cublasCreate( &handle );
32 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
33
34 stat = mx::cuda::cublasTscal( handle, dx.size(), &alpha, dx(), 1 );
35 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
36
37 dx.download( hx.data() );
38
39 REQUIRE( hx[0] == 0 );
40 REQUIRE( hx[1] == 2 );
41 REQUIRE( hx[2] == 4 );
42 REQUIRE( hx[3] == 6 );
43 REQUIRE( hx[4] == 8 );
44
45 stat = cublasDestroy( handle );
46
47 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
48 }
49
50 WHEN( "type is double precision real" )
51 {
52 std::vector<double> hx;
53 mx::cuda::cudaPtr<double> dx;
54 double alpha = 2;
55
56 hx.resize( 5 );
57 for( size_t n = 0; n < hx.size(); ++n )
58 hx[n] = n;
59
60 dx.upload( hx.data(), hx.size() );
61 REQUIRE( dx.size() == hx.size() );
62
63 cublasHandle_t handle;
64 cublasStatus_t stat;
65 stat = cublasCreate( &handle );
66 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
67
68 stat = mx::cuda::cublasTscal( handle, dx.size(), &alpha, dx(), 1 );
69 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
70
71 dx.download( hx.data() );
72
73 REQUIRE( hx[0] == 0 );
74 REQUIRE( hx[1] == 2 );
75 REQUIRE( hx[2] == 4 );
76 REQUIRE( hx[3] == 6 );
77 REQUIRE( hx[4] == 8 );
78
79 stat = cublasDestroy( handle );
80
81 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
82 }
83
84 WHEN( "type is single precision complex" )
85 {
86 std::vector<std::complex<float>> hx;
87 mx::cuda::cudaPtr<std::complex<float>> dx;
88 std::complex<float> alpha = 2;
89
90 hx.resize( 5 );
91 for( size_t n = 0; n < hx.size(); ++n )
92 hx[n] = std::complex<float>( n, n );
93
94 dx.upload( hx.data(), hx.size() );
95 REQUIRE( dx.size() == hx.size() );
96
97 cublasHandle_t handle;
98 cublasStatus_t stat;
99 stat = cublasCreate( &handle );
100 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
101
102 stat = mx::cuda::cublasTscal( handle, dx.size(), (cuComplex *)&alpha, dx(), 1 );
103 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
104
105 dx.download( hx.data() );
106
107 REQUIRE( hx[0] == std::complex<float>( 0, 0 ) );
108 REQUIRE( hx[1] == std::complex<float>( 2, 2 ) );
109 REQUIRE( hx[2] == std::complex<float>( 4, 4 ) );
110 REQUIRE( hx[3] == std::complex<float>( 6, 6 ) );
111 REQUIRE( hx[4] == std::complex<float>( 8, 8 ) );
112
113 stat = cublasDestroy( handle );
114
115 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
116 }
117
118 WHEN( "type is double precision complex" )
119 {
120 std::vector<std::complex<double>> hx;
121 mx::cuda::cudaPtr<std::complex<double>> dx;
122 std::complex<double> alpha = 2;
123
124 hx.resize( 5 );
125 for( size_t n = 0; n < hx.size(); ++n )
126 hx[n] = std::complex<double>( n, n );
127
128 dx.upload( hx.data(), hx.size() );
129 REQUIRE( dx.size() == hx.size() );
130
131 cublasHandle_t handle;
132 cublasStatus_t stat;
133 stat = cublasCreate( &handle );
134 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
135
136 stat = mx::cuda::cublasTscal( handle, dx.size(), (cuDoubleComplex *)&alpha, dx(), 1 );
137 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
138
139 dx.download( hx.data() );
140
141 REQUIRE( hx[0] == std::complex<double>( 0, 0 ) );
142 REQUIRE( hx[1] == std::complex<double>( 2, 2 ) );
143 REQUIRE( hx[2] == std::complex<double>( 4, 4 ) );
144 REQUIRE( hx[3] == std::complex<double>( 6, 6 ) );
145 REQUIRE( hx[4] == std::complex<double>( 8, 8 ) );
146
147 stat = cublasDestroy( handle );
148
149 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
150 }
151 }
152}
153
154/** \test Scenario: scaling and accumulating a vector with cublas
155 * Tests cublasTaxpy, as well as basic cudaPtr operations.
156 *
157 * \anchor test_math_templateCublas_axpy
158 */
159SCENARIO( "scaling and accumulating a vector with cublas", "[math::cuda::templateCublas]" )
160{
161 GIVEN( "a vector" )
162 {
163 WHEN( "type is single precision real" )
164 {
165 std::vector<float> hx, hy;
166 mx::cuda::cudaPtr<float> dx, dy;
167 float alpha = 2;
168
169 hx.resize( 5 );
170 for( size_t n = 0; n < hx.size(); ++n )
171 hx[n] = n;
172
173 hy.resize( 5 );
174 for( size_t n = 0; n < hy.size(); ++n )
175 hy[n] = 1;
176
177 dx.upload( hx.data(), hx.size() );
178 REQUIRE( dx.size() == hx.size() );
179
180 dy.upload( hy.data(), hy.size() );
181 REQUIRE( dy.size() == hy.size() );
182
183 cublasHandle_t handle;
184 cublasStatus_t stat;
185 stat = cublasCreate( &handle );
186 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
187
188 stat = mx::cuda::cublasTaxpy( handle, dx.size(), &alpha, dx(), 1, dy(), 1 );
189 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
190
191 dy.download( hy.data() );
192
193 REQUIRE( hy[0] == 1 );
194 REQUIRE( hy[1] == 3 );
195 REQUIRE( hy[2] == 5 );
196 REQUIRE( hy[3] == 7 );
197 REQUIRE( hy[4] == 9 );
198
199 stat = cublasDestroy( handle );
200
201 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
202 }
203
204 WHEN( "type is double precision real" )
205 {
206 std::vector<double> hx, hy;
207 mx::cuda::cudaPtr<double> dx, dy;
208 double alpha = 2;
209
210 hx.resize( 5 );
211 for( size_t n = 0; n < hx.size(); ++n )
212 hx[n] = n;
213
214 hy.resize( 5 );
215 for( size_t n = 0; n < hy.size(); ++n )
216 hy[n] = 1;
217
218 dx.upload( hx.data(), hx.size() );
219 REQUIRE( dx.size() == hx.size() );
220
221 dy.upload( hy.data(), hy.size() );
222 REQUIRE( dy.size() == hy.size() );
223
224 cublasHandle_t handle;
225 cublasStatus_t stat;
226 stat = cublasCreate( &handle );
227 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
228
229 stat = mx::cuda::cublasTaxpy( handle, dx.size(), &alpha, dx(), 1, dy(), 1 );
230 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
231
232 dy.download( hy.data() );
233
234 REQUIRE( hy[0] == 1 );
235 REQUIRE( hy[1] == 3 );
236 REQUIRE( hy[2] == 5 );
237 REQUIRE( hy[3] == 7 );
238 REQUIRE( hy[4] == 9 );
239
240 stat = cublasDestroy( handle );
241
242 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
243 }
244
245 WHEN( "type is single precision complex" )
246 {
247 std::vector<std::complex<float>> hx, hy;
248 mx::cuda::cudaPtr<std::complex<float>> dx, dy;
249 std::complex<float> alpha = 2;
250
251 hx.resize( 5 );
252 for( size_t n = 0; n < hx.size(); ++n )
253 hx[n] = std::complex<float>( n, n );
254
255 hy.resize( 5 );
256 for( size_t n = 0; n < hy.size(); ++n )
257 hy[n] = std::complex<float>( 1, 1 );
258
259 dx.upload( hx.data(), hx.size() );
260 REQUIRE( dx.size() == hx.size() );
261
262 dy.upload( hy.data(), hy.size() );
263 REQUIRE( dy.size() == hy.size() );
264
265 cublasHandle_t handle;
266 cublasStatus_t stat;
267 stat = cublasCreate( &handle );
268 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
269
270 stat = mx::cuda::cublasTaxpy( handle, dx.size(), (cuComplex *)&alpha, dx(), 1, dy(), 1 );
271 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
272
273 dy.download( hy.data() );
274
275 REQUIRE( hy[0] == std::complex<float>( 1, 1 ) );
276 REQUIRE( hy[1] == std::complex<float>( 3, 3 ) );
277 REQUIRE( hy[2] == std::complex<float>( 5, 5 ) );
278 REQUIRE( hy[3] == std::complex<float>( 7, 7 ) );
279 REQUIRE( hy[4] == std::complex<float>( 9, 9 ) );
280
281 stat = cublasDestroy( handle );
282
283 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
284 }
285
286 WHEN( "type is double precision complex" )
287 {
288 std::vector<std::complex<double>> hx, hy;
289 mx::cuda::cudaPtr<std::complex<double>> dx, dy;
290 std::complex<double> alpha = 2;
291
292 hx.resize( 5 );
293 for( size_t n = 0; n < hx.size(); ++n )
294 hx[n] = std::complex<double>( n, n );
295
296 hy.resize( 5 );
297 for( size_t n = 0; n < hy.size(); ++n )
298 hy[n] = std::complex<double>( 1, 1 );
299
300 dx.upload( hx.data(), hx.size() );
301 REQUIRE( dx.size() == hx.size() );
302
303 dy.upload( hy.data(), hy.size() );
304 REQUIRE( dy.size() == hy.size() );
305
306 cublasHandle_t handle;
307 cublasStatus_t stat;
308 stat = cublasCreate( &handle );
309 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
310
311 stat = mx::cuda::cublasTaxpy( handle, dx.size(), (cuDoubleComplex *)&alpha, dx(), 1, dy(), 1 );
312 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
313
314 dy.download( hy.data() );
315
316 REQUIRE( hy[0] == std::complex<double>( 1, 1 ) );
317 REQUIRE( hy[1] == std::complex<double>( 3, 3 ) );
318 REQUIRE( hy[2] == std::complex<double>( 5, 5 ) );
319 REQUIRE( hy[3] == std::complex<double>( 7, 7 ) );
320 REQUIRE( hy[4] == std::complex<double>( 9, 9 ) );
321
322 stat = cublasDestroy( handle );
323
324 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
325 }
326 }
327}
328
329/** \test Scenario: multiplying two vectors element by element
330 * Tests mx::cuda::elementwiseXxY, as well as basic cudaPtr operations.
331 *
332 * \anchor test_math_templateCublas_elementwiseXxY
333 */
334SCENARIO( "multiplying two vectors element by element", "[math::cuda::templateCublas]" )
335{
336 GIVEN( "a vector" )
337 {
338 WHEN( "both types are single precision real" )
339 {
340 std::vector<float> hx, hy;
341 mx::cuda::cudaPtr<float> dx, dy;
342
343 hx.resize( 5 );
344 for( size_t n = 0; n < hx.size(); ++n )
345 hx[n] = n;
346
347 hy.resize( 5 );
348 for( size_t n = 0; n < hy.size(); ++n )
349 hy[n] = 2 * n;
350
351 dx.upload( hx.data(), hx.size() );
352 REQUIRE( dx.size() == hx.size() );
353
354 dy.resize( hy.size() );
355 dy.upload( hy.data() );
356 REQUIRE( dy.size() == hy.size() );
357
358 cudaError_t rv = mx::cuda::elementwiseXxY( dx(), dy(), dx.size() );
359 REQUIRE( rv == cudaSuccess );
360
361 dx.download( hx.data() );
362
363 REQUIRE( hx[0] == 0 );
364 REQUIRE( hx[1] == 2 );
365 REQUIRE( hx[2] == 8 );
366 REQUIRE( hx[3] == 18 );
367 REQUIRE( hx[4] == 32 );
368 }
369
370 WHEN( "type1 is complex-float, and type2 is float" )
371 {
372 std::vector<std::complex<float>> hx;
373 mx::cuda::cudaPtr<std::complex<float>> dx;
374
375 std::vector<float> hy;
376 mx::cuda::cudaPtr<float> dy;
377
378 hx.resize( 5 );
379 for( size_t n = 0; n < hx.size(); ++n )
380 hx[n] = std::complex<float>( n, n );
381
382 hy.resize( 5 );
383 for( size_t n = 0; n < hy.size(); ++n )
384 hy[n] = 2 * n;
385
386 dx.upload( hx.data(), hx.size() );
387 REQUIRE( dx.size() == hx.size() );
388
389 dy.resize( hy.size() );
390 dy.upload( hy.data() );
391 REQUIRE( dy.size() == hy.size() );
392
393 cudaError_t rv = mx::cuda::elementwiseXxY( dx(), dy(), dx.size() );
394 REQUIRE( rv == cudaSuccess );
395
396 dx.download( hx.data() );
397
398 REQUIRE( hx[0] == std::complex<float>( 0, 0 ) );
399 REQUIRE( hx[1] == std::complex<float>( 2, 2 ) );
400 REQUIRE( hx[2] == std::complex<float>( 8, 8 ) );
401 REQUIRE( hx[3] == std::complex<float>( 18, 18 ) );
402 REQUIRE( hx[4] == std::complex<float>( 32, 32 ) );
403 }
404
405 WHEN( "type1 is complex-float, and type2 is complex-float" )
406 {
407 std::vector<std::complex<float>> hx;
408 mx::cuda::cudaPtr<std::complex<float>> dx;
409
410 std::vector<std::complex<float>> hy;
411 mx::cuda::cudaPtr<std::complex<float>> dy;
412
413 hx.resize( 5 );
414 for( size_t n = 0; n < hx.size(); ++n )
415 hx[n] = std::complex<float>( n, n );
416
417 hy.resize( 5 );
418 for( size_t n = 0; n < hy.size(); ++n )
419 hy[n] = std::complex<float>( 0, 2 * n );
420
421 dx.upload( hx.data(), hx.size() );
422 REQUIRE( dx.size() == hx.size() );
423
424 dy.resize( hy.size() );
425 dy.upload( hy.data() );
426 REQUIRE( dy.size() == hy.size() );
427
428 cudaError_t rv = mx::cuda::elementwiseXxY( dx(), dy(), dx.size() );
429 REQUIRE( rv == cudaSuccess );
430
431 dx.download( hx.data() );
432
433 REQUIRE( hx[0] == std::complex<float>( 0, 0 ) );
434 REQUIRE( hx[1] == std::complex<float>( -2, 2 ) ); //(1,1) * (0,2) = (0 + 2i + 0i -2) = (-2,2)
435 REQUIRE( hx[2] == std::complex<float>( -8, 8 ) );
436 REQUIRE( hx[3] == std::complex<float>( -18, 18 ) );
437 REQUIRE( hx[4] == std::complex<float>( -32, 32 ) );
438 }
439
440 WHEN( "both types are double precision real" )
441 {
442 std::vector<double> hx, hy;
443 mx::cuda::cudaPtr<double> dx, dy;
444
445 hx.resize( 5 );
446 for( size_t n = 0; n < hx.size(); ++n )
447 hx[n] = n;
448
449 hy.resize( 5 );
450 for( size_t n = 0; n < hy.size(); ++n )
451 hy[n] = 2 * n;
452
453 dx.upload( hx.data(), hx.size() );
454 REQUIRE( dx.size() == hx.size() );
455
456 dy.resize( hy.size() );
457 dy.upload( hy.data() );
458 REQUIRE( dy.size() == hy.size() );
459
460 cudaError_t rv = mx::cuda::elementwiseXxY( dx(), dy(), dx.size() );
461 REQUIRE( rv == cudaSuccess );
462
463 dx.download( hx.data() );
464
465 REQUIRE( hx[0] == 0 );
466 REQUIRE( hx[1] == 2 );
467 REQUIRE( hx[2] == 8 );
468 REQUIRE( hx[3] == 18 );
469 REQUIRE( hx[4] == 32 );
470 }
471
472 WHEN( "type1 is complex-double, and type2 is double" )
473 {
474 std::vector<std::complex<double>> hx;
475 mx::cuda::cudaPtr<std::complex<double>> dx;
476
477 std::vector<double> hy;
478 mx::cuda::cudaPtr<double> dy;
479
480 hx.resize( 5 );
481 for( size_t n = 0; n < hx.size(); ++n )
482 hx[n] = std::complex<double>( n, n );
483
484 hy.resize( 5 );
485 for( size_t n = 0; n < hy.size(); ++n )
486 hy[n] = 2 * n;
487
488 dx.upload( hx.data(), hx.size() );
489 REQUIRE( dx.size() == hx.size() );
490
491 dy.resize( hy.size() );
492 dy.upload( hy.data() );
493 REQUIRE( dy.size() == hy.size() );
494
495 cudaError_t rv = mx::cuda::elementwiseXxY( dx(), dy(), dx.size() );
496 REQUIRE( rv == cudaSuccess );
497
498 dx.download( hx.data() );
499
500 REQUIRE( hx[0] == std::complex<double>( 0, 0 ) );
501 REQUIRE( hx[1] == std::complex<double>( 2, 2 ) );
502 REQUIRE( hx[2] == std::complex<double>( 8, 8 ) );
503 REQUIRE( hx[3] == std::complex<double>( 18, 18 ) );
504 REQUIRE( hx[4] == std::complex<double>( 32, 32 ) );
505 }
506
507 WHEN( "type1 is complex-double, and type2 is complex-double" )
508 {
509 std::vector<std::complex<double>> hx;
510 mx::cuda::cudaPtr<std::complex<double>> dx;
511
512 std::vector<std::complex<double>> hy;
513 mx::cuda::cudaPtr<std::complex<double>> dy;
514
515 hx.resize( 5 );
516 for( size_t n = 0; n < hx.size(); ++n )
517 hx[n] = std::complex<double>( n, n );
518
519 hy.resize( 5 );
520 for( size_t n = 0; n < hy.size(); ++n )
521 hy[n] = std::complex<double>( 1, 2 * n );
522
523 dx.upload( hx.data(), hx.size() );
524 REQUIRE( dx.size() == hx.size() );
525
526 dy.resize( hy.size() );
527 dy.upload( hy.data() );
528 REQUIRE( dy.size() == hy.size() );
529
530 cudaError_t rv = mx::cuda::elementwiseXxY( dx(), dy(), dx.size() );
531 REQUIRE( rv == cudaSuccess );
532
533 dx.download( hx.data() );
534
535 REQUIRE( hx[0] == std::complex<double>( 0, 0 ) ); //(0,0) * (0,0) = (0,0)
536 REQUIRE( hx[1] == std::complex<double>( -1, 3 ) ); //(1,1) * (1,2) = (1 + 2i + i -2) = (-1,3)
537 REQUIRE( hx[2] == std::complex<double>( -6, 10 ) );
538 REQUIRE( hx[3] == std::complex<double>( -15, 21 ) );
539 REQUIRE( hx[4] == std::complex<double>( -28, 36 ) );
540 }
541 }
542}
543
544/** Scenario: multiplying a vector by a matrix
545 * Tests mx::cuda::cublasTgemv, as well as basic cudaPtr operations.
546 *
547 * \anchor test_math_templateCublas_cublasTgemv_inc
548 */
549SCENARIO( "multiplying a vector by a matrix giving increments", "[math::cuda::templateCublas]" )
550{
551 GIVEN( "a 2x2 matrix, float" )
552 {
553 WHEN( "float precision, beta is 0" )
554 {
555 std::vector<float> hA; // This will actually be a vector
556 mx::cuda::cudaPtr<float> dA;
557
558 std::vector<float> hx; // This will actually be a vector
559 mx::cuda::cudaPtr<float> dx;
560
561 std::vector<float> hy;
562 mx::cuda::cudaPtr<float> dy;
563
564 /* Column major order:
565 1 3
566 2 4
567 */
568 hA.resize( 4 );
569 hA[0] = 1;
570 hA[1] = 2;
571 hA[2] = 3;
572 hA[3] = 4;
573
574 dA.resize( 2, 2 );
575 dA.upload( hA.data() );
576
577 hx.resize( 2 );
578 hx[0] = 1;
579 hx[1] = 2;
580
581 dx.upload( hx.data(), hx.size() );
582
583 hy.resize( 2 );
584
585 dy.resize( 2 );
586 dy.initialize();
587
588 float alpha = 1;
589 float beta = 0;
590
591 cublasHandle_t handle;
592 cublasStatus_t stat;
593 stat = cublasCreate( &handle );
594 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
595
596 stat = mx::cuda::cublasTgemv( handle, CUBLAS_OP_N, 2, 2, &alpha, dA(), 2, dx(), 1, &beta, dy(), 1 );
597 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
598
599 dy.download( hy.data() );
600
601 REQUIRE( hy[0] == 7 );
602 REQUIRE( hy[1] == 10 );
603
604 stat = cublasDestroy( handle );
605 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
606 }
607
608 WHEN( "float precision, beta is 1, but y is all 0" )
609 {
610 std::vector<float> hA; // This will actually be a vector
611 mx::cuda::cudaPtr<float> dA;
612
613 std::vector<float> hx; // This will actually be a vector
614 mx::cuda::cudaPtr<float> dx;
615
616 std::vector<float> hy;
617 mx::cuda::cudaPtr<float> dy;
618
619 /* Column major order:
620 1 3
621 2 4
622 */
623 hA.resize( 4 );
624 hA[0] = 1;
625 hA[1] = 2;
626 hA[2] = 3;
627 hA[3] = 4;
628
629 dA.resize( 2, 2 );
630 dA.upload( hA.data() );
631
632 hx.resize( 2 );
633 hx[0] = 1;
634 hx[1] = 2;
635
636 dx.upload( hx.data(), hx.size() );
637
638 hy.resize( 2 );
639
640 dy.resize( 2 );
641 dy.initialize();
642
643 float alpha = 1;
644 float beta = 1;
645
646 cublasHandle_t handle;
647 cublasStatus_t stat;
648 stat = cublasCreate( &handle );
649 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
650
651 stat = mx::cuda::cublasTgemv( handle, CUBLAS_OP_N, 2, 2, &alpha, dA(), 2, dx(), 1, &beta, dy(), 1 );
652 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
653
654 dy.download( hy.data() );
655
656 REQUIRE( hy[0] == 7 );
657 REQUIRE( hy[1] == 10 );
658
659 stat = cublasDestroy( handle );
660 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
661 }
662 WHEN( "float precision, beta is 1, y is [1,2]" )
663 {
664 std::vector<float> hA; // This will actually be a vector
665 mx::cuda::cudaPtr<float> dA;
666
667 std::vector<float> hx; // This will actually be a vector
668 mx::cuda::cudaPtr<float> dx;
669
670 std::vector<float> hy;
671 mx::cuda::cudaPtr<float> dy;
672
673 /* Column major order:
674 1 3
675 2 4
676 */
677 hA.resize( 4 );
678 hA[0] = 1;
679 hA[1] = 2;
680 hA[2] = 3;
681 hA[3] = 4;
682
683 dA.resize( 2, 2 );
684 dA.upload( hA.data() );
685
686 hx.resize( 2 );
687 hx[0] = 1;
688 hx[1] = 2;
689
690 dx.upload( hx.data(), hx.size() );
691
692 hy.resize( 2 );
693 hy[0] = 1;
694 hy[1] = 2;
695
696 dy.resize( 2 );
697 dy.upload( hx.data() );
698
699 float alpha = 1;
700 float beta = 1;
701
702 cublasHandle_t handle;
703 cublasStatus_t stat;
704 stat = cublasCreate( &handle );
705 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
706
707 stat = mx::cuda::cublasTgemv( handle, CUBLAS_OP_N, 2, 2, &alpha, dA(), 2, dx(), 1, &beta, dy(), 1 );
708 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
709
710 dy.download( hy.data() );
711
712 REQUIRE( hy[0] == 8 );
713 REQUIRE( hy[1] == 12 );
714
715 stat = cublasDestroy( handle );
716 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
717 }
718 }
719
720 GIVEN( "a 2x2 matrix, double" )
721 {
722 WHEN( "double precision, beta is 0" )
723 {
724 std::vector<double> hA; // This will actually be a vector
725 mx::cuda::cudaPtr<double> dA;
726
727 std::vector<double> hx; // This will actually be a vector
728 mx::cuda::cudaPtr<double> dx;
729
730 std::vector<double> hy;
731 mx::cuda::cudaPtr<double> dy;
732
733 /* Column major order:
734 1 3
735 2 4
736 */
737 hA.resize( 4 );
738 hA[0] = 1;
739 hA[1] = 2;
740 hA[2] = 3;
741 hA[3] = 4;
742
743 dA.resize( 2, 2 );
744 dA.upload( hA.data() );
745
746 hx.resize( 2 );
747 hx[0] = 1;
748 hx[1] = 2;
749
750 dx.upload( hx.data(), hx.size() );
751
752 hy.resize( 2 );
753
754 dy.resize( 2 );
755 dy.initialize();
756
757 double alpha = 1;
758 double beta = 0;
759
760 cublasHandle_t handle;
761 cublasStatus_t stat;
762 stat = cublasCreate( &handle );
763 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
764
765 stat = mx::cuda::cublasTgemv( handle, CUBLAS_OP_N, 2, 2, &alpha, dA(), 2, dx(), 1, &beta, dy(), 1 );
766 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
767
768 dy.download( hy.data() );
769
770 REQUIRE( hy[0] == 7 );
771 REQUIRE( hy[1] == 10 );
772
773 stat = cublasDestroy( handle );
774 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
775 }
776
777 WHEN( "double precision, beta is 1, but y is all 0" )
778 {
779 std::vector<double> hA; // This will actually be a vector
780 mx::cuda::cudaPtr<double> dA;
781
782 std::vector<double> hx; // This will actually be a vector
783 mx::cuda::cudaPtr<double> dx;
784
785 std::vector<double> hy;
786 mx::cuda::cudaPtr<double> dy;
787
788 /* Column major order:
789 1 3
790 2 4
791 */
792 hA.resize( 4 );
793 hA[0] = 1;
794 hA[1] = 2;
795 hA[2] = 3;
796 hA[3] = 4;
797
798 dA.resize( 2, 2 );
799 dA.upload( hA.data() );
800
801 hx.resize( 2 );
802 hx[0] = 1;
803 hx[1] = 2;
804
805 dx.upload( hx.data(), hx.size() );
806
807 hy.resize( 2 );
808
809 dy.resize( 2 );
810 dy.initialize();
811
812 double alpha = 1;
813 double beta = 1;
814
815 cublasHandle_t handle;
816 cublasStatus_t stat;
817 stat = cublasCreate( &handle );
818 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
819
820 stat = mx::cuda::cublasTgemv( handle, CUBLAS_OP_N, 2, 2, &alpha, dA(), 2, dx(), 1, &beta, dy(), 1 );
821 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
822
823 dy.download( hy.data() );
824
825 REQUIRE( hy[0] == 7 );
826 REQUIRE( hy[1] == 10 );
827
828 stat = cublasDestroy( handle );
829 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
830 }
831 WHEN( "double precision, beta is 1, y is [1,2]" )
832 {
833 std::vector<double> hA; // This will actually be a vector
834 mx::cuda::cudaPtr<double> dA;
835
836 std::vector<double> hx; // This will actually be a vector
837 mx::cuda::cudaPtr<double> dx;
838
839 std::vector<double> hy;
840 mx::cuda::cudaPtr<double> dy;
841
842 /* Column major order:
843 1 3
844 2 4
845 */
846 hA.resize( 4 );
847 hA[0] = 1;
848 hA[1] = 2;
849 hA[2] = 3;
850 hA[3] = 4;
851
852 dA.resize( 2, 2 );
853 dA.upload( hA.data() );
854
855 hx.resize( 2 );
856 hx[0] = 1;
857 hx[1] = 2;
858
859 dx.upload( hx.data(), hx.size() );
860
861 hy.resize( 2 );
862 hy[0] = 1;
863 hy[1] = 2;
864
865 dy.resize( 2 );
866 dy.upload( hx.data() );
867
868 double alpha = 1;
869 double beta = 1;
870
871 cublasHandle_t handle;
872 cublasStatus_t stat;
873 stat = cublasCreate( &handle );
874 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
875
876 stat = mx::cuda::cublasTgemv( handle, CUBLAS_OP_N, 2, 2, &alpha, dA(), 2, dx(), 1, &beta, dy(), 1 );
877 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
878
879 dy.download( hy.data() );
880
881 REQUIRE( hy[0] == 8 );
882 REQUIRE( hy[1] == 12 );
883
884 stat = cublasDestroy( handle );
885 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
886 }
887 }
888}
SCENARIO("scaling a vector with cublas", "[math::cuda::templateCublas]")
Tests cublasTscal, as well as basic cudaPtr operations.