Graph Framework
Loading...
Searching...
No Matches
backend.hpp
Go to the documentation of this file.
1//------------------------------------------------------------------------------
6//------------------------------------------------------------------------------
7
8#ifndef backend_h
9#define backend_h
10
11#include <algorithm>
12#include <vector>
13#include <cmath>
14
15#include "special_functions.hpp"
16#include "register.hpp"
17
19namespace backend {
20//******************************************************************************
21// Data buffer.
22//******************************************************************************
23//------------------------------------------------------------------------------
27//------------------------------------------------------------------------------
28 template<jit::float_scalar T>
29 class buffer {
30 private:
32 std::vector<T> memory;
33
34 public:
35//------------------------------------------------------------------------------
37//------------------------------------------------------------------------------
39 memory() {}
40
41//------------------------------------------------------------------------------
45//------------------------------------------------------------------------------
46 buffer(const size_t s) :
47 memory(s) {}
48
49//------------------------------------------------------------------------------
54//------------------------------------------------------------------------------
55 buffer(const size_t s, const T d) :
56 memory(s, d) {}
57
58//------------------------------------------------------------------------------
62//------------------------------------------------------------------------------
63 buffer(const std::vector<T> &d) :
64 memory(d) {}
65
66//------------------------------------------------------------------------------
70//------------------------------------------------------------------------------
71 buffer(const buffer &d) :
72 memory(d.memory) {}
73
74//------------------------------------------------------------------------------
76//------------------------------------------------------------------------------
77 T &operator[] (const size_t index) {
78 return memory[index];
79 }
80
81//------------------------------------------------------------------------------
83//------------------------------------------------------------------------------
84 const T &operator[] (const size_t index) const {
85 return memory[index];
86 }
87
88//------------------------------------------------------------------------------
90//------------------------------------------------------------------------------
91 const T at(const size_t index) const {
92 return memory.at(index);
93 }
94
95//------------------------------------------------------------------------------
99//------------------------------------------------------------------------------
100 void set(const T d) {
101 memory.assign(memory.size(), d);
102 }
103
104//------------------------------------------------------------------------------
108//------------------------------------------------------------------------------
109 void set(const std::vector<T> &d) {
110 memory.assign(d.cbegin(), d.cend());
111 }
112
113//------------------------------------------------------------------------------
115//------------------------------------------------------------------------------
116 size_t size() const {
117 return memory.size();
118 }
119
120//------------------------------------------------------------------------------
124//------------------------------------------------------------------------------
125 bool is_same() const {
126 const T same = memory.at(0);
127 for (size_t i = 1, ie = memory.size(); i < ie; i++) {
128 if (memory.at(i) != same) {
129 return false;
130 }
131 }
132
133 return true;
134 }
135
136//------------------------------------------------------------------------------
140//------------------------------------------------------------------------------
141 bool is_zero() const {
142 for (const T &d : memory) {
143 if (d != static_cast<T> (0.0)) {
144 return false;
145 }
146 }
147
148 return true;
149 }
150
151//------------------------------------------------------------------------------
155//------------------------------------------------------------------------------
156 bool has_zero() const {
157 for (const T &d : memory) {
158 if (d == static_cast<T> (0.0)) {
159 return true;
160 }
161 }
162
163 return false;
164 }
165
166//------------------------------------------------------------------------------
170//------------------------------------------------------------------------------
171 bool is_negative() const {
172 for (const T &d : memory) {
173 if (std::real(d) > std::real(static_cast<T> (0.0))) {
174 return false;
175 }
176 }
177
178 return true;
179 }
180
181//------------------------------------------------------------------------------
185//------------------------------------------------------------------------------
186 bool is_even() const {
187 for (const T &d : memory) {
188 if (std::fmod(std::real(d), std::real(static_cast<T> (2.0)))) {
189 return false;
190 }
191 }
192
193 return true;
194 }
195
196//------------------------------------------------------------------------------
200//------------------------------------------------------------------------------
201 bool is_none() const {
202 for (const T &d : memory) {
203 if (d != static_cast<T> (-1.0)) {
204 return false;
205 }
206 }
207
208 return true;
209 }
210
211//------------------------------------------------------------------------------
213//------------------------------------------------------------------------------
214 void sqrt() {
215 for (T &d : memory) {
216 d = std::sqrt(d);
217 }
218 }
219
220//------------------------------------------------------------------------------
222//------------------------------------------------------------------------------
223 void exp() {
224 for (T &d : memory) {
225 d = std::exp(d);
226 }
227 }
228
229//------------------------------------------------------------------------------
231//------------------------------------------------------------------------------
232 void log() {
233 for (T &d : memory) {
234 d = std::log(d);
235 }
236 }
237
238//------------------------------------------------------------------------------
240//------------------------------------------------------------------------------
241 void sin() {
242 for (T &d : memory) {
243 d = std::sin(d);
244 }
245 }
246
247//------------------------------------------------------------------------------
249//------------------------------------------------------------------------------
250 void cos() {
251 for (T &d : memory) {
252 d = std::cos(d);
253 }
254 }
255
256//------------------------------------------------------------------------------
258//------------------------------------------------------------------------------
259 void erfi() requires(jit::complex_scalar<T>) {
260 for (T &d : memory) {
261 d = special::erfi(d);
262 }
263 }
264
265//------------------------------------------------------------------------------
269//------------------------------------------------------------------------------
270 T *data() {
271 return memory.data();
272 }
273
274//------------------------------------------------------------------------------
278//------------------------------------------------------------------------------
279 bool is_normal() const {
280 for (const T &x : memory) {
281 if constexpr (jit::complex_scalar<T>) {
282 if (std::isnan(std::real(x)) || std::isinf(std::real(x)) ||
283 std::isnan(std::imag(x)) || std::isinf(std::imag(x))) {
284 return false;
285 }
286 } else {
287 if (std::isnan(x) || std::isinf(x)) {
288 return false;
289 }
290 }
291 }
292 return true;
293 }
294
295//------------------------------------------------------------------------------
302//------------------------------------------------------------------------------
303 void add_row(const buffer<T> &x) {
304 if (size() > x.size()) {
305 assert(size()%x.size() == 0 &&
306 "Vector operand size is not a multiple of matrix operand size");
307
308 const size_t num_colmns = size()/x.size();
309 const size_t num_rows = x.size();
310 for (size_t i = 0; i < num_rows; i++) {
311 for (size_t j = 0; j < num_colmns; j++) {
312 memory[i*num_rows + j] += x[i];
313 }
314 }
315 } else {
316 assert(x.size()%size() == 0 &&
317 "Vector operand size is not a multiple of matrix operand size");
318
319 std::vector<T> m(x.size());
320 const size_t num_colmns = x.size()/size();
321 const size_t num_rows = size();
322 for (size_t i = 0; i < num_rows; i++) {
323 for (size_t j = 0; j < num_colmns; j++) {
324 m[i*num_colmns + j] = memory[i] + x[i*num_colmns + j];
325 }
326 }
327 memory = m;
328 }
329 }
330
331//------------------------------------------------------------------------------
338//------------------------------------------------------------------------------
339 void add_col(const buffer<T> &x) {
340 if (size() > x.size()) {
341 assert(size()%x.size() == 0 &&
342 "Vector operand size is not a multiple of matrix operand size");
343
344 const size_t num_colmns = size()/x.size();
345 const size_t num_rows = x.size();
346 for (size_t i = 0; i < num_rows; i++) {
347 for (size_t j = 0; j < num_colmns; j++) {
348 memory[i*num_colmns + j] += x[j];
349 }
350 }
351 } else {
352 assert(x.size()%size() == 0 &&
353 "Vector operand size is not a multiple of matrix operand size");
354
355 std::vector<T> m(x.size());
356 const size_t num_colmns = x.size()/size();
357 const size_t num_rows = size();
358 for (size_t i = 0; i < num_rows; i++) {
359 for (size_t j = 0; j < num_colmns; j++) {
360 m[i*num_colmns + j] = memory[j] + x[i*num_colmns + j];
361 }
362 }
363 memory = m;
364 }
365 }
366
367//------------------------------------------------------------------------------
374//------------------------------------------------------------------------------
375 void subtract_row(const buffer<T> &x) {
376 if (size() > x.size()) {
377 assert(size()%x.size() == 0 &&
378 "Vector operand size is not a multiple of matrix operand size");
379
380 const size_t num_colmns = size()/x.size();
381 const size_t num_rows = x.size();
382 for (size_t i = 0; i < num_rows; i++) {
383 for (size_t j = 0; j < num_colmns; j++) {
384 memory[i*num_colmns + j] -= x[i];
385 }
386 }
387 } else {
388 assert(x.size()%size() == 0 &&
389 "Vector operand size is not a multiple of matrix operand size");
390
391 std::vector<T> m(x.size());
392 const size_t num_colmns = x.size()/size();
393 const size_t num_rows = size();
394 for (size_t i = 0; i < num_colmns; i++) {
395 for (size_t j = 0; j < num_rows; j++) {
396 m[i*num_colmns + j] = memory[i] - x[i*num_colmns + j];
397 }
398 }
399 memory = m;
400 }
401 }
402
403//------------------------------------------------------------------------------
410//------------------------------------------------------------------------------
411 void subtract_col(const buffer<T> &x) {
412 if (size() > x.size()) {
413 assert(size()%x.size() == 0 &&
414 "Vector operand size is not a multiple of matrix operand size");
415
416 const size_t num_colmns = size()/x.size();
417 const size_t num_rows = x.size();
418 for (size_t i = 0; i < num_rows; i++) {
419 for (size_t j = 0; j < num_colmns; j++) {
420 memory[i*num_colmns + j] -= x[j];
421 }
422 }
423 } else {
424 assert(x.size()%size() == 0 &&
425 "Vector operand size is not a multiple of matrix operand size");
426
427 std::vector<T> m(x.size());
428 const size_t num_colmns = x.size()/size();
429 const size_t num_rows = size();
430 for (size_t i = 0; i < num_rows; i++) {
431 for (size_t j = 0; j < num_colmns; j++) {
432 m[i*num_colmns + j] = memory[j] - x[i*num_colmns + j];
433 }
434 }
435 memory = m;
436 }
437 }
438
439//------------------------------------------------------------------------------
446//------------------------------------------------------------------------------
447 void multiply_row(const buffer<T> &x) {
448 if (size() > x.size()) {
449 assert(size()%x.size() == 0 &&
450 "Vector operand size is not a multiple of matrix operand size");
451
452 const size_t num_colmns = size()/x.size();
453 const size_t num_rows = x.size();
454 for (size_t i = 0; i < num_rows; i++) {
455 for (size_t j = 0; j < num_colmns; j++) {
456 memory[i*num_colmns + j] *= x[i];
457 }
458 }
459 } else {
460 assert(x.size()%size() == 0 &&
461 "Vector operand size is not a multiple of matrix operand size");
462
463 std::vector<T> m(x.size());
464 const size_t num_colmns = x.size()/size();
465 const size_t num_rows = size();
466 for (size_t i = 0; i < num_rows; i++) {
467 for (size_t j = 0; j < num_colmns; j++) {
468 m[i*num_colmns + j] = memory[i]*x[i*num_colmns + j];
469 }
470 }
471 memory = m;
472 }
473 }
474
475//------------------------------------------------------------------------------
482//------------------------------------------------------------------------------
483 void multiply_col(const buffer<T> &x) {
484 if (size() > x.size()) {
485 assert(size()%x.size() == 0 &&
486 "Vector operand size is not a multiple of matrix operand size");
487
488 const size_t num_colmns = size()/x.size();
489 const size_t num_rows = x.size();
490 for (size_t i = 0; i < num_rows; i++) {
491 for (size_t j = 0; j < num_colmns; j++) {
492 memory[i*num_colmns + j] *= x[j];
493 }
494 }
495 } else {
496 assert(x.size()%size() == 0 &&
497 "Vector operand size is not a multiple of matrix operand size");
498
499 std::vector<T> m(x.size());
500 const size_t num_colmns = x.size()/size();
501 const size_t num_rows = size();
502 for (size_t i = 0; i < num_rows; i++) {
503 for (size_t j = 0; j < num_colmns; j++) {
504 m[i*num_colmns + j] = memory[j]*x[i*num_colmns + j];
505 }
506 }
507 memory = m;
508 }
509 }
510
511//------------------------------------------------------------------------------
518//------------------------------------------------------------------------------
519 void divide_row(const buffer<T> &x) {
520 if (size() > x.size()) {
521 assert(size()%x.size() == 0 &&
522 "Vector operand size is not a multiple of matrix operand size");
523
524 const size_t num_colmns = size()/x.size();
525 const size_t num_rows = x.size();
526 for (size_t i = 0; i < num_rows; i++) {
527 for (size_t j = 0; j < num_colmns; j++) {
528 memory[i*num_colmns + j] /= x[i];
529 }
530 }
531 } else {
532 assert(x.size()%size() == 0 &&
533 "Vector operand size is not a multiple of matrix operand size");
534
535 std::vector<T> m(x.size());
536 const size_t num_colmns = x.size()/size();
537 const size_t num_rows = size();
538 for (size_t i = 0; i < num_rows; i++) {
539 for (size_t j = 0; j < num_colmns; j++) {
540 m[i*num_colmns + j] = memory[i]/x[i*num_colmns + j];
541 }
542 }
543 memory = m;
544 }
545 }
546
547//------------------------------------------------------------------------------
554//------------------------------------------------------------------------------
555 void divide_col(const buffer<T> &x) {
556 if (size() > x.size()) {
557 assert(size()%x.size() == 0 &&
558 "Vector operand size is not a multiple of matrix operand size");
559
560 const size_t num_colmns = size()/x.size();
561 const size_t num_rows = x.size();
562 for (size_t i = 0; i < num_rows; i++) {
563 for (size_t j = 0; j < num_colmns; j++) {
564 memory[i*num_colmns + j] /= x[j];
565 }
566 }
567 } else {
568 assert(x.size()%size() == 0 &&
569 "Vector operand size is not a multiple of matrix operand size");
570
571 std::vector<T> m(x.size());
572 const size_t num_colmns = x.size()/size();
573 const size_t num_rows = size();
574 for (size_t i = 0; i < num_rows; i++) {
575 for (size_t j = 0; j < num_colmns; j++) {
576 m[i*num_colmns + j] = memory[j]/x[i*num_colmns + j];
577 }
578 }
579 memory = m;
580 }
581 }
582
583//------------------------------------------------------------------------------
590//------------------------------------------------------------------------------
591 void atan_row(const buffer<T> &x) {
592 if (size() > x.size()) {
593 assert(size()%x.size() == 0 &&
594 "Vector operand size is not a multiple of matrix operand size");
595
596 const size_t num_colmns = size()/x.size();
597 const size_t num_rows = x.size();
598 for (size_t i = 0; i < num_rows; i++) {
599 for (size_t j = 0; j < num_colmns; j++) {
600 if constexpr (jit::complex_scalar<T>) {
601 memory[i*num_colmns + j] = std::atan(x[i]/memory[i*num_colmns + j]);
602 } else {
603 memory[i*num_colmns + j] = std::atan2(x[i], memory[i*num_colmns + j]);
604 }
605 }
606 }
607 } else {
608 assert(x.size()%size() == 0 &&
609 "Vector operand size is not a multiple of matrix operand size");
610
611 std::vector<T> m(x.size());
612 const size_t num_colmns = x.size()/size();
613 const size_t num_rows = size();
614 for (size_t i = 0; i < num_rows; i++) {
615 for (size_t j = 0; j < num_colmns; j++) {
616 if constexpr (jit::complex_scalar<T>) {
617 m[i*num_colmns + j] = std::atan(x[i*num_colmns + j]/memory[i]);
618 } else {
619 m[i*num_colmns + j] = std::atan2(x[i*num_colmns + j], memory[i]);
620 }
621 }
622 }
623 memory = m;
624 }
625 }
626
627//------------------------------------------------------------------------------
634//------------------------------------------------------------------------------
635 void atan_col(const buffer<T> &x) {
636 if (size() > x.size()) {
637 assert(size()%x.size() == 0 &&
638 "Vector operand size is not a multiple of matrix operand size");
639
640 const size_t num_colmns = size()/x.size();
641 const size_t num_rows = x.size();
642 for (size_t i = 0; i < num_colmns; i++) {
643 for (size_t j = 0; j < num_rows; j++) {
644 if constexpr (jit::complex_scalar<T>) {
645 memory[i*num_colmns + j] = std::atan(x[j]/memory[i*num_colmns + j]);
646 } else {
647 memory[i*num_colmns + j] = std::atan2(x[j], memory[i*num_colmns + j]);
648 }
649 }
650 }
651 } else {
652 assert(x.size()%size() == 0 &&
653 "Vector operand size is not a multiple of matrix operand size");
654
655 std::vector<T> m(x.size());
656 const size_t num_colmns = x.size()/size();
657 const size_t num_rows = size();
658 for (size_t i = 0; i < num_rows; i++) {
659 for (size_t j = 0; j < num_colmns; j++) {
660 if constexpr (jit::complex_scalar<T>) {
661 m[i*num_colmns + j] = std::atan(x[i*num_colmns + j]/memory[j]);
662 } else {
663 m[i*num_colmns + j] = std::atan2(x[i*num_colmns + j], memory[j]);
664 }
665 }
666 }
667 memory = m;
668 }
669 }
670
671//------------------------------------------------------------------------------
678//------------------------------------------------------------------------------
679 void pow_row(const buffer<T> &x) {
680 if (size() > x.size()) {
681 assert(size()%x.size() == 0 &&
682 "Vector operand size is not a multiple of matrix operand size");
683
684 const size_t num_colmns = size()/x.size();
685 const size_t num_rows = x.size();
686 for (size_t i = 0; i < num_rows; i++) {
687 for (size_t j = 0; j < num_colmns; j++) {
688 memory[i*num_colmns + j] = std::pow(memory[i*num_colmns + j], x[i]);
689 }
690 }
691 } else {
692 assert(x.size()%size() == 0 &&
693 "Vector operand size is not a multiple of matrix operand size");
694
695 std::vector<T> m(x.size());
696 const size_t num_colmns = x.size()/size();
697 const size_t num_rows = size();
698 for (size_t i = 0; i < num_colmns; i++) {
699 for (size_t j = 0; j < num_rows; j++) {
700 m[i*num_colmns + j] = std::pow(memory[i], x[i*num_colmns + j]);
701 }
702 }
703 memory = m;
704 }
705 }
706
707//------------------------------------------------------------------------------
714//------------------------------------------------------------------------------
715 void pow_col(const buffer<T> &x) {
716 if (size() > x.size()) {
717 assert(size()%x.size() == 0 &&
718 "Vector operand size is not a multiple of matrix operand size");
719
720 const size_t num_colmns = size()/x.size();
721 const size_t num_rows = x.size();
722 for (size_t i = 0; i < num_rows; i++) {
723 for (size_t j = 0; j < num_colmns; j++) {
724 memory[i*num_colmns + j] = std::pow(memory[i*num_colmns + j], x[j]);
725 }
726 }
727 } else {
728 assert(x.size()%size() == 0 &&
729 "Vector operand size is not a multiple of matrix operand size");
730
731 std::vector<T> m(x.size());
732 const size_t num_colmns = x.size()/size();
733 const size_t num_rows = size();
734 for (size_t i = 0; i < num_rows; i++) {
735 for (size_t j = 0; j < num_colmns; j++) {
736 m[i*num_colmns + j] = std::pow(memory[j], x[i*num_colmns + j]);
737 }
738 }
739 memory = m;
740 }
741 }
742
744 typedef T base;
745 };
746
747//------------------------------------------------------------------------------
755//------------------------------------------------------------------------------
756 template<jit::float_scalar T>
758 buffer<T> &b) {
759 if (b.size() == 1) {
760 const T right = b.at(0);
761 for (size_t i = 0, ie = a.size(); i < ie; i++) {
762 a[i] += right;
763 }
764 return a;
765 } else if (a.size() == 1) {
766 const T left = a.at(0);
767 for (size_t i = 0, ie = b.size(); i < ie; i++) {
768 b[i] += left;
769 }
770 return b;
771 }
772
773 assert(a.size() == b.size() &&
774 "Left and right sizes are incompatable.");
775 for (size_t i = 0, ie = a.size(); i < ie; i++) {
776 a[i] += b.at(i);
777 }
778 return a;
779 }
780
781//------------------------------------------------------------------------------
789//------------------------------------------------------------------------------
790 template<jit::float_scalar T>
791 inline bool operator==(const buffer<T> &a,
792 const buffer<T> &b) {
793 if (a.size() != b.size()) {
794 return false;
795 }
796
797 for (size_t i = 0, ie = a.size(); i < ie; i++) {
798 if (a.at(i) != b.at(i)) {
799 return false;
800 }
801 }
802 return true;
803 }
804
805//------------------------------------------------------------------------------
813//------------------------------------------------------------------------------
814 template<jit::float_scalar T>
816 buffer<T> &b) {
817 if (b.size() == 1) {
818 const T right = b.at(0);
819 for (size_t i = 0, ie = a.size(); i < ie; i++) {
820 a[i] -= right;
821 }
822 return a;
823 } else if (a.size() == 1) {
824 const T left = a.at(0);
825 for (size_t i = 0, ie = b.size(); i < ie; i++) {
826 b[i] = left - b.at(i);
827 }
828 return b;
829 }
830
831 assert(a.size() == b.size() &&
832 "Left and right sizes are incompatable.");
833 for (size_t i = 0, ie = a.size(); i < ie; i++) {
834 a[i] -= b.at(i);
835 }
836 return a;
837 }
838
839//------------------------------------------------------------------------------
847//------------------------------------------------------------------------------
848 template<jit::float_scalar T>
850 buffer<T> &b) {
851 if (b.size() == 1) {
852 const T right = b.at(0);
853 for (size_t i = 0, ie = a.size(); i < ie; i++) {
854 a[i] *= right;
855 }
856 return a;
857 } else if (a.size() == 1) {
858 const T left = a.at(0);
859 for (size_t i = 0, ie = b.size(); i < ie; i++) {
860 b[i] *= left;
861 }
862 return b;
863 }
864
865 assert(a.size() == b.size() &&
866 "Left and right sizes are incompatable.");
867 for (size_t i = 0, ie = a.size(); i < ie; i++) {
868 a[i] *= b.at(i);
869 }
870 return a;
871 }
872
873//------------------------------------------------------------------------------
881//------------------------------------------------------------------------------
882 template<jit::float_scalar T>
884 buffer<T> &b) {
885 if (b.size() == 1) {
886 const T right = b.at(0);
887 for (size_t i = 0, ie = a.size(); i < ie; i++) {
888 a[i] /= right;
889 }
890 return a;
891 } else if (a.size() == 1) {
892 const T left = a.at(0);
893 for (size_t i = 0, ie = b.size(); i < ie; i++) {
894 b[i] = left/b.at(i);
895 }
896 return b;
897 }
898
899 assert(a.size() == b.size() &&
900 "Left and right sizes are incompatable.");
901 for (size_t i = 0, ie = a.size(); i < ie; i++) {
902 a[i] /= b.at(i);
903 }
904 return a;
905 }
906
907//------------------------------------------------------------------------------
916//------------------------------------------------------------------------------
917 template<jit::float_scalar T>
919 buffer<T> &b,
920 buffer<T> &c) {
921 constexpr bool use_fma = !jit::complex_scalar<T> &&
922#ifdef FP_FAST_FMA
923 true;
924#else
925 false;
926#endif
927
928 if (a.size() == 1) {
929 const T left = a.at(0);
930
931 if (b.size() == 1) {
932 const T middle = b.at(0);
933 for (size_t i = 0, ie = c.size(); i < ie; i++) {
934 if constexpr (use_fma) {
935 c[i] = std::fma(left, middle, c.at(i));
936 } else {
937 c[i] = left*middle + c.at(i);
938 }
939 }
940 return c;
941 } else if (c.size() == 1) {
942 const T right = c.at(0);
943 for (size_t i = 0, ie = b.size(); i < ie; i++) {
944 if constexpr (use_fma) {
945 b[i] = std::fma(left, b.at(i), right);
946 } else {
947 b[i] = left*b.at(i) + right;
948 }
949 }
950 return b;
951 }
952
953 assert(b.size() == c.size() &&
954 "Size mismatch between middle and right.");
955 for (size_t i = 0, ie = b.size(); i < ie; i++) {
956 if constexpr (use_fma) {
957 b[i] = std::fma(left, b.at(i), c.at(i));
958 } else {
959 b[i] = left*b.at(i) + c.at(i);
960 }
961 }
962 return b;
963 } else if (b.size() == 1) {
964 const T middle = b.at(0);
965 if (c.size() == 1) {
966 const T right = c.at(0);
967 for (size_t i = 0, ie = a.size(); i < ie; i++) {
968 if constexpr (use_fma) {
969 a[i] = std::fma(a.at(i), middle, right);
970 } else {
971 a[i] = a.at(i)*middle + right;
972 }
973 }
974 return a;
975 }
976
977 assert(a.size() == c.size() &&
978 "Size mismatch between left and right.");
979 for (size_t i = 0, ie = a.size(); i < ie; i++) {
980 if constexpr (use_fma) {
981 a[i] = std::fma(a.at(i), middle, c.at(i));
982 } else {
983 a[i] = a.at(i)*middle + c.at(i);
984 }
985 }
986 return a;
987 } else if (c.size() == 1) {
988 assert(a.size() == b.size() &&
989 "Size mismatch between left and middle.");
990 const T right = c.at(0);
991 for (size_t i = 0, ie = a.size(); i < ie; i++) {
992 if constexpr (use_fma) {
993 a[i] = std::fma(a.at(i), b.at(i), right);
994 } else {
995 a[i] = a.at(i)*b.at(i) + right;
996 }
997 }
998 return a;
999 }
1000
1001 assert(a.size() == b.size() &&
1002 b.size() == c.size() &&
1003 a.size() == c.size() &&
1004 "Left, middle and right sizes are incompatable.");
1005 for (size_t i = 0, ie = a.size(); i < ie; i++) {
1006 if constexpr (use_fma) {
1007 a[i] = std::fma(a.at(i), b.at(i), c.at(i));
1008 } else {
1009 a[i] = a.at(i)*b.at(i) + c.at(i);
1010 }
1011 }
1012 return a;
1013 }
1014
1015//------------------------------------------------------------------------------
1023//------------------------------------------------------------------------------
1024 template<jit::float_scalar T>
1026 buffer<T> &exponent) {
1027 if (exponent.size() == 1) {
1028 const T right = exponent.at(0);
1029 if (std::imag(right) == 0) {
1030 const int64_t right_int = static_cast<int64_t> (std::real(right));
1031 if (std::real(right) - right_int) {
1032 if (right == static_cast<T> (0.5)) {
1033 base.sqrt();
1034 return base;
1035 }
1036
1037 for (size_t i = 0, ie = base.size(); i < ie; i++) {
1038 base[i] = std::pow(base.at(i), right);
1039 }
1040 return base;
1041 }
1042
1043 if (right_int > 0) {
1044 for (size_t i = 0, ie = base.size(); i < ie; i++) {
1045 const T left = base.at(i);
1046 for (size_t j = 0, je = right_int - 1; j < je; j++) {
1047 base[i] *= left;
1048 }
1049 }
1050 return base;
1051 } else if (right_int == 0) {
1052 for (size_t i = 0, ie = base.size(); i < ie; i++) {
1053 base[i] = 1.0;
1054 }
1055 return base;
1056 } else {
1057 for (size_t i = 0, ie = base.size(); i < ie; i++) {
1058 const T left = static_cast<T> (1.0)/base.at(i);
1059 base[i] = left;
1060 for (size_t j = 0, je = std::abs(right_int) - 1; j < je; j++) {
1061 base[i] *= left;
1062 }
1063 }
1064 return base;
1065 }
1066 } else {
1067 for (size_t i = 0, ie = base.size(); i < ie; i++) {
1068 base[i] = std::pow(base.at(i), right);
1069 }
1070 return base;
1071 }
1072 } else if (base.size() == 1) {
1073 const T left = base.at(0);
1074 for (size_t i = 0, ie = exponent.size(); i < ie; i++) {
1075 exponent[i] = std::pow(left, exponent.at(i));
1076 }
1077 return exponent;
1078 }
1079
1080 assert(base.size() == exponent.size() &&
1081 "Left and right sizes are incompatable.");
1082 for (size_t i = 0, ie = base.size(); i < ie; i++) {
1083 base[i] = std::pow(base.at(i), exponent.at(i));
1084 }
1085 return base;
1086 }
1087
1088//------------------------------------------------------------------------------
1096//------------------------------------------------------------------------------
1097 template<jit::float_scalar T>
1099 buffer<T> &y) {
1100 if (y.size() == 1) {
1101 const T right = y.at(0);
1102 for (size_t i = 0, ie = x.size(); i < ie; i++) {
1103 if constexpr (jit::complex_scalar<T>) {
1104 x[i] = std::atan(right/x[i]);
1105 } else {
1106 x[i] = std::atan2(right, x[i]);
1107 }
1108 }
1109 return x;
1110 } else if (x.size() == 1) {
1111 const T left = x.at(0);
1112 for (size_t i = 0, ie = y.size(); i < ie; i++) {
1113 if constexpr (jit::complex_scalar<T>) {
1114 y[i] = std::atan(y[i]/left);
1115 } else {
1116 y[i] = std::atan2(y[i], left);
1117 }
1118 }
1119 return y;
1120 }
1121
1122 assert(x.size() == y.size() &&
1123 "Left and right sizes are incompatable.");
1124 for (size_t i = 0, ie = x.size(); i < ie; i++) {
1125 if constexpr (jit::complex_scalar<T>) {
1126 x[i] = std::atan(y[i]/x[i]);
1127 } else {
1128 x[i] = std::atan2(y[i], x[i]);
1129 }
1130 }
1131 return x;
1132 }
1133}
1134
1135#endif /* backend_h */
Class representing a generic buffer.
Definition backend.hpp:29
void set(const std::vector< T > &d)
Assign a vector value.
Definition backend.hpp:109
void subtract_row(const buffer< T > &x)
Subtract row operation.
Definition backend.hpp:375
void multiply_row(const buffer< T > &x)
Multiply row operation.
Definition backend.hpp:447
void add_col(const buffer< T > &x)
Add col operation.
Definition backend.hpp:339
void erfi()
Take erfi.
Definition backend.hpp:259
void log()
Take log.
Definition backend.hpp:232
buffer(const buffer &d)
Construct a buffer backend from a buffer backend.
Definition backend.hpp:71
void sqrt()
Take sqrt.
Definition backend.hpp:214
buffer(const size_t s)
Construct a buffer backend with a size.
Definition backend.hpp:46
T & operator[](const size_t index)
Index operator.
Definition backend.hpp:77
bool is_normal() const
Check for normal values.
Definition backend.hpp:279
buffer()
Construct an empty buffer backend.
Definition backend.hpp:38
bool is_negative() const
Is every element negative.
Definition backend.hpp:171
const T at(const size_t index) const
Get value at.
Definition backend.hpp:91
void pow_col(const buffer< T > &x)
Pow col operation.
Definition backend.hpp:715
buffer(const std::vector< T > &d)
Construct a buffer backend from a vector.
Definition backend.hpp:63
bool is_same() const
Is every element the same.
Definition backend.hpp:125
buffer(const size_t s, const T d)
Construct a buffer backend with a size.
Definition backend.hpp:55
void divide_col(const buffer< T > &x)
Divide col operation.
Definition backend.hpp:555
bool is_none() const
Is every element negative one.
Definition backend.hpp:201
void pow_row(const buffer< T > &x)
Pow row operation.
Definition backend.hpp:679
void sin()
Take sin.
Definition backend.hpp:241
void add_row(const buffer< T > &x)
Add row operation.
Definition backend.hpp:303
void subtract_col(const buffer< T > &x)
Subtract col operation.
Definition backend.hpp:411
void multiply_col(const buffer< T > &x)
Multiply col operation.
Definition backend.hpp:483
bool is_even() const
Is every element even.
Definition backend.hpp:186
T base
Type def to retrieve the backend T type.
Definition backend.hpp:744
void divide_row(const buffer< T > &x)
Divide row operation.
Definition backend.hpp:519
void atan_col(const buffer< T > &x)
Atan col operation.
Definition backend.hpp:635
void atan_row(const buffer< T > &x)
Atan row operation.
Definition backend.hpp:591
void exp()
Take exp.
Definition backend.hpp:223
bool has_zero() const
Is any element zero.
Definition backend.hpp:156
bool is_zero() const
Is every element zero.
Definition backend.hpp:141
size_t size() const
Get size of the buffer.
Definition backend.hpp:116
void set(const T d)
Assign a constant value.
Definition backend.hpp:100
T * data()
Get a pointer to the basic memory buffer.
Definition backend.hpp:270
void cos()
Take cos.
Definition backend.hpp:250
Complex scalar concept.
Definition register.hpp:24
subroutine assert(test, message)
Assert check.
Definition f_binding_test.f90:38
Name space for backend buffers.
Definition backend.hpp:19
buffer< T > fma(buffer< T > &a, buffer< T > &b, buffer< T > &c)
Fused multiply add operation.
Definition backend.hpp:918
buffer< T > operator*(buffer< T > &a, buffer< T > &b)
Multiply operation.
Definition backend.hpp:849
bool operator==(const buffer< T > &a, const buffer< T > &b)
Equal operation.
Definition backend.hpp:791
buffer< T > atan(buffer< T > &x, buffer< T > &y)
Take the inverse tangent.
Definition backend.hpp:1098
buffer< T > operator/(buffer< T > &a, buffer< T > &b)
Divide operation.
Definition backend.hpp:883
buffer< T > operator-(buffer< T > &a, buffer< T > &b)
Subtract operation.
Definition backend.hpp:815
buffer< T > operator+(buffer< T > &a, buffer< T > &b)
Add operation.
Definition backend.hpp:757
buffer< T > pow(buffer< T > &base, buffer< T > &exponent)
Take the power.
Definition backend.hpp:1025
complex_type< T > erfi(const complex_type< T > z)
erfi(z) = -i erf(iz)
Definition special_functions.hpp:1580
Utilities for writting jit source code.
Implimentations for special functions.