Graph Framework
Loading...
Searching...
No Matches
math.hpp
Go to the documentation of this file.
1//------------------------------------------------------------------------------
4//------------------------------------------------------------------------------
5
6#ifndef math_h
7#define math_h
8
9#include <cmath>
10
11#include "node.hpp"
12
13namespace graph {
14//******************************************************************************
15// Sqrt node.
16//******************************************************************************
17//------------------------------------------------------------------------------
24//------------------------------------------------------------------------------
25 template<jit::float_scalar T, bool SAFE_MATH=false>
26 class sqrt_node final : public straight_node<T, SAFE_MATH> {
27 private:
28//------------------------------------------------------------------------------
33//------------------------------------------------------------------------------
34 static std::string to_string(leaf_node<T, SAFE_MATH> *a) {
35 return "sqrt" + jit::format_to_string(reinterpret_cast<size_t> (a));
36 }
37
38 public:
39//------------------------------------------------------------------------------
43//------------------------------------------------------------------------------
46
47//------------------------------------------------------------------------------
53//------------------------------------------------------------------------------
55 backend::buffer<T> result = this->arg->evaluate();
56 result.sqrt();
57 return result;
58 }
59
60//------------------------------------------------------------------------------
64//------------------------------------------------------------------------------
66 auto ac = constant_cast(this->arg);
67
68 if (ac.get()) {
69 if (ac->is(0) || ac->is(1)) {
70 return this->arg;
71 }
72 return constant<T, SAFE_MATH> (this->evaluate());
73 }
74
75 auto ap1 = piecewise_1D_cast(this->arg);
76 if (ap1.get()) {
77 return piecewise_1D(this->evaluate(),
78 ap1->get_arg(),
79 ap1->get_scale(),
80 ap1->get_offset());
81 }
82
83 auto ap2 = piecewise_2D_cast(this->arg);
84 if (ap2.get()) {
85 return piecewise_2D(this->evaluate(),
86 ap2->get_num_columns(),
87 ap2->get_left(), ap2->get_x_scale(), ap2->get_x_offset(),
88 ap2->get_right(), ap2->get_y_scale(), ap2->get_y_offset());
89 }
90
91// Handle casses like sqrt(c*x) where c is constant or cases like
92// sqrt((x^a)*y).
93 auto am = multiply_cast(this->arg);
94 if (am.get()) {
95 if (pow_cast(am->get_left()).get() ||
96 am->get_left()->is_constant() ||
97 pow_cast(am->get_right()).get() ||
98 am->get_right()->is_constant()) {
99 return sqrt(am->get_left()) *
100 sqrt(am->get_right());
101 }
102 }
103
104 auto ad = divide_cast(this->arg);
105 if (ad.get()) {
106// sqrt((c1*x)/y) -> c2*sqrt(x/y)
107 auto alm = multiply_cast(ad->get_left());
108 if (alm.get() && alm->get_left()->is_constant()) {
109 return sqrt(alm->get_left()) *
110 sqrt(alm->get_right()/ad->get_right());
111 }
112
113// Handle cases like sqrt(x^a/b) and sqrt(a/x^b) or sqrt(c/b) and sqrt(a/c)
114// where c is a constant.
115 if (pow_cast(ad->get_left()).get() ||
116 ad->get_left()->is_constant() ||
117 pow_cast(ad->get_right()).get() ||
118 ad->get_right()->is_constant()) {
119 return sqrt(ad->get_left()) /
120 sqrt(ad->get_right());
121 }
122 }
123
124 return this->shared_from_this();
125 }
126
127//------------------------------------------------------------------------------
134//------------------------------------------------------------------------------
136 if (this->is_match(x)) {
137 return one<T, SAFE_MATH> ();
138 }
139
140 const size_t hash = reinterpret_cast<size_t> (x.get());
141 if (this->df_cache.find(hash) == this->df_cache.end()) {
142 this->df_cache[hash] = this->arg->df(x)
143 / (2.0*this->shared_from_this());
144 }
145 return this->df_cache[hash];
146 }
147
148//------------------------------------------------------------------------------
156//------------------------------------------------------------------------------
158 compile(std::ostringstream &stream,
159 jit::register_map &registers,
161 const jit::register_usage &usage) {
162 if (registers.find(this) == registers.end()) {
163 shared_leaf<T, SAFE_MATH> a = this->arg->compile(stream,
164 registers,
165 indices,
166 usage);
167
168 registers[this] = jit::to_string('r', this);
169 stream << " const ";
170 jit::add_type<T> (stream);
171 stream << " " << registers[this] << " = sqrt("
172 << registers[a.get()] << ")";
173 this->endline(stream, usage);
174 }
175
176 return this->shared_from_this();
177 }
178
179//------------------------------------------------------------------------------
184//------------------------------------------------------------------------------
186 if (this == x.get()) {
187 return true;
188 }
189
190 auto x_cast = sqrt_cast(x);
191 if (x_cast.get()) {
192 return this->arg->is_match(x_cast->get_arg());
193 }
194
195 return false;
196 }
197
198//------------------------------------------------------------------------------
200//------------------------------------------------------------------------------
201 virtual void to_latex() const {
202 std::cout << "\\sqrt{";
203 this->arg->to_latex();
204 std::cout << "}";
205 }
206
207//------------------------------------------------------------------------------
211//------------------------------------------------------------------------------
212 virtual bool is_power_like() const {
213 return true;
214 }
215
216//------------------------------------------------------------------------------
220//------------------------------------------------------------------------------
222 return this->arg;
223 }
224
225//------------------------------------------------------------------------------
229//------------------------------------------------------------------------------
231 return constant<T, SAFE_MATH> (static_cast<T> (0.5));
232 }
233
234//------------------------------------------------------------------------------
238//------------------------------------------------------------------------------
240 if (this->has_pseudo()) {
241 return sqrt(this->arg->remove_pseudo());
242 }
243 return this->shared_from_this();
244 }
245
246//------------------------------------------------------------------------------
252//------------------------------------------------------------------------------
253 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
254 jit::register_map &registers) {
255 if (registers.find(this) == registers.end()) {
256 const std::string name = jit::to_string('r', this);
257 registers[this] = name;
258 stream << " " << name
259 << " [label = \"sqrt\", shape = oval, style = filled, fillcolor = blue, fontcolor = white];" << std::endl;
260
261 auto a = this->arg->to_vizgraph(stream, registers);
262 stream << " " << name << " -- " << registers[a.get()] << ";" << std::endl;
263 }
264
265 return this->shared_from_this();
266 }
267 };
268
269//------------------------------------------------------------------------------
277//------------------------------------------------------------------------------
278 template<jit::float_scalar T, bool SAFE_MATH=false>
280 auto temp = std::make_shared<sqrt_node<T, SAFE_MATH>> (x)->reduce();
281// Test for hash collisions.
282 for (size_t i = temp->get_hash();
284 if (leaf_node<T, SAFE_MATH>::caches.nodes.find(i) ==
287 return temp;
288 } else if (temp->is_match(leaf_node<T, SAFE_MATH>::caches.nodes[i])) {
290 }
291 }
292#if defined(__clang__) || defined(__GNUC__)
294#else
295 assert(false && "Should never reach.");
296#endif
297 }
298
300 template<jit::float_scalar T, bool SAFE_MATH=false>
301 using shared_sqrt = std::shared_ptr<sqrt_node<T, SAFE_MATH>>;
302
303//------------------------------------------------------------------------------
311//------------------------------------------------------------------------------
312 template<jit::float_scalar T, bool SAFE_MATH=false>
314 return std::dynamic_pointer_cast<sqrt_node<T, SAFE_MATH>> (x);
315 }
316
317//******************************************************************************
318// Exp node.
319//******************************************************************************
320//------------------------------------------------------------------------------
327//------------------------------------------------------------------------------
328 template<jit::float_scalar T, bool SAFE_MATH=false>
329 class exp_node final : public straight_node<T, SAFE_MATH> {
330 private:
331//------------------------------------------------------------------------------
336//------------------------------------------------------------------------------
337 static std::string to_string(leaf_node<T, SAFE_MATH> *a) {
338 return "exp" + jit::format_to_string(reinterpret_cast<size_t> (a));
339 }
340
341 public:
342//------------------------------------------------------------------------------
346//------------------------------------------------------------------------------
349
350//------------------------------------------------------------------------------
356//------------------------------------------------------------------------------
358 backend::buffer<T> result = this->arg->evaluate();
359 result.exp();
360 return result;
361 }
362
363//------------------------------------------------------------------------------
367//------------------------------------------------------------------------------
369 if (constant_cast(this->arg).get()) {
370 return constant<T, SAFE_MATH> (this->evaluate());
371 }
372
373 auto ap1 = piecewise_1D_cast(this->arg);
374 if (ap1.get()) {
375 return piecewise_1D(this->evaluate(),
376 ap1->get_arg(),
377 ap1->get_scale(),
378 ap1->get_offset());
379 }
380
381 auto ap2 = piecewise_2D_cast(this->arg);
382 if (ap2.get()) {
383 return piecewise_2D(this->evaluate(),
384 ap2->get_num_columns(),
385 ap2->get_left(), ap2->get_x_scale(), ap2->get_x_offset(),
386 ap2->get_right(), ap2->get_y_scale(), ap2->get_y_offset());
387 }
388
389// Reduce exp(log(x)) -> x
390 auto a = log_cast(this->arg);
391 if (a.get()) {
392 return a->get_arg();
393 }
394
395 return this->shared_from_this();
396 }
397
398//------------------------------------------------------------------------------
405//------------------------------------------------------------------------------
407 if (this->is_match(x)) {
408 return one<T, SAFE_MATH> ();
409 }
410
411 const size_t hash = reinterpret_cast<size_t> (x.get());
412 if (this->df_cache.find(hash) == this->df_cache.end()) {
413 this->df_cache[hash] = this->shared_from_this()*this->arg->df(x);
414 }
415 return this->df_cache[hash];
416 }
417
418//------------------------------------------------------------------------------
426//------------------------------------------------------------------------------
428 compile(std::ostringstream &stream,
429 jit::register_map &registers,
431 const jit::register_usage &usage) {
432 if (registers.find(this) == registers.end()) {
433 shared_leaf<T, SAFE_MATH> a = this->arg->compile(stream,
434 registers,
435 indices,
436 usage);
437
438 registers[this] = jit::to_string('r', this);
439 stream << " const ";
440 jit::add_type<T> (stream);
441 stream << " " << registers[this] << " = ";
442 if constexpr (SAFE_MATH) {
443 if constexpr (jit::complex_scalar<T>) {
444 stream << "real(";
445 }
446 stream << registers[a.get()];
447 if constexpr (jit::complex_scalar<T>) {
448 stream << ")";
449 }
450 stream << " < 709.8 ? ";
451 }
452 stream << "exp(" << registers[a.get()] << ")";
453 if constexpr (SAFE_MATH) {
454 stream << " : ";
455 if constexpr (jit::complex_scalar<T>) {
456 jit::add_type<T> (stream);
457 stream << "(";
458 }
460 if constexpr (jit::complex_scalar<T>) {
461 stream << ")";
462 }
463 }
464 stream << "";
465 this->endline(stream, usage);
466 }
467
468 return this->shared_from_this();
469 }
470
471//------------------------------------------------------------------------------
476//------------------------------------------------------------------------------
478 if (this == x.get()) {
479 return true;
480 }
481
482 auto x_cast = exp_cast(x);
483 if (x_cast.get()) {
484 return this->arg->is_match(x_cast->get_arg());
485 }
486
487 return false;
488 }
489
490//------------------------------------------------------------------------------
492//------------------------------------------------------------------------------
493 virtual void to_latex() const {
494 std::cout << "e^{\\left(";
495 this->arg->to_latex();
496 std::cout << "\\right)}";
497 }
498
499//------------------------------------------------------------------------------
503//------------------------------------------------------------------------------
505 if (this->has_pseudo()) {
506 return exp(this->arg->remove_pseudo());
507 }
508 return this->shared_from_this();
509 }
510
511//------------------------------------------------------------------------------
517//------------------------------------------------------------------------------
518 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
519 jit::register_map &registers) {
520 if (registers.find(this) == registers.end()) {
521 const std::string name = jit::to_string('r', this);
522 registers[this] = name;
523 stream << " " << name
524 << " [label = \"exp\", shape = oval, style = filled, fillcolor = blue, fontcolor = white];" << std::endl;
525
526 auto a = this->arg->to_vizgraph(stream, registers);
527 stream << " " << name << " -- " << registers[a.get()] << ";" << std::endl;
528 }
529
530 return this->shared_from_this();
531 }
532 };
533
534//------------------------------------------------------------------------------
542//------------------------------------------------------------------------------
543 template<jit::float_scalar T, bool SAFE_MATH=false>
545 auto temp = std::make_shared<exp_node<T, SAFE_MATH>> (x)->reduce();
546// Test for hash collisions.
547 for (size_t i = temp->get_hash();
549 if (leaf_node<T, SAFE_MATH>::caches.nodes.find(i) ==
552 return temp;
553 } else if (temp->is_match(leaf_node<T, SAFE_MATH>::caches.nodes[i])) {
555 }
556 }
557#if defined(__clang__) || defined(__GNUC__)
559#else
560 assert(false && "Should never reach.");
561#endif
562 }
563
565 template<jit::float_scalar T, bool SAFE_MATH=false>
566 using shared_exp = std::shared_ptr<exp_node<T, SAFE_MATH>>;
567
568//------------------------------------------------------------------------------
576//------------------------------------------------------------------------------
577 template<jit::float_scalar T, bool SAFE_MATH=false>
579 return std::dynamic_pointer_cast<exp_node<T, SAFE_MATH>> (x);
580 }
581
582//******************************************************************************
583// Log node.
584//******************************************************************************
585//------------------------------------------------------------------------------
592//------------------------------------------------------------------------------
593 template<jit::float_scalar T, bool SAFE_MATH=false>
594 class log_node final : public straight_node<T, SAFE_MATH> {
595 private:
596//------------------------------------------------------------------------------
601//------------------------------------------------------------------------------
602 static std::string to_string(leaf_node<T, SAFE_MATH> *a) {
603 return "log" + jit::format_to_string(reinterpret_cast<size_t> (a));
604 }
605
606 public:
607//------------------------------------------------------------------------------
611//------------------------------------------------------------------------------
614
615//------------------------------------------------------------------------------
621//------------------------------------------------------------------------------
623 backend::buffer<T> result = this->arg->evaluate();
624 result.log();
625 return result;
626 }
627
628//------------------------------------------------------------------------------
632//------------------------------------------------------------------------------
634 if (constant_cast(this->arg).get()) {
635 return constant<T, SAFE_MATH> (this->evaluate());
636 }
637
638 auto ap1 = piecewise_1D_cast(this->arg);
639 if (ap1.get()) {
640 return piecewise_1D(this->evaluate(),
641 ap1->get_arg(),
642 ap1->get_scale(),
643 ap1->get_offset());
644 }
645
646 auto ap2 = piecewise_2D_cast(this->arg);
647 if (ap2.get()) {
648 return piecewise_2D(this->evaluate(),
649 ap2->get_num_columns(),
650 ap2->get_left(), ap2->get_x_scale(), ap2->get_x_offset(),
651 ap2->get_right(), ap2->get_y_scale(), ap2->get_y_offset());
652 }
653
654// Reduce log(exp(x)) -> x
655 auto a = exp_cast(this->arg);
656 if (a.get()) {
657 return a->get_arg();
658 }
659
660 return this->shared_from_this();
661 }
662
663//------------------------------------------------------------------------------
670//------------------------------------------------------------------------------
672 if (this->is_match(x)) {
673 return one<T, SAFE_MATH> ();
674 }
675
676 const size_t hash = reinterpret_cast<size_t> (x.get());
677 if (this->df_cache.find(hash) == this->df_cache.end()) {
678 this->df_cache[hash] = this->arg->df(x)/this->arg;
679 }
680 return this->df_cache[hash];
681 }
682
683//------------------------------------------------------------------------------
691//------------------------------------------------------------------------------
693 compile(std::ostringstream &stream,
694 jit::register_map &registers,
696 const jit::register_usage &usage) {
697 if (registers.find(this) == registers.end()) {
698 shared_leaf<T, SAFE_MATH> a = this->arg->compile(stream,
699 registers,
700 indices,
701 usage);
702
703 registers[this] = jit::to_string('r', this);
704 stream << " const ";
705 jit::add_type<T> (stream);
706 stream << " " << registers[this] << " = log("
707 << registers[a.get()] << ")";
708 this->endline(stream, usage);
709 }
710
711 return this->shared_from_this();
712 }
713
714//------------------------------------------------------------------------------
719//------------------------------------------------------------------------------
721 if (this == x.get()) {
722 return true;
723 }
724
725 auto x_cast = log_cast(x);
726 if (x_cast.get()) {
727 return this->arg->is_match(x_cast->get_arg());
728 }
729
730 return false;
731 }
732
733//------------------------------------------------------------------------------
735//------------------------------------------------------------------------------
736 virtual void to_latex() const {
737 std::cout << "\\ln{\\left(";
738 this->arg->to_latex();
739 std::cout << "\\right)}";
740 }
741
742//------------------------------------------------------------------------------
746//------------------------------------------------------------------------------
748 if (this->has_pseudo()) {
749 return log(this->arg->remove_pseudo());
750 }
751 return this->shared_from_this();
752 }
753
754//------------------------------------------------------------------------------
760//------------------------------------------------------------------------------
761 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
762 jit::register_map &registers) {
763 if (registers.find(this) == registers.end()) {
764 const std::string name = jit::to_string('r', this);
765 registers[this] = name;
766 stream << " " << name
767 << " [label = \"log\", shape = oval, style = filled, fillcolor = blue, fontcolor = white];" << std::endl;
768
769 auto a = this->arg->to_vizgraph(stream, registers);
770 stream << " " << name << " -- " << registers[a.get()] << ";" << std::endl;
771 }
772
773 return this->shared_from_this();
774 }
775 };
776
777//------------------------------------------------------------------------------
785//------------------------------------------------------------------------------
786 template<jit::float_scalar T, bool SAFE_MATH=false>
788 auto temp = std::make_shared<log_node<T, SAFE_MATH>> (x)->reduce();
789// Test for hash collisions.
790 for (size_t i = temp->get_hash(); i < std::numeric_limits<size_t>::max(); i++) {
791 if (leaf_node<T, SAFE_MATH>::caches.nodes.find(i) ==
794 return temp;
795 } else if (temp->is_match(leaf_node<T, SAFE_MATH>::caches.nodes[i])) {
797 }
798 }
799#if defined(__clang__) || defined(__GNUC__)
801#else
802 assert(false && "Should never reach.");
803#endif
804 }
805
807 template<jit::float_scalar T, bool SAFE_MATH=false>
808 using shared_log = std::shared_ptr<log_node<T, SAFE_MATH>>;
809
810//------------------------------------------------------------------------------
818//------------------------------------------------------------------------------
819 template<jit::float_scalar T, bool SAFE_MATH=false>
821 return std::dynamic_pointer_cast<log_node<T, SAFE_MATH>> (x);
822 }
823
824//******************************************************************************
825// Pow node.
826//******************************************************************************
827//------------------------------------------------------------------------------
834//------------------------------------------------------------------------------
835 template<jit::float_scalar T, bool SAFE_MATH=false>
836 class pow_node final : public branch_node<T, SAFE_MATH> {
837 private:
838//------------------------------------------------------------------------------
844//------------------------------------------------------------------------------
845 static std::string to_string(leaf_node<T, SAFE_MATH> *l,
847 return "pow" + jit::format_to_string(reinterpret_cast<size_t> (l))
848 + jit::format_to_string(reinterpret_cast<size_t> (r));
849 }
850
851 public:
852//------------------------------------------------------------------------------
857//------------------------------------------------------------------------------
862
863//------------------------------------------------------------------------------
869//------------------------------------------------------------------------------
871 backend::buffer<T> l_result = this->left->evaluate();
872 backend::buffer<T> r_result = this->right->evaluate();
873 return backend::pow(l_result, r_result);
874 }
875
876//------------------------------------------------------------------------------
880//------------------------------------------------------------------------------
882 auto lc = constant_cast(this->left);
883 auto rc = constant_cast(this->right);
884
885 if (rc.get() && rc->is(0)) {
886 return one<T, SAFE_MATH> ();
887 } else if (rc.get() && rc->is(1)) {
888 return this->left;
889 } else if (rc.get() && rc->is(0.5)) {
890 return sqrt(this->left);
891 } else if (rc.get() && rc->is(2)){
892 auto sq = sqrt_cast(this->left);
893 if (sq.get()) {
894 return sq->get_arg();
895 }
896 }
897
898 if (lc.get() && rc.get()) {
899 return constant<T, SAFE_MATH> (this->evaluate());
900 }
901
902 auto pl1 = piecewise_1D_cast(this->left);
903 auto pr1 = piecewise_1D_cast(this->right);
904 if (pl1.get() && (rc.get() || pl1->is_arg_match(this->right))) {
905 return piecewise_1D(this->evaluate(), pl1->get_arg(),
906 pl1->get_scale(), pl1->get_offset());
907 } else if (pr1.get() && (lc.get() || pr1->is_arg_match(this->left))) {
908 return piecewise_1D(this->evaluate(), pr1->get_arg(),
909 pr1->get_scale(), pr1->get_offset());
910 }
911
912 auto pl2 = piecewise_2D_cast(this->left);
913 auto pr2 = piecewise_2D_cast(this->right);
914 if (pl2.get() && (rc.get() || pl2->is_arg_match(this->right))) {
915 return piecewise_2D(this->evaluate(),
916 pl2->get_num_columns(),
917 pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
918 pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
919 } else if (pr2.get() && (lc.get() || pr2->is_arg_match(this->left))) {
920 return piecewise_2D(this->evaluate(),
921 pr2->get_num_columns(),
922 pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
923 pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
924 }
925
926// Combine 2D and 1D piecewise constants if a row or column matches.
927 if (pr2.get() && pr2->is_row_match(this->left)) {
928 backend::buffer<T> result = pl1->evaluate();
929 result.pow_row(pr2->evaluate());
930 return piecewise_2D(result,
931 pr2->get_num_columns(),
932 pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
933 pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
934 } else if (pr2.get() && pr2->is_col_match(this->left)) {
935 backend::buffer<T> result = pl1->evaluate();
936 result.pow_col(pr2->evaluate());
937 return piecewise_2D(result,
938 pr2->get_num_columns(),
939 pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
940 pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
941 } else if (pl2.get() && pl2->is_row_match(this->right)) {
942 backend::buffer<T> result = pl2->evaluate();
943 result.pow_row(pr1->evaluate());
944 return piecewise_2D(result,
945 pl2->get_num_columns(),
946 pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
947 pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
948 } else if (pl2.get() && pl2->is_col_match(this->right)) {
949 backend::buffer<T> result = pl2->evaluate();
950 result.pow_col(pr1->evaluate());
951 return piecewise_2D(result,
952 pl2->get_num_columns(),
953 pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
954 pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
955 }
956
957 auto lp = pow_cast(this->left);
958// Only run this reduction if the right is an integer constant value.
959 if (lp.get() && rc.get() && rc->is_integer()) {
960 return pow(lp->get_left(), lp->get_right()*this->right);
961 }
962
963// Handle cases where (c*x)^a, (x*c)^a, (a*sqrt(b))^c and (a*b^c)^2.
964// These reductions only make sense if the power is constant.
965 auto lm = multiply_cast(this->left);
966 if (lm.get() && rc.get()) {
967 if (lm->get_left()->is_constant() ||
968 lm->get_right()->is_constant() ||
969 sqrt_cast(lm->get_left()).get() ||
970 sqrt_cast(lm->get_right()).get() ||
971 pow_cast(lm->get_left()).get() ||
972 pow_cast(lm->get_right()).get()) {
973 return pow(lm->get_left(), this->right) *
974 pow(lm->get_right(), this->right);
975 }
976
977// ((Sqrt(a)*b)*c)^d -> a^(d/2)*(b*c)^d
978// ((b*Sqrt(a))*c)^d -> a^(d/2)*(b*c)^d
979 auto lmlm = multiply_cast(lm->get_left());
980 if (lmlm.get()) {
981 if (lmlm->get_left()->is_constant() ||
982 lmlm->get_right()->is_constant() ||
983 sqrt_cast(lmlm->get_left()).get() ||
984 sqrt_cast(lmlm->get_right()).get() ||
985 pow_cast(lmlm->get_left()).get() ||
986 pow_cast(lmlm->get_right()).get()) {
987 return pow(lmlm->get_left(), this->right) *
988 pow(lmlm->get_right(), this->right) *
989 pow(lm->get_right(), this->right);
990 }
991 }
992 }
993
994// These reductions only make sense if the power is constant.
995 auto ld = divide_cast(this->left);
996 if (ld.get() && rc.get()) {
997// For even exponents e.
998// (-a/b)^e -> (a/b)^e
999 auto ldlm = multiply_cast(ld->get_left());
1000 if (ldlm.get()) {
1001 if (rc.get() &&
1002 rc->evaluate().is_even()) {
1003 if (ldlm->get_left()->is_constant()) {
1004 return pow(ldlm->get_left(), this->right) *
1005 pow(ldlm->get_right()/ld->get_right(),
1006 this->right);
1007 }
1008 }
1009 if (ldlm->get_left()->is_constant() ||
1010 ldlm->get_right()->is_constant() ||
1011 sqrt_cast(ldlm->get_left()).get() ||
1012 sqrt_cast(ldlm->get_right()).get() ||
1013 pow_cast(ldlm->get_left()).get() ||
1014 pow_cast(ldlm->get_right()).get()) {
1015 return pow(ldlm->get_left(), this->right) *
1016 pow(ldlm->get_right(), this->right)/
1017 pow(ld->get_right(), this->right);
1018 }
1019
1020 auto ldlmlm = multiply_cast(ldlm->get_left());
1021 if (ldlmlm.get()) {
1022 if (ldlmlm->get_left()->is_constant() ||
1023 ldlmlm->get_right()->is_constant() ||
1024 sqrt_cast(ldlmlm->get_left()).get() ||
1025 sqrt_cast(ldlmlm->get_right()).get() ||
1026 pow_cast(ldlmlm->get_left()).get() ||
1027 pow_cast(ldlmlm->get_right()).get()) {
1028 return (pow(ldlmlm->get_left(), this->right) *
1029 pow(ldlmlm->get_right(), this->right) *
1030 pow(ldlm->get_right(), this->right)) /
1031 pow(ld->get_right(), this->right);
1032 }
1033 }
1034
1035 auto ldlmrm = multiply_cast(ldlm->get_right());
1036 if (ldlmrm.get()) {
1037 if (ldlmrm->get_left()->is_constant() ||
1038 ldlmrm->get_right()->is_constant() ||
1039 sqrt_cast(ldlmrm->get_left()).get() ||
1040 sqrt_cast(ldlmrm->get_right()).get() ||
1041 pow_cast(ldlmrm->get_left()).get() ||
1042 pow_cast(ldlmrm->get_right()).get()) {
1043 return (pow(ldlmrm->get_left(), this->right) *
1044 pow(ldlmrm->get_right(), this->right) *
1045 pow(ldlm->get_left(), this->right)) /
1046 pow(ld->get_right(), this->right);
1047 }
1048 }
1049 }
1050
1051// Handle cases where (c/x)^a, (x/c)^a, (a/sqrt(b))^c and (a/b^c)^2.
1052 if (ld->get_left()->is_constant() ||
1053 ld->get_right()->is_constant() ||
1054 sqrt_cast(ld->get_left()).get() ||
1055 sqrt_cast(ld->get_right()).get() ||
1056 pow_cast(ld->get_left()).get() ||
1057 pow_cast(ld->get_right()).get()) {
1058 return pow(ld->get_left(), this->right) /
1059 pow(ld->get_right(), this->right);
1060 }
1061
1062// Handle cases where (a/(b*sqrt(c))), (a/(sqrt(c)*b)), (a/(b*c^d)), (a/(c^d*b))
1063 auto ldrm = multiply_cast(ld->get_right());
1064 if (ldrm.get()) {
1065 if (ldrm->get_left()->is_constant() ||
1066 ldrm->get_right()->is_constant() ||
1067 sqrt_cast(ldrm->get_left()).get() ||
1068 sqrt_cast(ldrm->get_right()).get() ||
1069 pow_cast(ldrm->get_left()).get() ||
1070 pow_cast(ldrm->get_right()).get()) {
1071 return pow(ld->get_left(), this->right) /
1072 (pow(ldrm->get_left(), this->right) *
1073 pow(ldrm->get_right(), this->right));
1074 }
1075
1076 auto ldrmlm = multiply_cast(ldrm->get_left());
1077 if (ldrmlm.get()) {
1078 if (ldrmlm->get_left()->is_constant() ||
1079 ldrmlm->get_right()->is_constant() ||
1080 sqrt_cast(ldrmlm->get_left()).get() ||
1081 sqrt_cast(ldrmlm->get_right()).get() ||
1082 pow_cast(ldrmlm->get_left()).get() ||
1083 pow_cast(ldrmlm->get_right()).get()) {
1084 return pow(ld->get_left(), this->right) /
1085 (pow(ldrmlm->get_left(), this->right) *
1086 pow(ldrmlm->get_right(), this->right) *
1087 pow(ldrm->get_right(), this->right));
1088 }
1089 }
1090
1091 auto ldrmrm = multiply_cast(ldrm->get_right());
1092 if (ldrmrm.get()) {
1093 if (ldrmrm->get_left()->is_constant() ||
1094 ldrmrm->get_right()->is_constant() ||
1095 sqrt_cast(ldrmrm->get_left()).get() ||
1096 sqrt_cast(ldrmrm->get_right()).get() ||
1097 pow_cast(ldrmrm->get_left()).get() ||
1098 pow_cast(ldrmrm->get_right()).get()) {
1099 return pow(ld->get_left(), this->right) /
1100 (pow(ldrmrm->get_left(), this->right) *
1101 pow(ldrmrm->get_right(), this->right) *
1102 pow(ldrm->get_left(), this->right));
1103 }
1104 }
1105 }
1106
1107 if (is_variable_combineable(ld->get_left(),
1108 ld->get_right())) {
1109 return pow(ld->get_left()->get_power_base(),
1110 this->right*(ld->get_left()->get_power_exponent() -
1111 ld->get_right()->get_power_exponent()));
1112 }
1113
1114 if (ldrm.get()) {
1115 auto ldrmlm = multiply_cast(ldrm->get_left());
1116 if (ldrmlm.get()) {
1117 if (is_variable_combineable(ldrm->get_right(),
1118 ldrmlm->get_right()->get_power_base())) {
1119 return pow(ld->get_left()/ldrmlm->get_left(),
1120 this->right) /
1121 pow(ldrm->get_right()*ldrmlm->get_right(),
1122 this->right);
1123 } else if (is_variable_combineable(ldrm->get_right(),
1124 ldrmlm->get_left()->get_power_base())) {
1125 return pow(ld->get_left()/ldrmlm->get_right(),
1126 this->right) /
1127 pow(ldrm->get_right()*ldrmlm->get_left(),
1128 this->right);
1129 } else if (is_variable_combineable(ldrmlm->get_left(),
1130 ldrmlm->get_right()->get_power_base()) ||
1131 is_variable_combineable(ldrmlm->get_right(),
1132 ldrmlm->get_left()->get_power_base())) {
1133 return pow(ld->get_left()/ldrm->get_right(),
1134 this->right) /
1135 pow(ldrmlm->get_left()*ldrmlm->get_right(),
1136 this->right);
1137 }
1138 }
1139 }
1140 }
1141
1142// Reduce sqrt(a)^b
1143 auto lsq = sqrt_cast(this->left);
1144 if (lsq.get()) {
1145 return pow(lsq->get_arg(),
1146 this->right/2.0);
1147 }
1148
1149// Reduce exp(x)^n -> exp(n*x) when x is an integer.
1150 auto temp = exp_cast(this->left);
1151 if (temp.get() && rc.get() && rc->is_integer()) {
1152 return exp(this->right*temp->get_arg());
1153 }
1154
1155 return this->shared_from_this();
1156 }
1157
1158//------------------------------------------------------------------------------
1165//------------------------------------------------------------------------------
1168 if (this->is_match(x)) {
1169 return one<T, SAFE_MATH> ();
1170 }
1171
1172 const size_t hash = reinterpret_cast<size_t> (x.get());
1173 if (this->df_cache.find(hash) == this->df_cache.end()) {
1174 this->df_cache[hash] = pow(this->left, this->right - 1.0)
1175 * (this->right*this->left->df(x) +
1176 this->left*log(this->left)*this->right->df(x));
1177 }
1178 return this->df_cache[hash];
1179 }
1180
1181//------------------------------------------------------------------------------
1189//------------------------------------------------------------------------------
1191 compile(std::ostringstream &stream,
1192 jit::register_map &registers,
1194 const jit::register_usage &usage) {
1195 if (registers.find(this) == registers.end()) {
1196 shared_leaf<T, SAFE_MATH> l = this->left->compile(stream,
1197 registers,
1198 indices,
1199 usage);
1201 auto temp = constant_cast(this->right);
1202 if (!temp.get() || !temp->is_integer()) {
1203 r = this->right->compile(stream, registers, indices, usage);
1204 }
1205
1206 registers[this] = jit::to_string('r', this);
1207 stream << " const ";
1208 jit::add_type<T> (stream);
1209 stream << " " << registers[this] << " = ";
1210 if (temp.get() && temp->is_integer()) {
1211 stream << registers[l.get()];
1212 const size_t end = static_cast<size_t> (std::real(this->right->evaluate().at(0)));
1213 for (size_t i = 1; i < end; i++) {
1214 stream << "*" << registers[l.get()];
1215 }
1216 } else {
1217 stream << "pow("
1218 << registers[l.get()] << ", "
1219 << registers[r.get()] << ")";
1220 }
1221 this->endline(stream, usage);
1222 }
1223
1224 return this->shared_from_this();
1225 }
1226
1227//------------------------------------------------------------------------------
1232//------------------------------------------------------------------------------
1234 if (this == x.get()) {
1235 return true;
1236 }
1237
1238 auto x_cast = pow_cast(x);
1239 if (x_cast.get()) {
1240 return this->left->is_match(x_cast->get_left()) &&
1241 this->right->is_match(x_cast->get_right());
1242 }
1243
1244 return false;
1245 }
1246
1247//------------------------------------------------------------------------------
1249//------------------------------------------------------------------------------
1250 virtual void to_latex() const {
1251 auto use_brackets = !constant_cast(this->left).get() &&
1252 !variable_cast(this->left).get();
1253
1254 if (use_brackets) {
1255 std::cout << "\\left(";
1256 }
1257 this->left->to_latex();
1258 if (use_brackets) {
1259 std::cout << "\\right)";
1260 }
1261 std::cout << "^{";
1262 this->right->to_latex();
1263 std::cout << "}";
1264 }
1265
1266//------------------------------------------------------------------------------
1272//------------------------------------------------------------------------------
1273 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
1274 jit::register_map &registers) {
1275 if (registers.find(this) == registers.end()) {
1276 const std::string name = jit::to_string('r', this);
1277 registers[this] = name;
1278 stream << " " << name
1279 << " [label = \"pow\", shape = oval, style = filled, fillcolor = blue, fontcolor = white];" << std::endl;
1280
1281 auto l = this->left->to_vizgraph(stream, registers);
1282 stream << " " << name << " -- " << registers[l.get()] << ";" << std::endl;
1283 auto r = this->right->to_vizgraph(stream, registers);
1284 stream << " " << name << " -- " << registers[r.get()] << ";" << std::endl;
1285 }
1286
1287 return this->shared_from_this();
1288 }
1289
1290//------------------------------------------------------------------------------
1294//------------------------------------------------------------------------------
1295 virtual bool is_all_variables() const {
1296 return this->left->is_all_variables() &&
1297 (this->right->is_all_variables() ||
1298 constant_cast(this->right).get());
1299 }
1300
1301//------------------------------------------------------------------------------
1305//------------------------------------------------------------------------------
1306 virtual bool is_power_like() const {
1307 return true;
1308 }
1309
1310//------------------------------------------------------------------------------
1314//------------------------------------------------------------------------------
1316 return this->left;
1317 }
1318
1319//------------------------------------------------------------------------------
1323//------------------------------------------------------------------------------
1325 return this->right;
1326 }
1327
1328//------------------------------------------------------------------------------
1332//------------------------------------------------------------------------------
1334 if (this->has_pseudo()) {
1335 return pow(this->left->remove_pseudo(),
1336 this->right->remove_pseudo());
1337 }
1338 return this->shared_from_this();
1339 }
1340 };
1341
1342//------------------------------------------------------------------------------
1350//------------------------------------------------------------------------------
1351 template<jit::float_scalar T, bool SAFE_MATH=false>
1354 auto temp = std::make_shared<pow_node<T, SAFE_MATH>> (l, r)->reduce();
1355// Test for hash collisions.
1356 for (size_t i = temp->get_hash();
1358 if (leaf_node<T, SAFE_MATH>::caches.nodes.find(i) ==
1359 leaf_node<T, SAFE_MATH>::caches.nodes.end()) {
1361 return temp;
1362 } else if (temp->is_match(leaf_node<T, SAFE_MATH>::caches.nodes[i])) {
1363 return leaf_node<T, SAFE_MATH>::caches.nodes[i];
1364 }
1365 }
1366#if defined(__clang__) || defined(__GNUC__)
1368#else
1369 assert(false && "Should never reach.");
1370#endif
1371 }
1372
1373//------------------------------------------------------------------------------
1382//------------------------------------------------------------------------------
1383 template<jit::float_scalar T, jit::float_scalar L, bool SAFE_MATH=false>
1386 return pow(constant<T, SAFE_MATH> (static_cast<T> (l)), r);
1387 }
1388
1389//------------------------------------------------------------------------------
1398//------------------------------------------------------------------------------
1399 template<jit::float_scalar T, jit::float_scalar R, bool SAFE_MATH=false>
1401 const R r) {
1402 return pow(l, constant<T, SAFE_MATH> (static_cast<T> (r)));
1403 }
1404
1406 template<jit::float_scalar T, bool SAFE_MATH=false>
1407 using shared_pow = std::shared_ptr<pow_node<T, SAFE_MATH>>;
1408
1409//------------------------------------------------------------------------------
1414//------------------------------------------------------------------------------
1415 template<jit::float_scalar T, bool SAFE_MATH=false>
1417 return std::dynamic_pointer_cast<pow_node<T, SAFE_MATH>> (x);
1418 }
1419
1420//******************************************************************************
1421// Erfi node.
1422//******************************************************************************
1423//------------------------------------------------------------------------------
1430//------------------------------------------------------------------------------
1431 template<jit::complex_scalar T, bool SAFE_MATH=false>
1432 class erfi_node final : public straight_node<T, SAFE_MATH> {
1433 private:
1434//------------------------------------------------------------------------------
1439//------------------------------------------------------------------------------
1440 static std::string to_string(leaf_node<T, SAFE_MATH> *a) {
1441 return "erfi" + jit::format_to_string(reinterpret_cast<size_t> (a));
1442 }
1443
1444 public:
1445//------------------------------------------------------------------------------
1449//------------------------------------------------------------------------------
1452
1453//------------------------------------------------------------------------------
1459//------------------------------------------------------------------------------
1461 backend::buffer<T> result = this->arg->evaluate();
1462 result.erfi();
1463 return result;
1464 }
1465
1466//------------------------------------------------------------------------------
1470//------------------------------------------------------------------------------
1472 if (constant_cast(this->arg).get()) {
1473 return constant<T, SAFE_MATH> (this->evaluate());
1474 }
1475
1476 auto ap1 = piecewise_1D_cast(this->arg);
1477 if (ap1.get()) {
1478 return piecewise_1D(this->evaluate(),
1479 ap1->get_arg(),
1480 ap1->get_scale(),
1481 ap1->get_offset());
1482 }
1483
1484 auto ap2 = piecewise_2D_cast(this->arg);
1485 if (ap2.get()) {
1486 return piecewise_2D(this->evaluate(),
1487 ap2->get_num_columns(),
1488 ap2->get_left(), ap2->get_x_scale(), ap2->get_x_offset(),
1489 ap2->get_right(), ap2->get_y_scale(), ap2->get_y_offset());
1490 }
1491
1492 return this->shared_from_this();
1493 }
1494
1495//------------------------------------------------------------------------------
1502//------------------------------------------------------------------------------
1504 if (this->is_match(x)) {
1505 return one<T, SAFE_MATH> ();
1506 }
1507
1508 const size_t hash = reinterpret_cast<size_t> (x.get());
1509 if (this->df_cache.find(hash) == this->df_cache.end()) {
1510 this->df_cache[hash] = 2.0/std::sqrt(M_PI)
1511 * exp(this->arg*this->arg)*this->arg->df(x);
1512 }
1513 return this->df_cache[hash];
1514 }
1515
1516//------------------------------------------------------------------------------
1524//------------------------------------------------------------------------------
1526 compile(std::ostringstream &stream,
1527 jit::register_map &registers,
1529 const jit::register_usage &usage) {
1530 if (registers.find(this) == registers.end()) {
1531 shared_leaf<T, SAFE_MATH> a = this->arg->compile(stream,
1532 registers,
1533 indices,
1534 usage);
1535
1536 registers[this] = jit::to_string('r', this);
1537 stream << " const ";
1538 jit::add_type<T> (stream);
1539 stream << " " << registers[this] << " = special::erfi("
1540 << registers[a.get()] << ")";
1541 this->endline(stream, usage);
1542 }
1543
1544 return this->shared_from_this();
1545 }
1546
1547//------------------------------------------------------------------------------
1552//------------------------------------------------------------------------------
1554 if (this == x.get()) {
1555 return true;
1556 }
1557
1558 auto x_cast = erfi_cast(x);
1559 if (x_cast.get()) {
1560 return this->arg->is_match(x_cast->get_arg());
1561 }
1562
1563 return false;
1564 }
1565
1566//------------------------------------------------------------------------------
1568//------------------------------------------------------------------------------
1569 virtual void to_latex() const {
1570 std::cout << "erfi\\left(";
1571 this->arg->to_latex();
1572 std::cout << "\\right)";
1573 }
1574
1575//------------------------------------------------------------------------------
1579//------------------------------------------------------------------------------
1581 if (this->has_pseudo()) {
1582 return erfi(this->arg->remove_pseudo());
1583 }
1584 return this->shared_from_this();
1585 }
1586
1587//------------------------------------------------------------------------------
1593//------------------------------------------------------------------------------
1594 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
1595 jit::register_map &registers) {
1596 if (registers.find(this) == registers.end()) {
1597 const std::string name = jit::to_string('r', this);
1598 registers[this] = name;
1599 stream << " " << name
1600 << " [label = \"erfi\", shape = oval, style = filled, fillcolor = blue, fontcolor = white];" << std::endl;
1601
1602 auto a = this->arg->to_vizgraph(stream, registers);
1603 stream << " " << name << " -- " << registers[a.get()] << ";" << std::endl;
1604 }
1605
1606 return this->shared_from_this();
1607 }
1608 };
1609
1610//------------------------------------------------------------------------------
1618//------------------------------------------------------------------------------
1619 template<jit::complex_scalar T, bool SAFE_MATH=false>
1621 auto temp = std::make_shared<erfi_node<T, SAFE_MATH>> (x)->reduce();
1622// Test for hash collisions.
1623 for (size_t i = temp->get_hash();
1625 if (leaf_node<T, SAFE_MATH>::caches.nodes.find(i) ==
1626 leaf_node<T, SAFE_MATH>::caches.nodes.end()) {
1628 return temp;
1629 } else if (temp->is_match(leaf_node<T, SAFE_MATH>::caches.nodes[i])) {
1630 return leaf_node<T, SAFE_MATH>::caches.nodes[i];
1631 }
1632 }
1633#if defined(__clang__) || defined(__GNUC__)
1635#else
1636 assert(false && "Should never reach.");
1637#endif
1638 }
1639
1641 template<jit::complex_scalar T, bool SAFE_MATH=false>
1642 using shared_erfi = std::shared_ptr<erfi_node<T, SAFE_MATH>>;
1643
1644//------------------------------------------------------------------------------
1652//------------------------------------------------------------------------------
1653 template<jit::complex_scalar T, bool SAFE_MATH=false>
1655 return std::dynamic_pointer_cast<erfi_node<T, SAFE_MATH>> (x);
1656 }
1657}
1658
1659#endif /* math_h */
Class representing a generic buffer.
Definition backend.hpp:29
void erfi()
Take erfi.
Definition backend.hpp:259
void log()
Take log.
Definition backend.hpp:232
void sqrt()
Take sqrt.
Definition backend.hpp:214
void pow_col(const buffer< T > &x)
Pow col operation.
Definition backend.hpp:715
void pow_row(const buffer< T > &x)
Pow row operation.
Definition backend.hpp:679
void exp()
Take exp.
Definition backend.hpp:223
Class representing a branch node.
Definition node.hpp:1173
shared_leaf< T, SAFE_MATH > right
Right branch of the tree.
Definition node.hpp:1178
shared_leaf< T, SAFE_MATH > left
Left branch of the tree.
Definition node.hpp:1176
An imaginary error function node.
Definition math.hpp:1432
erfi_node(shared_leaf< T, SAFE_MATH > x)
Construct a exp node.
Definition math.hpp:1450
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduce the erfi(x).
Definition math.hpp:1471
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition math.hpp:1526
virtual void to_latex() const
Convert the node to latex.
Definition math.hpp:1569
virtual shared_leaf< T, SAFE_MATH > remove_pseudo()
Remove pseudo variable nodes.
Definition math.hpp:1580
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition math.hpp:1503
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Querey if the nodes match.
Definition math.hpp:1553
virtual backend::buffer< T > evaluate()
Evaluate the results of erfi.
Definition math.hpp:1460
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition math.hpp:1594
A exp node.
Definition math.hpp:329
virtual backend::buffer< T > evaluate()
Evaluate the results of exp.
Definition math.hpp:357
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition math.hpp:406
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Querey if the nodes match.
Definition math.hpp:477
virtual shared_leaf< T, SAFE_MATH > remove_pseudo()
Remove pseudo variable nodes.
Definition math.hpp:504
exp_node(shared_leaf< T, SAFE_MATH > x)
Construct a exp node.
Definition math.hpp:347
virtual void to_latex() const
Convert the node to latex.
Definition math.hpp:493
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduce the exp(x).
Definition math.hpp:368
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition math.hpp:518
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition math.hpp:428
Class representing a node leaf.
Definition node.hpp:364
virtual void endline(std::ostringstream &stream, const jit::register_usage &usage) const final
End a line in the kernel source.
Definition node.hpp:637
std::map< size_t, std::shared_ptr< leaf_node< T, SAFE_MATH > > > df_cache
Cache derivative terms.
Definition node.hpp:371
virtual bool has_pseudo() const
Query if the node contains pseudo variables.
Definition node.hpp:618
const size_t hash
Hash for node.
Definition node.hpp:367
A log node.
Definition math.hpp:594
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition math.hpp:761
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition math.hpp:671
virtual backend::buffer< T > evaluate()
Evaluate the results of log.
Definition math.hpp:622
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Querey if the nodes match.
Definition math.hpp:720
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduce the log(x).
Definition math.hpp:633
virtual void to_latex() const
Convert the node to latex.
Definition math.hpp:736
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition math.hpp:693
log_node(shared_leaf< T, SAFE_MATH > x)
Construct a log node.
Definition math.hpp:612
virtual shared_leaf< T, SAFE_MATH > remove_pseudo()
Remove pseudo variable nodes.
Definition math.hpp:747
An power node.
Definition math.hpp:836
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition math.hpp:1273
virtual bool is_all_variables() const
Test if node acts like a variable.
Definition math.hpp:1295
virtual shared_leaf< T, SAFE_MATH > get_power_base()
Get the base of a power.
Definition math.hpp:1315
virtual backend::buffer< T > evaluate()
Evaluate the results of addition.
Definition math.hpp:870
pow_node(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Construct an power node.
Definition math.hpp:858
virtual bool is_power_like() const
Test if the node acts like a power of variable.
Definition math.hpp:1306
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition math.hpp:1191
virtual shared_leaf< T, SAFE_MATH > remove_pseudo()
Remove pseudo variable nodes.
Definition math.hpp:1333
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Querey if the nodes match.
Definition math.hpp:1233
virtual shared_leaf< T, SAFE_MATH > get_power_exponent() const
Get the exponent of a power.
Definition math.hpp:1324
virtual void to_latex() const
Convert the node to latex.
Definition math.hpp:1250
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduce a power node.
Definition math.hpp:881
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition math.hpp:1167
A sqrt node.
Definition math.hpp:26
virtual shared_leaf< T, SAFE_MATH > get_power_exponent() const
Get the exponent of a power.
Definition math.hpp:230
virtual backend::buffer< T > evaluate()
Evaluate the results of sqrt.
Definition math.hpp:54
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduce the sqrt(x).
Definition math.hpp:65
sqrt_node(shared_leaf< T, SAFE_MATH > x)
Construct a sqrt node.
Definition math.hpp:44
virtual bool is_power_like() const
Test if the node acts like a power of variable.
Definition math.hpp:212
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Querey if the nodes match.
Definition math.hpp:185
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition math.hpp:158
virtual shared_leaf< T, SAFE_MATH > remove_pseudo()
Remove pseudo variable nodes.
Definition math.hpp:239
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition math.hpp:135
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition math.hpp:253
virtual void to_latex() const
Convert the node to latex.
Definition math.hpp:201
virtual shared_leaf< T, SAFE_MATH > get_power_base()
Get the base of a power.
Definition math.hpp:221
Class representing a straight node.
Definition node.hpp:1059
shared_leaf< T, SAFE_MATH > arg
Argument.
Definition node.hpp:1062
Complex scalar concept.
Definition register.hpp:24
subroutine assert(test, message)
Assert check.
Definition f_binding_test.f90:38
buffer< T > pow(buffer< T > &base, buffer< T > &exponent)
Take the power.
Definition backend.hpp:1025
Name space for graph nodes.
Definition arithmetic.hpp:13
shared_piecewise_2D< T, SAFE_MATH > piecewise_2D_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a piecewise 2D node.
Definition piecewise.hpp:1323
bool is_variable_combineable(shared_leaf< T, SAFE_MATH > a, shared_leaf< T, SAFE_MATH > b)
Check if the variable is combinable.
Definition arithmetic.hpp:75
shared_leaf< T, SAFE_MATH > log(shared_leaf< T, SAFE_MATH > x)
Define log convience function.
Definition math.hpp:787
shared_pow< T, SAFE_MATH > pow_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a power node.
Definition math.hpp:1416
constexpr shared_leaf< T, SAFE_MATH > zero()
Forward declare for zero.
Definition node.hpp:994
shared_sqrt< T, SAFE_MATH > sqrt_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a sqrt node.
Definition math.hpp:313
shared_leaf< T, SAFE_MATH > pow(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Build power node.
Definition math.hpp:1352
shared_divide< T, SAFE_MATH > divide_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a divide node.
Definition arithmetic.hpp:3688
std::shared_ptr< erfi_node< T, SAFE_MATH > > shared_erfi
Convenience type alias for shared exp nodes.
Definition math.hpp:1642
shared_piecewise_1D< T, SAFE_MATH > piecewise_1D_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a piecewise 1D node.
Definition piecewise.hpp:601
shared_multiply< T, SAFE_MATH > multiply_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a multiply node.
Definition arithmetic.hpp:2723
shared_leaf< T, SAFE_MATH > exp(shared_leaf< T, SAFE_MATH > x)
Define exp convience function.
Definition math.hpp:544
shared_exp< T, SAFE_MATH > exp_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a exp node.
Definition math.hpp:578
shared_constant< T, SAFE_MATH > constant_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a constant node.
Definition node.hpp:1042
shared_leaf< T, SAFE_MATH > erfi(shared_leaf< T, SAFE_MATH > x)
Define erfi convience function.
Definition math.hpp:1620
shared_variable< T, SAFE_MATH > variable_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a variable node.
Definition node.hpp:1746
constexpr T i
Convinece type for imaginary constant.
Definition node.hpp:1026
shared_erfi< T, SAFE_MATH > erfi_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a exp node.
Definition math.hpp:1654
shared_leaf< T, SAFE_MATH > sqrt(shared_leaf< T, SAFE_MATH > x)
Define sqrt convience function.
Definition math.hpp:279
std::shared_ptr< exp_node< T, SAFE_MATH > > shared_exp
Convenience type alias for shared exp nodes.
Definition math.hpp:566
shared_log< T, SAFE_MATH > log_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a exp node.
Definition math.hpp:820
std::shared_ptr< leaf_node< T, SAFE_MATH > > shared_leaf
Convenience type alias for shared leaf nodes.
Definition node.hpp:673
std::shared_ptr< log_node< T, SAFE_MATH > > shared_log
Convenience type alias for shared log nodes.
Definition math.hpp:808
std::shared_ptr< pow_node< T, SAFE_MATH > > shared_pow
Convenience type alias for shared add nodes.
Definition math.hpp:1407
std::shared_ptr< sqrt_node< T, SAFE_MATH > > shared_sqrt
Convenience type alias for shared sqrt nodes.
Definition math.hpp:301
std::string format_to_string(const T value)
Convert a value to a string while avoiding locale.
Definition register.hpp:211
std::map< void *, size_t > register_usage
Type alias for counting register usage.
Definition register.hpp:258
std::map< void *, std::string > register_map
Type alias for mapping node pointers to register names.
Definition register.hpp:256
std::string to_string(const char prefix, const NODE *pointer)
Convert a graph::leaf_node pointer to a string.
Definition register.hpp:245
Base nodes of graph computation framework.
void piecewise_1D()
Tests for 1D piecewise nodes.
Definition piecewise_test.cpp:80
void piecewise_2D()
Tests for 2D piecewise nodes.
Definition piecewise_test.cpp:283