Graph Framework
Loading...
Searching...
No Matches
backend.hpp
Go to the documentation of this file.
1//------------------------------------------------------------------------------
6//------------------------------------------------------------------------------
7
8#ifndef backend_h
9#define backend_h
10
11#include <algorithm>
12#include <vector>
13#include <cmath>
14
15#include "special_functions.hpp"
16#include "register.hpp"
17
19namespace backend {
20//******************************************************************************
21// Data buffer.
22//******************************************************************************
23//------------------------------------------------------------------------------
27//------------------------------------------------------------------------------
28 template<jit::float_scalar T>
29 class buffer {
30 private:
32 std::vector<T> memory;
33
34 public:
35//------------------------------------------------------------------------------
37//------------------------------------------------------------------------------
39 memory() {}
40
41//------------------------------------------------------------------------------
45//------------------------------------------------------------------------------
46 buffer(const size_t s) :
47 memory(s) {}
48
49//------------------------------------------------------------------------------
54//------------------------------------------------------------------------------
55 buffer(const size_t s, const T d) :
56 memory(s, d) {}
57
58//------------------------------------------------------------------------------
62//------------------------------------------------------------------------------
63 buffer(const std::vector<T> &d) :
64 memory(d) {}
65
66//------------------------------------------------------------------------------
70//------------------------------------------------------------------------------
71 buffer(const buffer &d) :
72 memory(d.memory) {}
73
74//------------------------------------------------------------------------------
76//------------------------------------------------------------------------------
77 T &operator[] (const size_t index) {
78 return memory[index];
79 }
80
81//------------------------------------------------------------------------------
83//------------------------------------------------------------------------------
84 const T &operator[] (const size_t index) const {
85 return memory[index];
86 }
87
88//------------------------------------------------------------------------------
90//------------------------------------------------------------------------------
91 const T at(const size_t index) const {
92 return memory.at(index);
93 }
94
95//------------------------------------------------------------------------------
99//------------------------------------------------------------------------------
100 void set(const T d) {
101 memory.assign(memory.size(), d);
102 }
103
104//------------------------------------------------------------------------------
108//------------------------------------------------------------------------------
109 void set(const std::vector<T> &d) {
110 memory.assign(d.cbegin(), d.cend());
111 }
112
113//------------------------------------------------------------------------------
115//------------------------------------------------------------------------------
116 size_t size() const {
117 return memory.size();
118 }
119
120//------------------------------------------------------------------------------
124//------------------------------------------------------------------------------
125 bool is_same() const {
126 const T same = memory.at(0);
127 for (size_t i = 1, ie = memory.size(); i < ie; i++) {
128 if (memory.at(i) != same) {
129 return false;
130 }
131 }
132
133 return true;
134 }
135
136//------------------------------------------------------------------------------
140//------------------------------------------------------------------------------
141 bool is_zero() const {
142 for (const T &d : memory) {
143 if (d != static_cast<T> (0.0)) {
144 return false;
145 }
146 }
147
148 return true;
149 }
150
151//------------------------------------------------------------------------------
155//------------------------------------------------------------------------------
156 bool has_zero() const {
157 for (const T &d : memory) {
158 if (d == static_cast<T> (0.0)) {
159 return true;
160 }
161 }
162
163 return false;
164 }
165
166//------------------------------------------------------------------------------
170//------------------------------------------------------------------------------
171 bool is_negative() const {
172 for (const T &d : memory) {
173 if (std::real(d) > std::real(static_cast<T> (0.0))) {
174 return false;
175 }
176 }
177
178 return true;
179 }
180
181//------------------------------------------------------------------------------
185//------------------------------------------------------------------------------
186 bool is_even() const {
187 for (const T &d : memory) {
188 if (std::fmod(std::real(d), std::real(static_cast<T> (2.0)))) {
189 return false;
190 }
191 }
192
193 return true;
194 }
195
196//------------------------------------------------------------------------------
200//------------------------------------------------------------------------------
201 bool is_none() const {
202 for (const T &d : memory) {
203 if (d != static_cast<T> (-1.0)) {
204 return false;
205 }
206 }
207
208 return true;
209 }
210
211//------------------------------------------------------------------------------
213//------------------------------------------------------------------------------
214 void sqrt() {
215 for (T &d : memory) {
216 d = std::sqrt(d);
217 }
218 }
219
220//------------------------------------------------------------------------------
222//------------------------------------------------------------------------------
223 void exp() {
224 for (T &d : memory) {
225 d = std::exp(d);
226 }
227 }
228
229//------------------------------------------------------------------------------
231//------------------------------------------------------------------------------
232 void log() {
233 for (T &d : memory) {
234 d = std::log(d);
235 }
236 }
237
238//------------------------------------------------------------------------------
240//------------------------------------------------------------------------------
241 void sin() {
242 for (T &d : memory) {
243 d = std::sin(d);
244 }
245 }
246
247//------------------------------------------------------------------------------
249//------------------------------------------------------------------------------
250 void cos() {
251 for (T &d : memory) {
252 d = std::cos(d);
253 }
254 }
255
256//------------------------------------------------------------------------------
258//------------------------------------------------------------------------------
259 void erfi() requires(jit::complex_scalar<T>) {
260 for (T &d : memory) {
261 d = special::erfi(d);
262 }
263 }
264
265//------------------------------------------------------------------------------
269//------------------------------------------------------------------------------
270 T *data() {
271 return memory.data();
272 }
273
274//------------------------------------------------------------------------------
278//------------------------------------------------------------------------------
279 bool is_normal() const {
280 for (const T &x : memory) {
281 if constexpr (jit::complex_scalar<T>) {
282 if (std::isnan(std::real(x)) || std::isinf(std::real(x)) ||
283 std::isnan(std::imag(x)) || std::isinf(std::imag(x))) {
284 return false;
285 }
286 } else {
287 if (std::isnan(x) || std::isinf(x)) {
288 return false;
289 }
290 }
291 }
292 return true;
293 }
294
295//------------------------------------------------------------------------------
301//------------------------------------------------------------------------------
302 buffer<T> index_row(const size_t index, const size_t num_columns) {
303 buffer<T> b(num_columns);
304 const size_t num_rows = size()/num_columns;
305 for (size_t j = 0; j < num_columns; j++) {
306 b[j] = memory[index*num_rows + j];
307 }
308 return b;
309 }
310
311//------------------------------------------------------------------------------
317//------------------------------------------------------------------------------
318 buffer<T> index_column(const size_t index, const size_t num_columns) {
319 const size_t num_rows = size()/num_columns;
320 buffer<T> b(num_rows);
321 for (size_t i = 0; i < num_rows; i++) {
322 b[i] = memory[i*num_rows + index];
323 }
324 return b;
325 }
326
327//------------------------------------------------------------------------------
334//------------------------------------------------------------------------------
335 void add_row(const buffer<T> &x) {
336 if (size() > x.size()) {
337 assert(size()%x.size() == 0 &&
338 "Vector operand size is not a multiple of matrix operand size");
339
340 const size_t num_columns = size()/x.size();
341 const size_t num_rows = x.size();
342 for (size_t i = 0; i < num_rows; i++) {
343 for (size_t j = 0; j < num_columns; j++) {
344 memory[i*num_rows + j] += x[i];
345 }
346 }
347 } else {
348 assert(x.size()%size() == 0 &&
349 "Vector operand size is not a multiple of matrix operand size");
350
351 std::vector<T> m(x.size());
352 const size_t num_columns = x.size()/size();
353 const size_t num_rows = size();
354 for (size_t i = 0; i < num_rows; i++) {
355 for (size_t j = 0; j < num_columns; j++) {
356 m[i*num_columns + j] = memory[i] + x[i*num_columns + j];
357 }
358 }
359 memory = m;
360 }
361 }
362
363//------------------------------------------------------------------------------
370//------------------------------------------------------------------------------
371 void add_col(const buffer<T> &x) {
372 if (size() > x.size()) {
373 assert(size()%x.size() == 0 &&
374 "Vector operand size is not a multiple of matrix operand size");
375
376 const size_t num_columns = size()/x.size();
377 const size_t num_rows = x.size();
378 for (size_t i = 0; i < num_rows; i++) {
379 for (size_t j = 0; j < num_columns; j++) {
380 memory[i*num_columns + j] += x[j];
381 }
382 }
383 } else {
384 assert(x.size()%size() == 0 &&
385 "Vector operand size is not a multiple of matrix operand size");
386
387 std::vector<T> m(x.size());
388 const size_t num_columns = x.size()/size();
389 const size_t num_rows = size();
390 for (size_t i = 0; i < num_rows; i++) {
391 for (size_t j = 0; j < num_columns; j++) {
392 m[i*num_columns + j] = memory[j] + x[i*num_columns + j];
393 }
394 }
395 memory = m;
396 }
397 }
398
399//------------------------------------------------------------------------------
406//------------------------------------------------------------------------------
407 void subtract_row(const buffer<T> &x) {
408 if (size() > x.size()) {
409 assert(size()%x.size() == 0 &&
410 "Vector operand size is not a multiple of matrix operand size");
411
412 const size_t num_columns = size()/x.size();
413 const size_t num_rows = x.size();
414 for (size_t i = 0; i < num_rows; i++) {
415 for (size_t j = 0; j < num_columns; j++) {
416 memory[i*num_columns + j] -= x[i];
417 }
418 }
419 } else {
420 assert(x.size()%size() == 0 &&
421 "Vector operand size is not a multiple of matrix operand size");
422
423 std::vector<T> m(x.size());
424 const size_t num_columns = x.size()/size();
425 const size_t num_rows = size();
426 for (size_t i = 0; i < num_columns; i++) {
427 for (size_t j = 0; j < num_rows; j++) {
428 m[i*num_columns + j] = memory[i] - x[i*num_columns + j];
429 }
430 }
431 memory = m;
432 }
433 }
434
435//------------------------------------------------------------------------------
442//------------------------------------------------------------------------------
443 void subtract_col(const buffer<T> &x) {
444 if (size() > x.size()) {
445 assert(size()%x.size() == 0 &&
446 "Vector operand size is not a multiple of matrix operand size");
447
448 const size_t num_columns = size()/x.size();
449 const size_t num_rows = x.size();
450 for (size_t i = 0; i < num_rows; i++) {
451 for (size_t j = 0; j < num_columns; j++) {
452 memory[i*num_columns + j] -= x[j];
453 }
454 }
455 } else {
456 assert(x.size()%size() == 0 &&
457 "Vector operand size is not a multiple of matrix operand size");
458
459 std::vector<T> m(x.size());
460 const size_t num_columns = x.size()/size();
461 const size_t num_rows = size();
462 for (size_t i = 0; i < num_rows; i++) {
463 for (size_t j = 0; j < num_columns; j++) {
464 m[i*num_columns + j] = memory[j] - x[i*num_columns + j];
465 }
466 }
467 memory = m;
468 }
469 }
470
471//------------------------------------------------------------------------------
478//------------------------------------------------------------------------------
479 void multiply_row(const buffer<T> &x) {
480 if (size() > x.size()) {
481 assert(size()%x.size() == 0 &&
482 "Vector operand size is not a multiple of matrix operand size");
483
484 const size_t num_columns = size()/x.size();
485 const size_t num_rows = x.size();
486 for (size_t i = 0; i < num_rows; i++) {
487 for (size_t j = 0; j < num_columns; j++) {
488 memory[i*num_columns + j] *= x[i];
489 }
490 }
491 } else {
492 assert(x.size()%size() == 0 &&
493 "Vector operand size is not a multiple of matrix operand size");
494
495 std::vector<T> m(x.size());
496 const size_t num_columns = x.size()/size();
497 const size_t num_rows = size();
498 for (size_t i = 0; i < num_rows; i++) {
499 for (size_t j = 0; j < num_columns; j++) {
500 m[i*num_columns + j] = memory[i]*x[i*num_columns + j];
501 }
502 }
503 memory = m;
504 }
505 }
506
507//------------------------------------------------------------------------------
514//------------------------------------------------------------------------------
515 void multiply_col(const buffer<T> &x) {
516 if (size() > x.size()) {
517 assert(size()%x.size() == 0 &&
518 "Vector operand size is not a multiple of matrix operand size");
519
520 const size_t num_columns = size()/x.size();
521 const size_t num_rows = x.size();
522 for (size_t i = 0; i < num_rows; i++) {
523 for (size_t j = 0; j < num_columns; j++) {
524 memory[i*num_columns + j] *= x[j];
525 }
526 }
527 } else {
528 assert(x.size()%size() == 0 &&
529 "Vector operand size is not a multiple of matrix operand size");
530
531 std::vector<T> m(x.size());
532 const size_t num_columns = x.size()/size();
533 const size_t num_rows = size();
534 for (size_t i = 0; i < num_rows; i++) {
535 for (size_t j = 0; j < num_columns; j++) {
536 m[i*num_columns + j] = memory[j]*x[i*num_columns + j];
537 }
538 }
539 memory = m;
540 }
541 }
542
543//------------------------------------------------------------------------------
550//------------------------------------------------------------------------------
551 void divide_row(const buffer<T> &x) {
552 if (size() > x.size()) {
553 assert(size()%x.size() == 0 &&
554 "Vector operand size is not a multiple of matrix operand size");
555
556 const size_t num_columns = size()/x.size();
557 const size_t num_rows = x.size();
558 for (size_t i = 0; i < num_rows; i++) {
559 for (size_t j = 0; j < num_columns; j++) {
560 memory[i*num_columns + j] /= x[i];
561 }
562 }
563 } else {
564 assert(x.size()%size() == 0 &&
565 "Vector operand size is not a multiple of matrix operand size");
566
567 std::vector<T> m(x.size());
568 const size_t num_columns = x.size()/size();
569 const size_t num_rows = size();
570 for (size_t i = 0; i < num_rows; i++) {
571 for (size_t j = 0; j < num_columns; j++) {
572 m[i*num_columns + j] = memory[i]/x[i*num_columns + j];
573 }
574 }
575 memory = m;
576 }
577 }
578
579//------------------------------------------------------------------------------
586//------------------------------------------------------------------------------
587 void divide_col(const buffer<T> &x) {
588 if (size() > x.size()) {
589 assert(size()%x.size() == 0 &&
590 "Vector operand size is not a multiple of matrix operand size");
591
592 const size_t num_columns = size()/x.size();
593 const size_t num_rows = x.size();
594 for (size_t i = 0; i < num_rows; i++) {
595 for (size_t j = 0; j < num_columns; j++) {
596 memory[i*num_columns + j] /= x[j];
597 }
598 }
599 } else {
600 assert(x.size()%size() == 0 &&
601 "Vector operand size is not a multiple of matrix operand size");
602
603 std::vector<T> m(x.size());
604 const size_t num_columns = x.size()/size();
605 const size_t num_rows = size();
606 for (size_t i = 0; i < num_rows; i++) {
607 for (size_t j = 0; j < num_columns; j++) {
608 m[i*num_columns + j] = memory[j]/x[i*num_columns + j];
609 }
610 }
611 memory = m;
612 }
613 }
614
615//------------------------------------------------------------------------------
622//------------------------------------------------------------------------------
623 void atan_row(const buffer<T> &x) {
624 if (size() > x.size()) {
625 assert(size()%x.size() == 0 &&
626 "Vector operand size is not a multiple of matrix operand size");
627
628 const size_t num_columns = size()/x.size();
629 const size_t num_rows = x.size();
630 for (size_t i = 0; i < num_rows; i++) {
631 for (size_t j = 0; j < num_columns; j++) {
632 if constexpr (jit::complex_scalar<T>) {
633 memory[i*num_columns + j] = std::atan(x[i]/memory[i*num_columns + j]);
634 } else {
635 memory[i*num_columns + j] = std::atan2(x[i], memory[i*num_columns + j]);
636 }
637 }
638 }
639 } else {
640 assert(x.size()%size() == 0 &&
641 "Vector operand size is not a multiple of matrix operand size");
642
643 std::vector<T> m(x.size());
644 const size_t num_columns = x.size()/size();
645 const size_t num_rows = size();
646 for (size_t i = 0; i < num_rows; i++) {
647 for (size_t j = 0; j < num_columns; j++) {
648 if constexpr (jit::complex_scalar<T>) {
649 m[i*num_columns + j] = std::atan(x[i*num_columns + j]/memory[i]);
650 } else {
651 m[i*num_columns + j] = std::atan2(x[i*num_columns + j], memory[i]);
652 }
653 }
654 }
655 memory = m;
656 }
657 }
658
659//------------------------------------------------------------------------------
666//------------------------------------------------------------------------------
667 void atan_col(const buffer<T> &x) {
668 if (size() > x.size()) {
669 assert(size()%x.size() == 0 &&
670 "Vector operand size is not a multiple of matrix operand size");
671
672 const size_t num_columns = size()/x.size();
673 const size_t num_rows = x.size();
674 for (size_t i = 0; i < num_columns; i++) {
675 for (size_t j = 0; j < num_rows; j++) {
676 if constexpr (jit::complex_scalar<T>) {
677 memory[i*num_columns + j] = std::atan(x[j]/memory[i*num_columns + j]);
678 } else {
679 memory[i*num_columns + j] = std::atan2(x[j], memory[i*num_columns + j]);
680 }
681 }
682 }
683 } else {
684 assert(x.size()%size() == 0 &&
685 "Vector operand size is not a multiple of matrix operand size");
686
687 std::vector<T> m(x.size());
688 const size_t num_columns = x.size()/size();
689 const size_t num_rows = size();
690 for (size_t i = 0; i < num_rows; i++) {
691 for (size_t j = 0; j < num_columns; j++) {
692 if constexpr (jit::complex_scalar<T>) {
693 m[i*num_columns + j] = std::atan(x[i*num_columns + j]/memory[j]);
694 } else {
695 m[i*num_columns + j] = std::atan2(x[i*num_columns + j], memory[j]);
696 }
697 }
698 }
699 memory = m;
700 }
701 }
702
703//------------------------------------------------------------------------------
710//------------------------------------------------------------------------------
711 void pow_row(const buffer<T> &x) {
712 if (size() > x.size()) {
713 assert(size()%x.size() == 0 &&
714 "Vector operand size is not a multiple of matrix operand size");
715
716 const size_t num_columns = size()/x.size();
717 const size_t num_rows = x.size();
718 for (size_t i = 0; i < num_rows; i++) {
719 for (size_t j = 0; j < num_columns; j++) {
720 memory[i*num_columns + j] = std::pow(memory[i*num_columns + j], x[i]);
721 }
722 }
723 } else {
724 assert(x.size()%size() == 0 &&
725 "Vector operand size is not a multiple of matrix operand size");
726
727 std::vector<T> m(x.size());
728 const size_t num_columns = x.size()/size();
729 const size_t num_rows = size();
730 for (size_t i = 0; i < num_columns; i++) {
731 for (size_t j = 0; j < num_rows; j++) {
732 m[i*num_columns + j] = std::pow(memory[i], x[i*num_columns + j]);
733 }
734 }
735 memory = m;
736 }
737 }
738
739//------------------------------------------------------------------------------
746//------------------------------------------------------------------------------
747 void pow_col(const buffer<T> &x) {
748 if (size() > x.size()) {
749 assert(size()%x.size() == 0 &&
750 "Vector operand size is not a multiple of matrix operand size");
751
752 const size_t num_columns = size()/x.size();
753 const size_t num_rows = x.size();
754 for (size_t i = 0; i < num_rows; i++) {
755 for (size_t j = 0; j < num_columns; j++) {
756 memory[i*num_columns + j] = std::pow(memory[i*num_columns + j], x[j]);
757 }
758 }
759 } else {
760 assert(x.size()%size() == 0 &&
761 "Vector operand size is not a multiple of matrix operand size");
762
763 std::vector<T> m(x.size());
764 const size_t num_columns = x.size()/size();
765 const size_t num_rows = size();
766 for (size_t i = 0; i < num_rows; i++) {
767 for (size_t j = 0; j < num_columns; j++) {
768 m[i*num_columns + j] = std::pow(memory[j], x[i*num_columns + j]);
769 }
770 }
771 memory = m;
772 }
773 }
774
776 typedef T base;
777 };
778
779//------------------------------------------------------------------------------
787//------------------------------------------------------------------------------
788 template<jit::float_scalar T>
790 buffer<T> &b) {
791 if (b.size() == 1) {
792 const T right = b.at(0);
793 for (size_t i = 0, ie = a.size(); i < ie; i++) {
794 a[i] += right;
795 }
796 return a;
797 } else if (a.size() == 1) {
798 const T left = a.at(0);
799 for (size_t i = 0, ie = b.size(); i < ie; i++) {
800 b[i] += left;
801 }
802 return b;
803 }
804
805 assert(a.size() == b.size() &&
806 "Left and right sizes are incompatible.");
807 for (size_t i = 0, ie = a.size(); i < ie; i++) {
808 a[i] += b.at(i);
809 }
810 return a;
811 }
812
813//------------------------------------------------------------------------------
821//------------------------------------------------------------------------------
822 template<jit::float_scalar T>
823 inline bool operator==(const buffer<T> &a,
824 const buffer<T> &b) {
825 if (a.size() != b.size()) {
826 return false;
827 }
828
829 for (size_t i = 0, ie = a.size(); i < ie; i++) {
830 if (a.at(i) != b.at(i)) {
831 return false;
832 }
833 }
834 return true;
835 }
836
837//------------------------------------------------------------------------------
845//------------------------------------------------------------------------------
846 template<jit::float_scalar T>
848 buffer<T> &b) {
849 if (b.size() == 1) {
850 const T right = b.at(0);
851 for (size_t i = 0, ie = a.size(); i < ie; i++) {
852 a[i] -= right;
853 }
854 return a;
855 } else if (a.size() == 1) {
856 const T left = a.at(0);
857 for (size_t i = 0, ie = b.size(); i < ie; i++) {
858 b[i] = left - b.at(i);
859 }
860 return b;
861 }
862
863 assert(a.size() == b.size() &&
864 "Left and right sizes are incompatible.");
865 for (size_t i = 0, ie = a.size(); i < ie; i++) {
866 a[i] -= b.at(i);
867 }
868 return a;
869 }
870
871//------------------------------------------------------------------------------
879//------------------------------------------------------------------------------
880 template<jit::float_scalar T>
882 buffer<T> &b) {
883 if (b.size() == 1) {
884 const T right = b.at(0);
885 for (size_t i = 0, ie = a.size(); i < ie; i++) {
886 a[i] *= right;
887 }
888 return a;
889 } else if (a.size() == 1) {
890 const T left = a.at(0);
891 for (size_t i = 0, ie = b.size(); i < ie; i++) {
892 b[i] *= left;
893 }
894 return b;
895 }
896
897 assert(a.size() == b.size() &&
898 "Left and right sizes are incompatible.");
899 for (size_t i = 0, ie = a.size(); i < ie; i++) {
900 a[i] *= b.at(i);
901 }
902 return a;
903 }
904
905//------------------------------------------------------------------------------
913//------------------------------------------------------------------------------
914 template<jit::float_scalar T>
916 buffer<T> &b) {
917 if (b.size() == 1) {
918 const T right = b.at(0);
919 for (size_t i = 0, ie = a.size(); i < ie; i++) {
920 a[i] /= right;
921 }
922 return a;
923 } else if (a.size() == 1) {
924 const T left = a.at(0);
925 for (size_t i = 0, ie = b.size(); i < ie; i++) {
926 b[i] = left/b.at(i);
927 }
928 return b;
929 }
930
931 assert(a.size() == b.size() &&
932 "Left and right sizes are incompatible.");
933 for (size_t i = 0, ie = a.size(); i < ie; i++) {
934 a[i] /= b.at(i);
935 }
936 return a;
937 }
938
939//------------------------------------------------------------------------------
948//------------------------------------------------------------------------------
949 template<jit::float_scalar T>
951 buffer<T> &b,
952 buffer<T> &c) {
953 constexpr bool use_fma = !jit::complex_scalar<T> &&
954#ifdef FP_FAST_FMA
955 true;
956#else
957 false;
958#endif
959
960 if (a.size() == 1) {
961 const T left = a.at(0);
962
963 if (b.size() == 1) {
964 const T middle = b.at(0);
965 for (size_t i = 0, ie = c.size(); i < ie; i++) {
966 if constexpr (use_fma) {
967 c[i] = std::fma(left, middle, c.at(i));
968 } else {
969 c[i] = left*middle + c.at(i);
970 }
971 }
972 return c;
973 } else if (c.size() == 1) {
974 const T right = c.at(0);
975 for (size_t i = 0, ie = b.size(); i < ie; i++) {
976 if constexpr (use_fma) {
977 b[i] = std::fma(left, b.at(i), right);
978 } else {
979 b[i] = left*b.at(i) + right;
980 }
981 }
982 return b;
983 }
984
985 assert(b.size() == c.size() &&
986 "Size mismatch between middle and right.");
987 for (size_t i = 0, ie = b.size(); i < ie; i++) {
988 if constexpr (use_fma) {
989 b[i] = std::fma(left, b.at(i), c.at(i));
990 } else {
991 b[i] = left*b.at(i) + c.at(i);
992 }
993 }
994 return b;
995 } else if (b.size() == 1) {
996 const T middle = b.at(0);
997 if (c.size() == 1) {
998 const T right = c.at(0);
999 for (size_t i = 0, ie = a.size(); i < ie; i++) {
1000 if constexpr (use_fma) {
1001 a[i] = std::fma(a.at(i), middle, right);
1002 } else {
1003 a[i] = a.at(i)*middle + right;
1004 }
1005 }
1006 return a;
1007 }
1008
1009 assert(a.size() == c.size() &&
1010 "Size mismatch between left and right.");
1011 for (size_t i = 0, ie = a.size(); i < ie; i++) {
1012 if constexpr (use_fma) {
1013 a[i] = std::fma(a.at(i), middle, c.at(i));
1014 } else {
1015 a[i] = a.at(i)*middle + c.at(i);
1016 }
1017 }
1018 return a;
1019 } else if (c.size() == 1) {
1020 assert(a.size() == b.size() &&
1021 "Size mismatch between left and middle.");
1022 const T right = c.at(0);
1023 for (size_t i = 0, ie = a.size(); i < ie; i++) {
1024 if constexpr (use_fma) {
1025 a[i] = std::fma(a.at(i), b.at(i), right);
1026 } else {
1027 a[i] = a.at(i)*b.at(i) + right;
1028 }
1029 }
1030 return a;
1031 }
1032
1033 assert(a.size() == b.size() &&
1034 b.size() == c.size() &&
1035 a.size() == c.size() &&
1036 "Left, middle and right sizes are incompatible.");
1037 for (size_t i = 0, ie = a.size(); i < ie; i++) {
1038 if constexpr (use_fma) {
1039 a[i] = std::fma(a.at(i), b.at(i), c.at(i));
1040 } else {
1041 a[i] = a.at(i)*b.at(i) + c.at(i);
1042 }
1043 }
1044 return a;
1045 }
1046
1047//------------------------------------------------------------------------------
1055//------------------------------------------------------------------------------
1056 template<jit::float_scalar T>
1058 buffer<T> &exponent) {
1059 if (exponent.size() == 1) {
1060 const T right = exponent.at(0);
1061 if (std::imag(right) == 0) {
1062 const int64_t right_int = static_cast<int64_t> (std::real(right));
1063 if (std::real(right) - right_int) {
1064 if (right == static_cast<T> (0.5)) {
1065 base.sqrt();
1066 return base;
1067 }
1068
1069 for (size_t i = 0, ie = base.size(); i < ie; i++) {
1070 base[i] = std::pow(base.at(i), right);
1071 }
1072 return base;
1073 }
1074
1075 if (right_int > 0) {
1076 for (size_t i = 0, ie = base.size(); i < ie; i++) {
1077 const T left = base.at(i);
1078 for (size_t j = 0, je = right_int - 1; j < je; j++) {
1079 base[i] *= left;
1080 }
1081 }
1082 return base;
1083 } else if (right_int == 0) {
1084 for (size_t i = 0, ie = base.size(); i < ie; i++) {
1085 base[i] = 1.0;
1086 }
1087 return base;
1088 } else {
1089 for (size_t i = 0, ie = base.size(); i < ie; i++) {
1090 const T left = static_cast<T> (1.0)/base.at(i);
1091 base[i] = left;
1092 for (size_t j = 0, je = std::abs(right_int) - 1; j < je; j++) {
1093 base[i] *= left;
1094 }
1095 }
1096 return base;
1097 }
1098 } else {
1099 for (size_t i = 0, ie = base.size(); i < ie; i++) {
1100 base[i] = std::pow(base.at(i), right);
1101 }
1102 return base;
1103 }
1104 } else if (base.size() == 1) {
1105 const T left = base.at(0);
1106 for (size_t i = 0, ie = exponent.size(); i < ie; i++) {
1107 exponent[i] = std::pow(left, exponent.at(i));
1108 }
1109 return exponent;
1110 }
1111
1112 assert(base.size() == exponent.size() &&
1113 "Left and right sizes are incompatible.");
1114 for (size_t i = 0, ie = base.size(); i < ie; i++) {
1115 base[i] = std::pow(base.at(i), exponent.at(i));
1116 }
1117 return base;
1118 }
1119
1120//------------------------------------------------------------------------------
1128//------------------------------------------------------------------------------
1129 template<jit::float_scalar T>
1131 buffer<T> &y) {
1132 if (y.size() == 1) {
1133 const T right = y.at(0);
1134 for (size_t i = 0, ie = x.size(); i < ie; i++) {
1135 if constexpr (jit::complex_scalar<T>) {
1136 x[i] = std::atan(right/x[i]);
1137 } else {
1138 x[i] = std::atan2(right, x[i]);
1139 }
1140 }
1141 return x;
1142 } else if (x.size() == 1) {
1143 const T left = x.at(0);
1144 for (size_t i = 0, ie = y.size(); i < ie; i++) {
1145 if constexpr (jit::complex_scalar<T>) {
1146 y[i] = std::atan(y[i]/left);
1147 } else {
1148 y[i] = std::atan2(y[i], left);
1149 }
1150 }
1151 return y;
1152 }
1153
1154 assert(x.size() == y.size() &&
1155 "Left and right sizes are incompatible.");
1156 for (size_t i = 0, ie = x.size(); i < ie; i++) {
1157 if constexpr (jit::complex_scalar<T>) {
1158 x[i] = std::atan(y[i]/x[i]);
1159 } else {
1160 x[i] = std::atan2(y[i], x[i]);
1161 }
1162 }
1163 return x;
1164 }
1165}
1166
1167#endif /* backend_h */
Class representing a generic buffer.
Definition backend.hpp:29
void set(const std::vector< T > &d)
Assign a vector value.
Definition backend.hpp:109
void subtract_row(const buffer< T > &x)
Subtract row operation.
Definition backend.hpp:407
void multiply_row(const buffer< T > &x)
Multiply row operation.
Definition backend.hpp:479
void add_col(const buffer< T > &x)
Add col operation.
Definition backend.hpp:371
void erfi()
Take erfi.
Definition backend.hpp:259
void log()
Take log.
Definition backend.hpp:232
buffer(const buffer &d)
Construct a buffer backend from a buffer backend.
Definition backend.hpp:71
void sqrt()
Take sqrt.
Definition backend.hpp:214
buffer(const size_t s)
Construct a buffer backend with a size.
Definition backend.hpp:46
T & operator[](const size_t index)
Index operator.
Definition backend.hpp:77
bool is_normal() const
Check for normal values.
Definition backend.hpp:279
buffer()
Construct an empty buffer backend.
Definition backend.hpp:38
buffer< T > index_row(const size_t index, const size_t num_columns)
Index row.
Definition backend.hpp:302
bool is_negative() const
Is every element negative.
Definition backend.hpp:171
const T at(const size_t index) const
Get value at.
Definition backend.hpp:91
void pow_col(const buffer< T > &x)
Pow col operation.
Definition backend.hpp:747
buffer(const std::vector< T > &d)
Construct a buffer backend from a vector.
Definition backend.hpp:63
bool is_same() const
Is every element the same.
Definition backend.hpp:125
buffer(const size_t s, const T d)
Construct a buffer backend with a size.
Definition backend.hpp:55
void divide_col(const buffer< T > &x)
Divide col operation.
Definition backend.hpp:587
bool is_none() const
Is every element negative one.
Definition backend.hpp:201
void pow_row(const buffer< T > &x)
Pow row operation.
Definition backend.hpp:711
void sin()
Take sin.
Definition backend.hpp:241
void add_row(const buffer< T > &x)
Add row operation.
Definition backend.hpp:335
void subtract_col(const buffer< T > &x)
Subtract col operation.
Definition backend.hpp:443
void multiply_col(const buffer< T > &x)
Multiply col operation.
Definition backend.hpp:515
bool is_even() const
Is every element even.
Definition backend.hpp:186
T base
Type def to retrieve the backend T type.
Definition backend.hpp:776
void divide_row(const buffer< T > &x)
Divide row operation.
Definition backend.hpp:551
void atan_col(const buffer< T > &x)
Atan col operation.
Definition backend.hpp:667
void atan_row(const buffer< T > &x)
Atan row operation.
Definition backend.hpp:623
buffer< T > index_column(const size_t index, const size_t num_columns)
Index column.
Definition backend.hpp:318
void exp()
Take exp.
Definition backend.hpp:223
bool has_zero() const
Is any element zero.
Definition backend.hpp:156
bool is_zero() const
Is every element zero.
Definition backend.hpp:141
size_t size() const
Get size of the buffer.
Definition backend.hpp:116
void set(const T d)
Assign a constant value.
Definition backend.hpp:100
T * data()
Get a pointer to the basic memory buffer.
Definition backend.hpp:270
void cos()
Take cos.
Definition backend.hpp:250
Complex scalar concept.
Definition register.hpp:24
subroutine assert(test, message)
Assert check.
Definition f_binding_test.f90:38
Name space for backend buffers.
Definition backend.hpp:19
buffer< T > fma(buffer< T > &a, buffer< T > &b, buffer< T > &c)
Fused multiply add operation.
Definition backend.hpp:950
buffer< T > operator*(buffer< T > &a, buffer< T > &b)
Multiply operation.
Definition backend.hpp:881
bool operator==(const buffer< T > &a, const buffer< T > &b)
Equal operation.
Definition backend.hpp:823
buffer< T > atan(buffer< T > &x, buffer< T > &y)
Take the inverse tangent.
Definition backend.hpp:1130
buffer< T > operator/(buffer< T > &a, buffer< T > &b)
Divide operation.
Definition backend.hpp:915
buffer< T > operator-(buffer< T > &a, buffer< T > &b)
Subtract operation.
Definition backend.hpp:847
buffer< T > operator+(buffer< T > &a, buffer< T > &b)
Add operation.
Definition backend.hpp:789
buffer< T > pow(buffer< T > &base, buffer< T > &exponent)
Take the power.
Definition backend.hpp:1057
complex_type< T > erfi(const complex_type< T > z)
erfi(z) = -i erf(iz)
Definition special_functions.hpp:1583
Utilities for writing jit source code.
Implementations for special functions.