12SCENARIO(
"scaling a vector with cublas",
"[math::cuda::templateCublas]" )
16 WHEN(
"type is single precision real" )
18 std::vector<float> hx;
19 mx::cuda::cudaPtr<float> dx;
23 for(
size_t n = 0; n < hx.size(); ++n )
26 dx.upload( hx.data(), hx.size() );
27 REQUIRE( dx.size() == hx.size() );
29 cublasHandle_t handle;
31 stat = cublasCreate( &handle );
32 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
34 stat = mx::cuda::cublasTscal( handle, dx.size(), &alpha, dx(), 1 );
35 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
37 dx.download( hx.data() );
39 REQUIRE( hx[0] == 0 );
40 REQUIRE( hx[1] == 2 );
41 REQUIRE( hx[2] == 4 );
42 REQUIRE( hx[3] == 6 );
43 REQUIRE( hx[4] == 8 );
45 stat = cublasDestroy( handle );
47 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
50 WHEN(
"type is double precision real" )
52 std::vector<double> hx;
53 mx::cuda::cudaPtr<double> dx;
57 for(
size_t n = 0; n < hx.size(); ++n )
60 dx.upload( hx.data(), hx.size() );
61 REQUIRE( dx.size() == hx.size() );
63 cublasHandle_t handle;
65 stat = cublasCreate( &handle );
66 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
68 stat = mx::cuda::cublasTscal( handle, dx.size(), &alpha, dx(), 1 );
69 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
71 dx.download( hx.data() );
73 REQUIRE( hx[0] == 0 );
74 REQUIRE( hx[1] == 2 );
75 REQUIRE( hx[2] == 4 );
76 REQUIRE( hx[3] == 6 );
77 REQUIRE( hx[4] == 8 );
79 stat = cublasDestroy( handle );
81 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
84 WHEN(
"type is single precision complex" )
86 std::vector<std::complex<float>> hx;
87 mx::cuda::cudaPtr<std::complex<float>> dx;
88 std::complex<float> alpha = 2;
91 for(
size_t n = 0; n < hx.size(); ++n )
92 hx[n] = std::complex<float>( n, n );
94 dx.upload( hx.data(), hx.size() );
95 REQUIRE( dx.size() == hx.size() );
97 cublasHandle_t handle;
99 stat = cublasCreate( &handle );
100 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
102 stat = mx::cuda::cublasTscal( handle, dx.size(), (cuComplex *)&alpha, dx(), 1 );
103 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
105 dx.download( hx.data() );
107 REQUIRE( hx[0] == std::complex<float>( 0, 0 ) );
108 REQUIRE( hx[1] == std::complex<float>( 2, 2 ) );
109 REQUIRE( hx[2] == std::complex<float>( 4, 4 ) );
110 REQUIRE( hx[3] == std::complex<float>( 6, 6 ) );
111 REQUIRE( hx[4] == std::complex<float>( 8, 8 ) );
113 stat = cublasDestroy( handle );
115 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
118 WHEN(
"type is double precision complex" )
120 std::vector<std::complex<double>> hx;
121 mx::cuda::cudaPtr<std::complex<double>> dx;
122 std::complex<double> alpha = 2;
125 for(
size_t n = 0; n < hx.size(); ++n )
126 hx[n] = std::complex<double>( n, n );
128 dx.upload( hx.data(), hx.size() );
129 REQUIRE( dx.size() == hx.size() );
131 cublasHandle_t handle;
133 stat = cublasCreate( &handle );
134 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
136 stat = mx::cuda::cublasTscal( handle, dx.size(), (cuDoubleComplex *)&alpha, dx(), 1 );
137 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
139 dx.download( hx.data() );
141 REQUIRE( hx[0] == std::complex<double>( 0, 0 ) );
142 REQUIRE( hx[1] == std::complex<double>( 2, 2 ) );
143 REQUIRE( hx[2] == std::complex<double>( 4, 4 ) );
144 REQUIRE( hx[3] == std::complex<double>( 6, 6 ) );
145 REQUIRE( hx[4] == std::complex<double>( 8, 8 ) );
147 stat = cublasDestroy( handle );
149 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
159SCENARIO(
"scaling and accumulating a vector with cublas",
"[math::cuda::templateCublas]" )
163 WHEN(
"type is single precision real" )
165 std::vector<float> hx, hy;
166 mx::cuda::cudaPtr<float> dx, dy;
170 for(
size_t n = 0; n < hx.size(); ++n )
174 for(
size_t n = 0; n < hy.size(); ++n )
177 dx.upload( hx.data(), hx.size() );
178 REQUIRE( dx.size() == hx.size() );
180 dy.upload( hy.data(), hy.size() );
181 REQUIRE( dy.size() == hy.size() );
183 cublasHandle_t handle;
185 stat = cublasCreate( &handle );
186 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
188 stat = mx::cuda::cublasTaxpy( handle, dx.size(), &alpha, dx(), 1, dy(), 1 );
189 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
191 dy.download( hy.data() );
193 REQUIRE( hy[0] == 1 );
194 REQUIRE( hy[1] == 3 );
195 REQUIRE( hy[2] == 5 );
196 REQUIRE( hy[3] == 7 );
197 REQUIRE( hy[4] == 9 );
199 stat = cublasDestroy( handle );
201 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
204 WHEN(
"type is double precision real" )
206 std::vector<double> hx, hy;
207 mx::cuda::cudaPtr<double> dx, dy;
211 for(
size_t n = 0; n < hx.size(); ++n )
215 for(
size_t n = 0; n < hy.size(); ++n )
218 dx.upload( hx.data(), hx.size() );
219 REQUIRE( dx.size() == hx.size() );
221 dy.upload( hy.data(), hy.size() );
222 REQUIRE( dy.size() == hy.size() );
224 cublasHandle_t handle;
226 stat = cublasCreate( &handle );
227 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
229 stat = mx::cuda::cublasTaxpy( handle, dx.size(), &alpha, dx(), 1, dy(), 1 );
230 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
232 dy.download( hy.data() );
234 REQUIRE( hy[0] == 1 );
235 REQUIRE( hy[1] == 3 );
236 REQUIRE( hy[2] == 5 );
237 REQUIRE( hy[3] == 7 );
238 REQUIRE( hy[4] == 9 );
240 stat = cublasDestroy( handle );
242 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
245 WHEN(
"type is single precision complex" )
247 std::vector<std::complex<float>> hx, hy;
248 mx::cuda::cudaPtr<std::complex<float>> dx, dy;
249 std::complex<float> alpha = 2;
252 for(
size_t n = 0; n < hx.size(); ++n )
253 hx[n] = std::complex<float>( n, n );
256 for(
size_t n = 0; n < hy.size(); ++n )
257 hy[n] = std::complex<float>( 1, 1 );
259 dx.upload( hx.data(), hx.size() );
260 REQUIRE( dx.size() == hx.size() );
262 dy.upload( hy.data(), hy.size() );
263 REQUIRE( dy.size() == hy.size() );
265 cublasHandle_t handle;
267 stat = cublasCreate( &handle );
268 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
270 stat = mx::cuda::cublasTaxpy( handle, dx.size(), (cuComplex *)&alpha, dx(), 1, dy(), 1 );
271 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
273 dy.download( hy.data() );
275 REQUIRE( hy[0] == std::complex<float>( 1, 1 ) );
276 REQUIRE( hy[1] == std::complex<float>( 3, 3 ) );
277 REQUIRE( hy[2] == std::complex<float>( 5, 5 ) );
278 REQUIRE( hy[3] == std::complex<float>( 7, 7 ) );
279 REQUIRE( hy[4] == std::complex<float>( 9, 9 ) );
281 stat = cublasDestroy( handle );
283 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
286 WHEN(
"type is double precision complex" )
288 std::vector<std::complex<double>> hx, hy;
289 mx::cuda::cudaPtr<std::complex<double>> dx, dy;
290 std::complex<double> alpha = 2;
293 for(
size_t n = 0; n < hx.size(); ++n )
294 hx[n] = std::complex<double>( n, n );
297 for(
size_t n = 0; n < hy.size(); ++n )
298 hy[n] = std::complex<double>( 1, 1 );
300 dx.upload( hx.data(), hx.size() );
301 REQUIRE( dx.size() == hx.size() );
303 dy.upload( hy.data(), hy.size() );
304 REQUIRE( dy.size() == hy.size() );
306 cublasHandle_t handle;
308 stat = cublasCreate( &handle );
309 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
311 stat = mx::cuda::cublasTaxpy( handle, dx.size(), (cuDoubleComplex *)&alpha, dx(), 1, dy(), 1 );
312 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
314 dy.download( hy.data() );
316 REQUIRE( hy[0] == std::complex<double>( 1, 1 ) );
317 REQUIRE( hy[1] == std::complex<double>( 3, 3 ) );
318 REQUIRE( hy[2] == std::complex<double>( 5, 5 ) );
319 REQUIRE( hy[3] == std::complex<double>( 7, 7 ) );
320 REQUIRE( hy[4] == std::complex<double>( 9, 9 ) );
322 stat = cublasDestroy( handle );
324 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
334SCENARIO(
"multiplying two vectors element by element",
"[math::cuda::templateCublas]" )
338 WHEN(
"both types are single precision real" )
340 std::vector<float> hx, hy;
341 mx::cuda::cudaPtr<float> dx, dy;
344 for(
size_t n = 0; n < hx.size(); ++n )
348 for(
size_t n = 0; n < hy.size(); ++n )
351 dx.upload( hx.data(), hx.size() );
352 REQUIRE( dx.size() == hx.size() );
354 dy.resize( hy.size() );
355 dy.upload( hy.data() );
356 REQUIRE( dy.size() == hy.size() );
358 cudaError_t rv = mx::cuda::elementwiseXxY( dx(), dy(), dx.size() );
359 REQUIRE( rv == cudaSuccess );
361 dx.download( hx.data() );
363 REQUIRE( hx[0] == 0 );
364 REQUIRE( hx[1] == 2 );
365 REQUIRE( hx[2] == 8 );
366 REQUIRE( hx[3] == 18 );
367 REQUIRE( hx[4] == 32 );
370 WHEN(
"type1 is complex-float, and type2 is float" )
372 std::vector<std::complex<float>> hx;
373 mx::cuda::cudaPtr<std::complex<float>> dx;
375 std::vector<float> hy;
376 mx::cuda::cudaPtr<float> dy;
379 for(
size_t n = 0; n < hx.size(); ++n )
380 hx[n] = std::complex<float>( n, n );
383 for(
size_t n = 0; n < hy.size(); ++n )
386 dx.upload( hx.data(), hx.size() );
387 REQUIRE( dx.size() == hx.size() );
389 dy.resize( hy.size() );
390 dy.upload( hy.data() );
391 REQUIRE( dy.size() == hy.size() );
393 cudaError_t rv = mx::cuda::elementwiseXxY( dx(), dy(), dx.size() );
394 REQUIRE( rv == cudaSuccess );
396 dx.download( hx.data() );
398 REQUIRE( hx[0] == std::complex<float>( 0, 0 ) );
399 REQUIRE( hx[1] == std::complex<float>( 2, 2 ) );
400 REQUIRE( hx[2] == std::complex<float>( 8, 8 ) );
401 REQUIRE( hx[3] == std::complex<float>( 18, 18 ) );
402 REQUIRE( hx[4] == std::complex<float>( 32, 32 ) );
405 WHEN(
"type1 is complex-float, and type2 is complex-float" )
407 std::vector<std::complex<float>> hx;
408 mx::cuda::cudaPtr<std::complex<float>> dx;
410 std::vector<std::complex<float>> hy;
411 mx::cuda::cudaPtr<std::complex<float>> dy;
414 for(
size_t n = 0; n < hx.size(); ++n )
415 hx[n] = std::complex<float>( n, n );
418 for(
size_t n = 0; n < hy.size(); ++n )
419 hy[n] = std::complex<float>( 0, 2 * n );
421 dx.upload( hx.data(), hx.size() );
422 REQUIRE( dx.size() == hx.size() );
424 dy.resize( hy.size() );
425 dy.upload( hy.data() );
426 REQUIRE( dy.size() == hy.size() );
428 cudaError_t rv = mx::cuda::elementwiseXxY( dx(), dy(), dx.size() );
429 REQUIRE( rv == cudaSuccess );
431 dx.download( hx.data() );
433 REQUIRE( hx[0] == std::complex<float>( 0, 0 ) );
434 REQUIRE( hx[1] == std::complex<float>( -2, 2 ) );
435 REQUIRE( hx[2] == std::complex<float>( -8, 8 ) );
436 REQUIRE( hx[3] == std::complex<float>( -18, 18 ) );
437 REQUIRE( hx[4] == std::complex<float>( -32, 32 ) );
440 WHEN(
"both types are double precision real" )
442 std::vector<double> hx, hy;
443 mx::cuda::cudaPtr<double> dx, dy;
446 for(
size_t n = 0; n < hx.size(); ++n )
450 for(
size_t n = 0; n < hy.size(); ++n )
453 dx.upload( hx.data(), hx.size() );
454 REQUIRE( dx.size() == hx.size() );
456 dy.resize( hy.size() );
457 dy.upload( hy.data() );
458 REQUIRE( dy.size() == hy.size() );
460 cudaError_t rv = mx::cuda::elementwiseXxY( dx(), dy(), dx.size() );
461 REQUIRE( rv == cudaSuccess );
463 dx.download( hx.data() );
465 REQUIRE( hx[0] == 0 );
466 REQUIRE( hx[1] == 2 );
467 REQUIRE( hx[2] == 8 );
468 REQUIRE( hx[3] == 18 );
469 REQUIRE( hx[4] == 32 );
472 WHEN(
"type1 is complex-double, and type2 is double" )
474 std::vector<std::complex<double>> hx;
475 mx::cuda::cudaPtr<std::complex<double>> dx;
477 std::vector<double> hy;
478 mx::cuda::cudaPtr<double> dy;
481 for(
size_t n = 0; n < hx.size(); ++n )
482 hx[n] = std::complex<double>( n, n );
485 for(
size_t n = 0; n < hy.size(); ++n )
488 dx.upload( hx.data(), hx.size() );
489 REQUIRE( dx.size() == hx.size() );
491 dy.resize( hy.size() );
492 dy.upload( hy.data() );
493 REQUIRE( dy.size() == hy.size() );
495 cudaError_t rv = mx::cuda::elementwiseXxY( dx(), dy(), dx.size() );
496 REQUIRE( rv == cudaSuccess );
498 dx.download( hx.data() );
500 REQUIRE( hx[0] == std::complex<double>( 0, 0 ) );
501 REQUIRE( hx[1] == std::complex<double>( 2, 2 ) );
502 REQUIRE( hx[2] == std::complex<double>( 8, 8 ) );
503 REQUIRE( hx[3] == std::complex<double>( 18, 18 ) );
504 REQUIRE( hx[4] == std::complex<double>( 32, 32 ) );
507 WHEN(
"type1 is complex-double, and type2 is complex-double" )
509 std::vector<std::complex<double>> hx;
510 mx::cuda::cudaPtr<std::complex<double>> dx;
512 std::vector<std::complex<double>> hy;
513 mx::cuda::cudaPtr<std::complex<double>> dy;
516 for(
size_t n = 0; n < hx.size(); ++n )
517 hx[n] = std::complex<double>( n, n );
520 for(
size_t n = 0; n < hy.size(); ++n )
521 hy[n] = std::complex<double>( 1, 2 * n );
523 dx.upload( hx.data(), hx.size() );
524 REQUIRE( dx.size() == hx.size() );
526 dy.resize( hy.size() );
527 dy.upload( hy.data() );
528 REQUIRE( dy.size() == hy.size() );
530 cudaError_t rv = mx::cuda::elementwiseXxY( dx(), dy(), dx.size() );
531 REQUIRE( rv == cudaSuccess );
533 dx.download( hx.data() );
535 REQUIRE( hx[0] == std::complex<double>( 0, 0 ) );
536 REQUIRE( hx[1] == std::complex<double>( -1, 3 ) );
537 REQUIRE( hx[2] == std::complex<double>( -6, 10 ) );
538 REQUIRE( hx[3] == std::complex<double>( -15, 21 ) );
539 REQUIRE( hx[4] == std::complex<double>( -28, 36 ) );
549SCENARIO(
"multiplying a vector by a matrix giving increments",
"[math::cuda::templateCublas]" )
551 GIVEN(
"a 2x2 matrix, float" )
553 WHEN(
"float precision, beta is 0" )
555 std::vector<float> hA;
556 mx::cuda::cudaPtr<float> dA;
558 std::vector<float> hx;
559 mx::cuda::cudaPtr<float> dx;
561 std::vector<float> hy;
562 mx::cuda::cudaPtr<float> dy;
575 dA.upload( hA.data() );
581 dx.upload( hx.data(), hx.size() );
591 cublasHandle_t handle;
593 stat = cublasCreate( &handle );
594 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
596 stat = mx::cuda::cublasTgemv( handle, CUBLAS_OP_N, 2, 2, &alpha, dA(), 2, dx(), 1, &beta, dy(), 1 );
597 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
599 dy.download( hy.data() );
601 REQUIRE( hy[0] == 7 );
602 REQUIRE( hy[1] == 10 );
604 stat = cublasDestroy( handle );
605 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
608 WHEN(
"float precision, beta is 1, but y is all 0" )
610 std::vector<float> hA;
611 mx::cuda::cudaPtr<float> dA;
613 std::vector<float> hx;
614 mx::cuda::cudaPtr<float> dx;
616 std::vector<float> hy;
617 mx::cuda::cudaPtr<float> dy;
630 dA.upload( hA.data() );
636 dx.upload( hx.data(), hx.size() );
646 cublasHandle_t handle;
648 stat = cublasCreate( &handle );
649 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
651 stat = mx::cuda::cublasTgemv( handle, CUBLAS_OP_N, 2, 2, &alpha, dA(), 2, dx(), 1, &beta, dy(), 1 );
652 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
654 dy.download( hy.data() );
656 REQUIRE( hy[0] == 7 );
657 REQUIRE( hy[1] == 10 );
659 stat = cublasDestroy( handle );
660 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
662 WHEN(
"float precision, beta is 1, y is [1,2]" )
664 std::vector<float> hA;
665 mx::cuda::cudaPtr<float> dA;
667 std::vector<float> hx;
668 mx::cuda::cudaPtr<float> dx;
670 std::vector<float> hy;
671 mx::cuda::cudaPtr<float> dy;
684 dA.upload( hA.data() );
690 dx.upload( hx.data(), hx.size() );
697 dy.upload( hx.data() );
702 cublasHandle_t handle;
704 stat = cublasCreate( &handle );
705 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
707 stat = mx::cuda::cublasTgemv( handle, CUBLAS_OP_N, 2, 2, &alpha, dA(), 2, dx(), 1, &beta, dy(), 1 );
708 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
710 dy.download( hy.data() );
712 REQUIRE( hy[0] == 8 );
713 REQUIRE( hy[1] == 12 );
715 stat = cublasDestroy( handle );
716 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
720 GIVEN(
"a 2x2 matrix, double" )
722 WHEN(
"double precision, beta is 0" )
724 std::vector<double> hA;
725 mx::cuda::cudaPtr<double> dA;
727 std::vector<double> hx;
728 mx::cuda::cudaPtr<double> dx;
730 std::vector<double> hy;
731 mx::cuda::cudaPtr<double> dy;
744 dA.upload( hA.data() );
750 dx.upload( hx.data(), hx.size() );
760 cublasHandle_t handle;
762 stat = cublasCreate( &handle );
763 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
765 stat = mx::cuda::cublasTgemv( handle, CUBLAS_OP_N, 2, 2, &alpha, dA(), 2, dx(), 1, &beta, dy(), 1 );
766 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
768 dy.download( hy.data() );
770 REQUIRE( hy[0] == 7 );
771 REQUIRE( hy[1] == 10 );
773 stat = cublasDestroy( handle );
774 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
777 WHEN(
"double precision, beta is 1, but y is all 0" )
779 std::vector<double> hA;
780 mx::cuda::cudaPtr<double> dA;
782 std::vector<double> hx;
783 mx::cuda::cudaPtr<double> dx;
785 std::vector<double> hy;
786 mx::cuda::cudaPtr<double> dy;
799 dA.upload( hA.data() );
805 dx.upload( hx.data(), hx.size() );
815 cublasHandle_t handle;
817 stat = cublasCreate( &handle );
818 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
820 stat = mx::cuda::cublasTgemv( handle, CUBLAS_OP_N, 2, 2, &alpha, dA(), 2, dx(), 1, &beta, dy(), 1 );
821 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
823 dy.download( hy.data() );
825 REQUIRE( hy[0] == 7 );
826 REQUIRE( hy[1] == 10 );
828 stat = cublasDestroy( handle );
829 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
831 WHEN(
"double precision, beta is 1, y is [1,2]" )
833 std::vector<double> hA;
834 mx::cuda::cudaPtr<double> dA;
836 std::vector<double> hx;
837 mx::cuda::cudaPtr<double> dx;
839 std::vector<double> hy;
840 mx::cuda::cudaPtr<double> dy;
853 dA.upload( hA.data() );
859 dx.upload( hx.data(), hx.size() );
866 dy.upload( hx.data() );
871 cublasHandle_t handle;
873 stat = cublasCreate( &handle );
874 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
876 stat = mx::cuda::cublasTgemv( handle, CUBLAS_OP_N, 2, 2, &alpha, dA(), 2, dx(), 1, &beta, dy(), 1 );
877 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );
879 dy.download( hy.data() );
881 REQUIRE( hy[0] == 8 );
882 REQUIRE( hy[1] == 12 );
884 stat = cublasDestroy( handle );
885 REQUIRE( stat == CUBLAS_STATUS_SUCCESS );