24 template<jit::
float_scalar T,
bool SAFE_MATH=false>
27 if (
a->is_constant() &&
b->is_constant()) {
34 (
a1.get() &&
a1->is_arg_match(
b)) ||
35 (
a2.get() &&
a2->is_arg_match(
b)) ||
36 (
a2.get() && (
a2->is_row_match(
b) ||
a2->is_col_match(
b))) ||
37 (
b2.get() && (
b2->is_row_match(
a) ||
b2->is_col_match(
a)));
52 template<jit::
float_scalar T,
bool SAFE_MATH=false>
58 return a->is_constant() &&
74 template<jit::
float_scalar T,
bool SAFE_MATH=false>
77 return a->is_power_base_match(
b);
90 template<jit::
float_scalar T,
bool SAFE_MATH=false>
93 return !
b->is_constant() &&
94 (
a->is_all_variables() &&
95 (!
b->is_all_variables() ||
96 (
b->is_all_variables() &&
97 a->get_complexity() <
b->get_complexity())));
110 template<jit::
float_scalar T,
bool SAFE_MATH=false>
116 return ae.get() &&
be.get() &&
117 std::abs(
ae->evaluate().at(0)) > std::abs(
be->evaluate().at(0));
131 template<jit::
float_scalar T,
bool SAFE_MATH=false>
182 if (l.get() &&
l->is(0)) {
184 }
else if (r.get() &&
r->is(0)) {
186 }
else if (l.get() &&
r.get()) {
188 }
else if (r.get() && !
l.get()) {
195 if (
pl1.get() && (r.get() ||
pl1->is_arg_match(
this->right))) {
197 pl1->get_scale(),
pl1->get_offset());
198 }
else if (
pr1.get() && (l.get() ||
pr1->is_arg_match(
this->left))) {
200 pr1->get_scale(),
pr1->get_offset());
206 if (
pl2.get() && (r.get() ||
pl2->is_arg_match(
this->right))) {
208 pl2->get_num_columns(),
209 pl2->get_left(),
pl2->get_x_scale(),
pl2->get_x_offset(),
210 pl2->get_right(),
pl2->get_y_scale(),
pl2->get_y_offset());
211 }
else if (
pr2.get() && (l.get() ||
pr2->is_arg_match(
this->left))) {
213 pr2->get_num_columns(),
214 pr2->get_left(),
pr2->get_x_scale(),
pr2->get_x_offset(),
215 pr2->get_right(),
pr2->get_y_scale(),
pr2->get_y_offset());
219 if (
pr2.get() &&
pr2->is_row_match(
this->left)) {
223 pr2->get_num_columns(),
224 pr2->get_left(),
pr2->get_x_scale(),
pr2->get_x_offset(),
225 pr2->get_right(),
pr2->get_y_scale(),
pr2->get_y_offset());
226 }
else if (
pr2.get() &&
pr2->is_col_match(
this->left)) {
230 pr2->get_num_columns(),
231 pr2->get_left(),
pr2->get_x_scale(),
pr2->get_x_offset(),
232 pr2->get_right(),
pr2->get_y_scale(),
pr2->get_y_offset());
233 }
else if (
pl2.get() &&
pl2->is_row_match(
this->right)) {
237 pl2->get_num_columns(),
238 pl2->get_left(),
pl2->get_x_scale(),
pl2->get_x_offset(),
239 pl2->get_right(),
pl2->get_y_scale(),
pl2->get_y_offset());
240 }
else if (
pl2.get() &&
pl2->is_col_match(
this->right)) {
244 pl2->get_num_columns(),
245 pl2->get_left(),
pl2->get_x_scale(),
pl2->get_x_offset(),
246 pl2->get_right(),
pl2->get_y_scale(),
pl2->get_y_offset());
250 if (this->
left->is_match(
this->right)) {
251 return 2.0*this->
left;
262 rm->get_left()->is_constant() &&
263 rm->get_left()->evaluate().is_negative()) {
265 }
else if (
rm.get() &&
266 rm->get_left()->is_constant() &&
267 rm->get_left()->evaluate().is_negative()) {
274 return fma(
lm->get_left(),
lm->get_right(),
this->right);
275 }
else if (
rm.get()) {
276 return fma(
rm->get_left(),
rm->get_right(),
this->left);
291 if ((
rd->get_left()->is_constant() &&
292 rd->get_left()->evaluate().is_negative()) ||
294 (
rdlm->get_left()->is_constant() &&
295 rdlm->get_left()->evaluate().is_negative()))) {
296 return this->
left - (-
rd->get_left())/
rd->get_right();
298 }
else if (
ld.get()) {
300 if ((
ld->get_left()->is_constant() &&
301 ld->get_left()->evaluate().is_negative()) ||
303 (
ldlm->get_left()->is_constant() &&
304 ldlm->get_left()->evaluate().is_negative()))) {
305 return this->
right - (-
ld->get_left())/
ld->get_right();
309 if (
ld.get() &&
rd.get()) {
310 if (
ld->get_right()->is_match(
rd->get_right())) {
311 return (
ld->get_left() +
rd->get_left())/
ld->get_right();
321 if (
ld->get_left()->is_match(
rdlm->get_left())) {
322 return (1.0/
ld->get_right() +
323 rdlm->get_right()/
rd->get_right())*
rdlm->get_left();
324 }
else if (
ld->get_left()->is_match(
rdlm->get_right())) {
325 return (1.0/
ld->get_right() +
326 rdlm->get_left()/
rd->get_right())*
rdlm->get_right();
328 }
else if (
ldlm.get()) {
329 if (
rd->get_left()->is_match(
ldlm->get_left())) {
330 return (
ldlm->get_right()/
ld->get_right() +
331 1.0/
rd->get_right())*
ldlm->get_left();
332 }
else if (
rd->get_left()->is_match(
ldlm->get_right())) {
333 return (
ldlm->get_left()/
ld->get_right() +
334 1.0/
rd->get_right())*
ldlm->get_right();
346 !
ldlm->get_right()->is_match(
rdlm->get_right())) {
347 return (
ldlm->get_right()/
ld->get_right() +
349 rdlm->get_right()/
rd->get_right())*
ldlm->get_left();
352 if (
ldlm->get_right()->is_match(
rdlm->get_right())) {
353 return (
ldlm->get_left()/
ld->get_right() +
354 rdlm->get_left()/
rd->get_right())*
ldlm->get_right();
355 }
else if (
ldlm->get_right()->is_match(
rdlm->get_left())) {
356 return (
ldlm->get_left()/
ld->get_right() +
357 rdlm->get_right()/
rd->get_right())*
ldlm->get_right();
358 }
else if (
ldlm->get_left()->is_match(
rdlm->get_right())) {
359 return (
ldlm->get_right()/
ld->get_right() +
360 rdlm->get_left()/
rd->get_right())*
ldlm->get_left();
361 }
else if (
ldlm->get_left()->is_match(
rdlm->get_left())) {
362 return (
ldlm->get_right()/
ld->get_right() +
363 rdlm->get_right()/
rd->get_right())*
ldlm->get_left();
374 if (
ldrm->get_right()->is_match(
rdrm->get_right())) {
375 return (
ld->get_left()/
ldrm->get_left() +
376 rd->get_left()/
rdrm->get_left())/
ldrm->get_right();
377 }
else if (
ldrm->get_right()->is_match(
rdrm->get_left())) {
378 return (
ld->get_left()/
ldrm->get_left() +
379 rd->get_left()/
rdrm->get_right())/
ldrm->get_right();
380 }
else if (
ldrm->get_left()->is_match(
rdrm->get_right())) {
381 return (
ld->get_left()/
ldrm->get_right() +
382 rd->get_left()/
rdrm->get_left())/
ldrm->get_left();
383 }
else if (
ldrm->get_left()->is_match(
rdrm->get_left())) {
384 return (
ld->get_left()/
ldrm->get_right() +
385 rd->get_left()/
rdrm->get_right())/
ldrm->get_left();
394 if (
ld->get_right()->is_match(
rdrm->get_left())) {
395 return fma(
ld->get_left(),
399 }
else if (
ld->get_right()->is_match(
rdrm->get_right())) {
400 return fma(
ld->get_left(),
405 }
else if (
ldrm.get()) {
406 if (
rd->get_right()->is_match(
ldrm->get_left())) {
407 return fma(
rd->get_left(),
411 }
else if (
rd->get_right()->is_match(
ldrm->get_right())) {
412 return fma(
rd->get_left(),
427 if (this->
right->is_match(
la->get_left())) {
428 return fma(2.0, this->
right,
la->get_right());
429 }
else if (this->
right->is_match(
la->get_right())) {
435 if (this->
left->is_match(
ra->get_left())) {
436 return fma(2.0, this->
left,
ra->get_right());
437 }
else if (this->
left->is_match(
ra->get_right())) {
438 return fma(2.0, this->
left,
ra->get_left());
450 if (
ls->get_left()->is_match(
this->right)) {
451 return static_cast<T
> (2.0)*this->
right -
ls->get_right();
452 }
else if (
ls->get_right()->is_match(this->
right)) {
453 return ls->get_left();
464 return (this->
left +
rs->get_left()) -
rs->get_right();
466 return (this->
left -
rs->get_right()) +
rs->get_left();
467 }
else if (this->
left->is_match(
rs->get_left())) {
468 return static_cast<T
> (2.0)*this->
left -
rs->get_right();
469 }
else if (this->
left->is_match(
rs->get_right())) {
470 return rs->get_left();
480 return la->get_left() + (
la->get_right() + this->
right);
485 return ls->get_left() + (this->
right -
ls->get_right());
496 }
else if (
rfma.get()) {
508 if (
lfma->get_middle()->is_match(
rfma->get_middle())) {
510 lfma->get_left() +
rfma->get_left(),
511 lfma->get_right() +
rfma->get_right());
512 }
else if (
lfma->get_left()->is_match(
rfma->get_middle())) {
514 lfma->get_middle() +
rfma->get_left(),
515 lfma->get_right() +
rfma->get_right());
516 }
else if (
lfma->get_middle()->is_match(
rfma->get_left())) {
518 lfma->get_left() +
rfma->get_middle(),
519 lfma->get_right() +
rfma->get_right());
520 }
else if (
lfma->get_left()->is_match(
rfma->get_left())) {
522 lfma->get_middle() +
rfma->get_middle(),
523 lfma->get_right() +
rfma->get_right());
534 if (
pl.get() &&
pr.get() &&
535 pl->get_right()->is_match(
pr->get_right())) {
538 if (
plm.get() &&
prm.get()) {
539 if (
plm->get_left()->is_match(
prm->get_left())) {
540 return pow(
plm->get_left(),
pl->get_right())*
541 (
pow(
plm->get_right(),
pl->get_right()) +
542 pow(
prm->get_right(),
pl->get_right()));
543 }
else if (
plm->get_left()->is_match(
prm->get_right())) {
544 return pow(
plm->get_left(),
pl->get_right())*
545 (
pow(
plm->get_right(),
pl->get_right()) +
546 pow(
prm->get_left(),
pl->get_right()));
547 }
else if (
plm->get_right()->is_match(
prm->get_left())) {
548 return pow(
plm->get_right(),
pl->get_right())*
549 (
pow(
plm->get_left(),
pl->get_right()) +
550 pow(
prm->get_right(),
pl->get_right()));
551 }
else if (
plm->get_right()->is_match(
prm->get_right())) {
552 return pow(
plm->get_right(),
pl->get_right())*
553 (
pow(
plm->get_left(),
pl->get_right()) +
554 pow(
prm->get_left(),
pl->get_right()));
561 if (
plrc.get() &&
plrc->is(
static_cast<T
> (2.0))) {
566 if ((
pls.get() &&
prc.get() &&
pls->get_arg()->is_match(
prc->get_arg())) ||
567 (
plc.get() &&
prs.get() &&
plc->get_arg()->is_match(
prs->get_arg()))) {
576 if (
pl.get() &&
rd.get()) {
578 if (
rdp.get() &&
pl->get_right()->is_match(
rdp->get_right())) {
581 rdp->get_left()->is_match(
plld->get_right())) {
582 return (
pow(
plld->get_left(),
pl->get_right()) +
584 pow(
rdp->get_left(),
pl->get_right());
587 }
else if (
pr.get() &&
ld.get()) {
589 if (
ldp.get() &&
pr->get_right()->is_match(
ldp->get_right())) {
592 ldp->get_left()->is_match(
prld->get_right())) {
593 return (
pow(
prld->get_left(),
pr->get_right()) +
595 pow(
ldp->get_left(),
pr->get_right());
598 }
else if (
pl.get() &&
pr.get()) {
599 if (
pl->get_right()->is_match(
pr->get_right())) {
602 if (
pld.get() &&
prd.get() &&
603 pld->get_right()->is_match(
prd->get_right())) {
604 return (
pow(
pld->get_left(),
pl->get_right()) +
605 pow(
prd->get_left(),
pl->get_right())) /
606 pow(
pld->get_right(),
pl->get_right());
628 const size_t hash =
reinterpret_cast<size_t> (x.get());
649 if (registers.find(
this) == registers.end()) {
661 jit::add_type<T> (stream);
662 stream <<
" " << registers[
this] <<
" = "
663 << registers[
l.get()] <<
" + "
664 << registers[
r.get()];
678 if (
this == x.get()) {
685 if ((this->
left->is_match(
x_cast->get_left()) &&
686 this->right->is_match(
x_cast->get_right())) ||
705 std::cout <<
"\\left(";
707 this->
left->to_latex();
709 std::cout <<
"\\right)";
713 std::cout <<
"\\left(";
715 this->
right->to_latex();
717 std::cout <<
"\\right)";
728 return this->
left->remove_pseudo() +
729 this->
right->remove_pseudo();
743 if (registers.find(
this) == registers.end()) {
745 registers[
this] =
name;
746 stream <<
" " <<
name
747 <<
" [label = \"+\", shape = oval, style = filled, fillcolor = blue, fontcolor = white];" << std::endl;
749 auto l = this->
left->to_vizgraph(stream, registers);
750 stream <<
" " <<
name <<
" -- " << registers[
l.get()] <<
";" << std::endl;
751 auto r = this->
right->to_vizgraph(stream, registers);
752 stream <<
" " <<
name <<
" -- " << registers[
r.get()] <<
";" << std::endl;
771 template<jit::
float_scalar T,
bool SAFE_MATH=false>
774 auto temp = std::make_shared<add_node<T, SAFE_MATH>> (
l,
r)->reduce();
776 for (
size_t i =
temp->get_hash();
786#if defined(__clang__) || defined(__GNUC__)
789 assert(
false &&
"Should never reach.");
805 template<jit::
float_scalar T,
bool SAFE_MATH=false>
824 template<jit::
float_scalar T, jit::
float_scalar L,
bool SAFE_MATH=false>
843 template<jit::
float_scalar T, jit::
float_scalar R,
bool SAFE_MATH=false>
850 template<jit::
float_scalar T,
bool SAFE_MATH=false>
862 template<jit::
float_scalar T,
bool SAFE_MATH=false>
864 return std::dynamic_pointer_cast<add_node<T, SAFE_MATH>> (x);
878 template<jit::
float_scalar T,
bool SAFE_MATH=false>
927 if (this->
left->is_match(
this->right)) {
928 if (l.get() &&
l->is(0)) {
938 if (l.get() &&
l->is(0)) {
940 }
else if (r.get() &&
r->is(0)) {
942 }
else if (l.get() &&
r.get()) {
944 }
else if (r.get() &&
r->
evaluate().is_negative()) {
951 if (
pl1.get() && (r.get() ||
pl1->is_arg_match(
this->right))) {
953 pl1->get_scale(),
pl1->get_offset());
954 }
else if (
pr1.get() && (l.get() ||
pr1->is_arg_match(
this->left))) {
956 pr1->get_scale(),
pr1->get_offset());
962 if (
pl2.get() && (r.get() ||
pl2->is_arg_match(
this->right))) {
964 pl2->get_num_columns(),
965 pl2->get_left(),
pl2->get_x_scale(),
pl2->get_x_offset(),
966 pl2->get_right(),
pl2->get_y_scale(),
pl2->get_y_offset());
967 }
else if (
pr2.get() && (l.get() ||
pr2->is_arg_match(
this->left))) {
969 pr2->get_num_columns(),
970 pr2->get_left(),
pr2->get_x_scale(),
pr2->get_x_offset(),
971 pr2->get_right(),
pr2->get_y_scale(),
pr2->get_y_offset());
975 if (
pr2.get() &&
pr2->is_row_match(
this->left)) {
979 pr2->get_num_columns(),
980 pr2->get_left(),
pr2->get_x_scale(),
pr2->get_x_offset(),
981 pr2->get_right(),
pr2->get_y_scale(),
pr2->get_y_offset());
982 }
else if (
pr2.get() &&
pr2->is_col_match(
this->left)) {
986 pr2->get_num_columns(),
987 pr2->get_left(),
pr2->get_x_scale(),
pr2->get_x_offset(),
988 pr2->get_right(),
pr2->get_y_scale(),
pr2->get_y_offset());
989 }
else if (
pl2.get() &&
pl2->is_row_match(
this->right)) {
993 pl2->get_num_columns(),
994 pl2->get_left(),
pl2->get_x_scale(),
pl2->get_x_offset(),
995 pl2->get_right(),
pl2->get_y_scale(),
pl2->get_y_offset());
996 }
else if (
pl2.get() &&
pl2->is_col_match(
this->right)) {
1000 pl2->get_num_columns(),
1001 pl2->get_left(),
pl2->get_x_scale(),
pl2->get_x_offset(),
1002 pl2->get_right(),
pl2->get_y_scale(),
pl2->get_y_offset());
1009 return (
la->get_left() -
this->right) +
la->get_right();
1015 return (this->
left -
ra->get_left()) -
ra->get_right();
1024 return (
ls->get_left() -
this->right) -
ls->get_right();
1027 return -(
ls->get_right() + this->
right) -
ls->get_left();
1035 return (this->
left -
rs->get_left()) +
rs->get_right();
1037 return (this->
left +
rs->get_right()) -
rs->get_left();
1051 lmra->get_left()) &&
1054 return fma(
lm->get_left(),
1056 lm->get_left()*
lmra->get_left() -
this->right);
1063 lmrs->get_left()) &&
1066 return lm->get_left()*
lmrs->get_left() - this->
right -
1067 lm->get_left()*
lmrs->get_right();
1075 rm->get_left()->is_constant() &&
1076 rm->get_left()->evaluate().is_negative()) {
1084 if (
lmc.get() &&
lmc->is(-1)) {
1085 return lm->get_left()*(
lm->get_right() + this->
right);
1090 if (this->
right->is_match(
lm->get_right())) {
1091 return (
lm->get_left() - 1.0)*this->
right;
1092 }
else if (this->
right->is_match(
lm->get_left())) {
1093 return (
lm->get_right() - 1.0)*this->
right;
1099 if (this->
left->is_match(
rm->get_right())) {
1100 return (1.0 -
rm->get_left())*this->
left;
1101 }
else if (this->
left->is_match(
rm->get_left())) {
1102 return (1.0 -
rm->get_right())*this->
left;
1106 if (
lm.get() &&
rm.get()) {
1107 if (
lm->get_left()->is_match(
rm->get_left())) {
1109 return lm->get_left()*(
lm->get_right() -
rm->get_right());
1110 }
else if (
lm->get_left()->is_match(
rm->get_right())) {
1112 return lm->get_left()*(
lm->get_right() -
rm->get_left());
1113 }
else if (
lm->get_right()->is_match(
rm->get_left())) {
1115 return lm->get_right()*(
lm->get_left() -
rm->get_right());
1116 }
else if (
lm->get_right()->is_match(
rm->get_right())) {
1118 return lm->get_right()*(
lm->get_left() -
rm->get_left());
1123 if (
lm->get_left()->is_constant() &&
1124 rm->get_left()->is_constant() &&
1125 !
lm->get_left()->has_constant_zero()) {
1126 return lm->get_left()*(
lm->get_right() -
1127 (
rm->get_left()/
lm->get_left())*
rm->get_right());
1134 if (
lm->get_right()->is_match(
rmrm->get_right())) {
1135 return (
lm->get_left() -
rm->get_left()*
rmrm->get_left())*
lm->get_right();
1138 if (
lm->get_right()->is_match(
rmrm->get_left())) {
1139 return (
lm->get_left() -
rm->get_left()*
rmrm->get_right())*
lm->get_right();
1142 if (
lm->get_left()->is_match(
rmrm->get_right())) {
1143 return (
lm->get_right() -
rm->get_left()*
rmrm->get_left())*
lm->get_left();
1146 if (
lm->get_left()->is_match(
rmrm->get_left())) {
1147 return (
lm->get_right() -
rm->get_left()*
rmrm->get_right())*
lm->get_left();
1153 if (
rm->get_right()->is_match(
lmrm->get_right())) {
1154 return (
lm->get_left()*
lmrm->get_left() -
rm->get_left())*
rm->get_right();
1157 if (
rm->get_right()->is_match(
lmrm->get_left())) {
1158 return (
lm->get_left()*
lmrm->get_right() -
rm->get_left())*
rm->get_right();
1161 if (
rm->get_left()->is_match(
lmrm->get_right())) {
1162 return (
lm->get_left()*
lmrm->get_left() -
rm->get_right())*
rm->get_left();
1165 if (
rm->get_left()->is_match(
lmrm->get_left())) {
1166 return (
lm->get_left()*
lmrm->get_right() -
rm->get_right())*
rm->get_left();
1179 lmld->get_right()->is_match(
rmld->get_right())) {
1180 return (
lmld->get_left()*
lm->get_right() -
1181 rmld->get_left()*
rm->get_right())/
lmld->get_right();
1182 }
else if (
lmld.get() &&
rmrd.get() &&
1183 lmld->get_right()->is_match(
rmrd->get_right())) {
1184 return (
lmld->get_left()*
lm->get_right() -
1185 rmrd->get_left()*
rm->get_left())/
lmld->get_right();
1186 }
else if (
lmrd.get() &&
rmld.get() &&
1187 lmrd->get_right()->is_match(
rmld->get_right())) {
1188 return (
lmrd->get_left()*
lm->get_left() -
1189 rmld->get_left()*
rm->get_right())/
lmrd->get_right();
1190 }
else if (
lmrd.get() &&
rmrd.get() &&
1191 lmrd->get_right()->is_match(
rmrd->get_right())) {
1192 return (
lmrd->get_left()*
lm->get_left() -
1193 rmrd->get_left()*
rm->get_left())/
lmrd->get_right();
1200 if (
lrm.get() &&
rm.get()) {
1201 if (
lrm->get_left()->is_match(
rm->get_left())) {
1203 return ls->get_left() -
1205 rm->get_right())*
rm->get_left();
1206 }
else if (
lrm->get_left()->is_match(
rm->get_right())) {
1208 return ls->get_left() -
1210 rm->get_left())*
rm->get_right();
1211 }
else if (
lrm->get_right()->is_match(
rm->get_left())) {
1213 return ls->get_left() -
1215 rm->get_right())*
rm->get_left();
1216 }
else if (
lrm->get_right()->is_match(
rm->get_right())) {
1218 return ls->get_left() -
1220 rm->get_left())*
rm->get_right();
1237 if ((
rd->get_left()->is_constant() &&
1238 rd->get_left()->evaluate().is_negative()) ||
1240 (
rdlm->get_left()->is_constant() &&
1241 rdlm->get_left()->evaluate().is_negative()))) {
1244 }
else if (
ld.get()) {
1246 if ((
ld->get_left()->is_constant() &&
1247 ld->get_left()->evaluate().is_negative()) ||
1249 (
ldlm->get_left()->is_constant() &&
1250 ldlm->get_left()->evaluate().is_negative()))) {
1255 if (
ld.get() &&
rd.get()) {
1256 if (
ld->get_right()->is_match(
rd->get_right())) {
1257 return (
ld->get_left() -
rd->get_left())/
ld->get_right();
1267 if (
ld->get_left()->is_match(
rdlm->get_left())) {
1268 return (1.0/
ld->get_right() -
1269 rdlm->get_right()/
rd->get_right())*
rdlm->get_left();
1270 }
else if (
ld->get_left()->is_match(
rdlm->get_right())) {
1271 return (1.0/
ld->get_right() -
1272 rdlm->get_left()/
rd->get_right())*
rdlm->get_right();
1274 }
else if (
ldlm.get()) {
1275 if (
rd->get_left()->is_match(
ldlm->get_left())) {
1276 return (
ldlm->get_right()/
ld->get_right() -
1277 1.0/
rd->get_right())*
ldlm->get_left();
1278 }
else if (
rd->get_left()->is_match(
ldlm->get_right())) {
1279 return (
ldlm->get_left()/
ld->get_right() -
1280 1.0/
rd->get_right())*
ldlm->get_right();
1291 rdlm->get_left()) &&
1292 !
ldlm->get_right()->is_match(
rdlm->get_right())) {
1293 return (
ldlm->get_right()/
ld->get_right() -
1294 rdlm->get_left()/
ldlm->get_left() *
1295 rdlm->get_right()/
rd->get_right())*
ldlm->get_left();
1298 if (
ldlm->get_right()->is_match(
rdlm->get_right())) {
1299 return (
ldlm->get_left()/
ld->get_right() -
1300 rdlm->get_left()/
rd->get_right())*
ldlm->get_right();
1301 }
else if (
ldlm->get_right()->is_match(
rdlm->get_left())) {
1302 return (
ldlm->get_left()/
ld->get_right() -
1303 rdlm->get_right()/
rd->get_right())*
ldlm->get_right();
1304 }
else if (
ldlm->get_left()->is_match(
rdlm->get_right())) {
1305 return (
ldlm->get_right()/
ld->get_right() -
1306 rdlm->get_left()/
rd->get_right())*
ldlm->get_left();
1307 }
else if (
ldlm->get_left()->is_match(
rdlm->get_left())) {
1308 return (
ldlm->get_right()/
ld->get_right() -
1309 rdlm->get_right()/
rd->get_right())*
ldlm->get_left();
1320 if (
ldrm->get_right()->is_match(
rdrm->get_right())) {
1321 return (
ld->get_left()/
ldrm->get_left() -
1322 rd->get_left()/
rdrm->get_left())/
ldrm->get_right();
1323 }
else if (
ldrm->get_right()->is_match(
rdrm->get_left())) {
1324 return (
ld->get_left()/
ldrm->get_left() -
1325 rd->get_left()/
rdrm->get_right())/
ldrm->get_right();
1326 }
else if (
ldrm->get_left()->is_match(
rdrm->get_right())) {
1327 return (
ld->get_left()/
ldrm->get_right() -
1328 rd->get_left()/
rdrm->get_left())/
ldrm->get_left();
1329 }
else if (
ldrm->get_left()->is_match(
rdrm->get_left())) {
1330 return (
ld->get_left()/
ldrm->get_right() -
1331 rd->get_left()/
rdrm->get_right())/
ldrm->get_left();
1340 if (
ld->get_right()->is_match(
rdrm->get_left())) {
1341 return (
ld->get_left()*
rdrm->get_right() -
rd->get_left()) /
1343 }
else if (
ld->get_right()->is_match(
rdrm->get_right())) {
1344 return (
ld->get_left()*
rdrm->get_left() -
rd->get_left()) /
1347 }
else if (
ldrm.get()) {
1348 if (
rd->get_right()->is_match(
ldrm->get_left())) {
1349 return (
ld->get_left() -
rd->get_left()*
ldrm->get_right()) /
1351 }
else if (
rd->get_right()->is_match(
ldrm->get_right())) {
1352 return (
ld->get_left() -
rd->get_left()*
ldrm->get_left()) /
1365 return la->get_left() + (
la->get_right() - this->
right);
1367 return ls->get_left() - (this->
right +
ls->get_right());
1377 if (
pl.get() &&
rd.get()) {
1379 if (
rdp.get() &&
pl->get_right()->is_match(
rdp->get_right())) {
1382 rdp->get_left()->is_match(
plld->get_right())) {
1383 return (
pow(
plld->get_left(),
pl->get_right()) -
1385 pow(
rdp->get_left(),
pl->get_right());
1388 }
else if (
pr.get() &&
ld.get()) {
1390 if (
ldp.get() &&
pr->get_right()->is_match(
ldp->get_right())) {
1393 ldp->get_left()->is_match(
prld->get_right())) {
1394 return (
pow(
prld->get_left(),
pr->get_right()) -
1396 pow(
ldp->get_left(),
pr->get_right());
1399 }
else if (
pl.get() &&
pr.get()) {
1400 if (
pl->get_right()->is_match(
pr->get_right())) {
1403 if (
pld.get() &&
prd.get() &&
1404 pld->get_right()->is_match(
prd->get_right())) {
1405 return (
pow(
pld->get_left(),
pl->get_right()) -
1406 pow(
prd->get_left(),
pl->get_right())) /
1407 pow(
pld->get_right(),
pl->get_right());
1416 if (
lfma->get_middle()->is_match(
rfma->get_middle())) {
1419 lfma->get_right() -
rfma->get_right());
1424 if (
lfma.get() && !
this->right->is_all_variables()) {
1437 return ls->get_left() - (
ls->get_right() + this->
right);
1458 const size_t hash =
reinterpret_cast<size_t> (x.get());
1479 if (registers.find(
this) == registers.end()) {
1490 stream <<
" const ";
1491 jit::add_type<T> (stream);
1492 stream <<
" " << registers[
this] <<
" = "
1493 << registers[
l.get()] <<
" - "
1494 << registers[
r.get()];
1508 if (
this == x.get()) {
1514 return this->
left->is_match(
x_cast->get_left()) &&
1530 std::cout <<
"\\left(";
1532 this->
left->to_latex();
1534 std::cout <<
"\\right)";
1538 std::cout <<
"\\left(";
1540 this->
right->to_latex();
1542 std::cout <<
"\\right)";
1553 return this->
left->remove_pseudo() -
1554 this->
right->remove_pseudo();
1568 if (registers.find(
this) == registers.end()) {
1570 registers[
this] =
name;
1571 stream <<
" " <<
name
1572 <<
" [label = \"-\", shape = oval, style = filled, fillcolor = blue, fontcolor = white];" << std::endl;
1574 auto l = this->
left->to_vizgraph(stream, registers);
1575 stream <<
" " <<
name <<
" -- " << registers[
l.get()] <<
";" << std::endl;
1576 auto r = this->
right->to_vizgraph(stream, registers);
1577 stream <<
" " <<
name <<
" -- " << registers[
r.get()] <<
";" << std::endl;
1594 template<jit::
float_scalar T,
bool SAFE_MATH=false>
1597 auto temp = std::make_shared<subtract_node<T, SAFE_MATH>> (
l,
r)->reduce();
1599 for (
size_t i =
temp->get_hash();
1609#if defined(__clang__) || defined(__GNUC__)
1612 assert(
false &&
"Should never reach.");
1629 template<jit::
float_scalar T,
bool SAFE_MATH=false>
1649 template<jit::
float_scalar T, jit::
float_scalar L,
bool SAFE_MATH=false>
1669 template<jit::
float_scalar T, jit::
float_scalar R,
bool SAFE_MATH=false>
1687 template<jit::
float_scalar T,
bool SAFE_MATH=false>
1693 template<jit::
float_scalar T,
bool SAFE_MATH=false>
1705 template<jit::
float_scalar T,
bool SAFE_MATH=false>
1707 return std::dynamic_pointer_cast<subtract_node<T, SAFE_MATH>> (x);
1719 template<jit::
float_scalar T,
bool SAFE_MATH=false>
1740 auto temp2 = reduce_nested_fma_times_constant(
temp->get_left());
1765 if (
add->get_right()->is_match(
temp->get_middle()) &&
1767 auto temp2 = expand_nested_fma_times_add2(
temp->get_left(),
1772 temp->get_right()*
add->get_left());
1776 add->get_left()*
temp->get_left() +
temp->get_right()),
1778 temp->get_right()*
add->get_left());
1803 if (
add->get_right()->is_match(
temp->get_middle()) &&
1809 add->get_left()*
temp->get_left() +
1812 add->get_left()*
temp->get_right() +
1813 temp2->get_right());
1815 auto temp3 = expand_nested_fma_times_add2(
temp->get_left(),
1820 add->get_left()*
temp->get_right() +
1821 temp2->get_right());
1864 if (this->
left.get() ==
this->right.get()) {
1888 if (l.get() &&
l->is(1)) {
1890 }
else if (l.get() &&
l->is(0)) {
1892 }
else if (r.get() &&
r->is(1)) {
1894 }
else if (r.get() &&
r->is(0)) {
1896 }
else if (l.get() &&
r.get()) {
1903 if (
pl1.get() && (r.get() ||
pl1->is_arg_match(
this->right))) {
1905 pl1->get_scale(),
pl1->get_offset());
1906 }
else if (
pr1.get() && (l.get() ||
pr1->is_arg_match(
this->left))) {
1908 pr1->get_scale(),
pr1->get_offset());
1914 if (
pl2.get() && (r.get() ||
pl2->is_arg_match(
this->right))) {
1916 pl2->get_num_columns(),
1917 pl2->get_left(),
pl2->get_x_scale(),
pl2->get_x_offset(),
1918 pl2->get_right(),
pl2->get_y_scale(),
pl2->get_y_offset());
1919 }
else if (
pr2.get() && (l.get() ||
pr2->is_arg_match(
this->left))) {
1921 pr2->get_num_columns(),
1922 pr2->get_left(),
pr2->get_x_scale(),
pr2->get_x_offset(),
1923 pr2->get_right(),
pr2->get_y_scale(),
pr2->get_y_offset());
1927 if (
pr2.get() &&
pr2->is_row_match(
this->left)) {
1931 pr2->get_num_columns(),
1932 pr2->get_left(),
pr2->get_x_scale(),
pr2->get_x_offset(),
1933 pr2->get_right(),
pr2->get_y_scale(),
pr2->get_y_offset());
1934 }
else if (
pr2.get() &&
pr2->is_col_match(
this->left)) {
1938 pr2->get_num_columns(),
1939 pr2->get_left(),
pr2->get_x_scale(),
pr2->get_x_offset(),
1940 pr2->get_right(),
pr2->get_y_scale(),
pr2->get_y_offset());
1941 }
else if (
pl2.get() &&
pl2->is_row_match(
this->right)) {
1945 pl2->get_num_columns(),
1946 pl2->get_left(),
pl2->get_x_scale(),
pl2->get_x_offset(),
1947 pl2->get_right(),
pl2->get_y_scale(),
pl2->get_y_offset());
1948 }
else if (
pl2.get() &&
pl2->is_col_match(
this->right)) {
1952 pl2->get_num_columns(),
1953 pl2->get_left(),
pl2->get_x_scale(),
pl2->get_x_offset(),
1954 pl2->get_right(),
pl2->get_y_scale(),
pl2->get_y_offset());
1970 if ((
cl.get() && !
this->right->is_power_like() &&
1971 !
this->right->is_all_variables() &&
1973 (
sl.get() && !
this->right->is_power_like() &&
1974 !
this->right->is_all_variables()) ||
1980 if (this->
left->is_match(
this->right)) {
1991 return lm->get_left()*(
lm->get_right()*this->
right);
1997 return (this->
right*
lm->get_left())*
lm->get_right();
1999 return (this->
right*
lm->get_right())*
lm->get_left();
2006 return (
lm->get_left()*this->
right)*
lm->get_right();
2019 this->right->get_power_exponent());
2031 if (
temp->is_normal()) {
2032 return temp*
rm->get_right();
2039 return (this->
left*
rm->get_left())*
rm->get_right();
2041 return (this->
left*
rm->get_right())*
rm->get_left();
2047 return (this->
left*
rm->get_left())*
rm->get_right();
2054 auto temp = this->reduce_nested_fma_times_constant(
rm->get_left());
2056 return temp*
rm->get_right();
2063 return rm->get_left()*(this->
left*
rm->get_right());
2075 !this->
right->is_power_like()) {
2076 return (
lm->get_left()*this->
right)*
lm->get_right();
2077 }
else if (
rm.get() &&
2080 !this->
left->is_constant()) {
2081 return (this->
left*
rm->get_left())*
rm->get_right();
2086 if (
lm.get() &&
rm.get()) {
2089 auto temp =
lm->get_left()*
rm->get_left();
2090 if (
temp->is_normal()) {
2091 return temp*(
lm->get_right()*
rm->get_right());
2095 auto temp =
lm->get_left()*
rm->get_right();
2096 if (
temp->is_normal()) {
2097 return temp*(
lm->get_right()*
rm->get_left());
2101 auto temp =
lm->get_right()*
rm->get_left();
2102 if (
temp->is_normal()) {
2103 return temp*(
lm->get_left()*
rm->get_right());
2107 auto temp =
lm->get_right()*
rm->get_right();
2108 if (
temp->is_normal()) {
2109 return temp*(
lm->get_left()*
rm->get_left());
2114 if (
lm->get_left()->is_match(
rm->get_left())) {
2115 return (
lm->get_left()*
rm->get_left()) *
2116 (
lm->get_right()*
rm->get_right());
2117 }
else if (
lm->get_right()->is_match(
rm->get_left())) {
2118 return (
lm->get_right()*
rm->get_left()) *
2119 (
lm->get_left()*
rm->get_right());
2120 }
else if (
lm->get_left()->is_match(
rm->get_right())) {
2121 return (
lm->get_left()*
rm->get_right()) *
2122 (
lm->get_right()*
rm->get_left());
2123 }
else if (
lm->get_right()->is_match(
rm->get_right())) {
2124 return (
lm->get_right()*
rm->get_right()) *
2125 (
lm->get_left()*
rm->get_left());
2136 return (this->
left*
rd->get_left())/
rd->get_right();
2137 }
else if (
ld.get()) {
2138 return (
ld->get_left()*
this->right)/
ld->get_right();
2143 if (
ld.get() &&
rd.get()) {
2144 if (
ld->get_left()->is_match(
rd->get_right())) {
2145 return rd->get_left()/
ld->get_right();
2146 }
else if (
ld->get_right()->is_match(
rd->get_left())) {
2147 return ld->get_left()/
rd->get_right();
2152 return (
ld->get_left()*
rd->get_left()) /
2153 (
ld->get_right()*
rd->get_right());
2158 return pow(this->
left->get_power_base(),
2159 this->left->get_power_exponent() +
2160 this->right->get_power_exponent());
2168 return this->
left/
pow(
rp->get_left(), -
rp->get_right());
2176 return this->
right/
pow(
lp->get_left(), -
lp->get_right());
2180 if (
lp.get() &&
rp.get()) {
2181 if (
lp->get_right()->is_match(
rp->get_right())) {
2182 return pow(
lp->get_left()*
rp->get_left(),
lp->get_right());
2189 if (
lm.get() &&
rp.get()) {
2193 if (
lmrp->get_right()->is_match(
rp->get_right())) {
2194 return lm->get_left()*
pow(
lmrp->get_left()*
rp->get_left(),
2197 }
else if (
lmlp.get()) {
2198 if (
lmlp->get_right()->is_match(
rp->get_right())) {
2199 return lm->get_right()*
pow(
lmlp->get_left()*
rp->get_left(),
2203 }
else if (
rm.get() &&
lp.get()) {
2207 if (
rmrp->get_right()->is_match(
lp->get_right())) {
2208 return rm->get_left()*
pow(
lp->get_left()*
rmrp->get_left(),
2211 }
else if (
rmlp.get()) {
2212 if (
rmlp->get_right()->is_match(
lp->get_right())) {
2213 return rm->get_right()*
pow(
lp->get_left()*
rmlp->get_left(),
2223 if (
lp.get() &&
rp.get()) {
2229 return pow(
lplm->get_left()->get_power_base(),
2230 this->left->get_power_exponent())*
2232 this->left->get_power_exponent() +
2233 this->right->get_power_exponent());
2236 return pow(
lplm->get_right()->get_power_base(),
2237 this->left->get_power_exponent())*
2239 this->left->get_power_exponent() +
2240 this->right->get_power_exponent());
2247 return pow(
rplm->get_left()->get_power_base(),
2248 this->right->get_power_exponent())*
2249 pow(this->
left->get_power_base(),
2250 this->left->get_power_exponent() +
2251 this->right->get_power_exponent());
2254 return pow(
rplm->get_right()->get_power_base(),
2255 this->right->get_power_exponent())*
2256 pow(this->
left->get_power_base(),
2257 this->left->get_power_exponent() +
2258 this->right->get_power_exponent());
2268 return pow(
lpd->get_left(),
this->left->get_power_exponent()) *
2270 this->right->get_power_exponent() -
2271 this->left->get_power_exponent()*
lpd->get_right()->get_power_exponent());
2275 return pow(this->
right->get_power_base(),
2276 this->right->get_power_exponent() +
2277 this->left->get_power_exponent()*
lpd->get_left()->get_power_exponent()) /
2278 pow(
lpd->get_right(),
this->left->get_power_exponent());
2286 return pow(
rpd->get_left(),
this->right->get_power_exponent()) *
2287 pow(this->
left->get_power_base(),
2288 this->left->get_power_exponent() -
2289 this->right->get_power_exponent()*
rpd->get_right()->get_power_exponent());
2294 return pow(this->
right->get_power_base(),
2295 this->right->get_power_exponent() +
2296 this->right->get_power_exponent()*
rpd->get_left()->get_power_exponent()) /
2297 pow(
rpd->get_right(),
this->right->get_power_exponent());
2304 if (
le.get() &&
re.get()) {
2305 return exp(
le->get_arg() +
re->get_arg());
2310 if (
le.get() &&
rm.get()) {
2313 return rm->get_right()*(this->
left*
rm->get_left());
2317 return rm->get_left()*(this->
left*
rm->get_right());
2322 if (
re.get() &&
lm.get()) {
2325 return lm->get_right()*(this->
right*
lm->get_left());
2329 return lm->get_left()*(this->
right*
lm->get_right());
2336 if (
lm.get() &&
rm.get()) {
2341 return (
lm->get_right()*
rm->get_right()) *
2342 (
lm->get_left()*
rm->get_left());
2346 return (
lm->get_right()*
rm->get_left()) *
2347 (
lm->get_left()*
rm->get_right());
2354 return (
lm->get_left()*
rm->get_right()) *
2355 (
lm->get_right()*
rm->get_left());
2359 return (
lm->get_left()*
rm->get_left()) *
2360 (
lm->get_right()*
rm->get_right());
2365 if (
ld.get() &&
re.get()) {
2369 return ld->get_left()*(this->
right/
ld->get_right());
2374 return (
ld->get_left()*
this->right)/
ld->get_right();
2377 if (
rd.get() &&
le.get()) {
2381 return rd->get_left()*(this->
left/
rd->get_right());
2386 return (this->
left*
rd->get_left())/
rd->get_right();
2390 if (
ld.get() &&
rm.get()) {
2396 return (
ld->get_left()*
rm->get_right()) *
2397 (
rm->get_left()/
ld->get_right());
2402 return (
rm->get_right()/
ld->get_right()) *
2403 (
ld->get_left()*
rm->get_left());
2411 return (
ld->get_left()*
rm->get_left()) *
2412 (
rm->get_right()/
ld->get_right());
2417 return (
rm->get_left()/
ld->get_right()) *
2418 (
ld->get_left()*
rm->get_right());
2421 }
else if (
rd.get() &&
lm.get()) {
2427 return (
lm->get_left()/
rd->get_right()) *
2428 (
lm->get_right()*
rd->get_left());
2433 return (
lm->get_left()*
rd->get_left()) *
2434 (
lm->get_right()/
rd->get_right());
2442 return (
lm->get_right()*
rd->get_left()) *
2443 (
lm->get_left()/
rd->get_right());
2448 return (
lm->get_right()/
rd->get_right()) *
2449 (
lm->get_left()*
rd->get_left());
2458 auto fma_reduce = this->reduce_nested_fma_times_constant(this->
right);
2498 const size_t hash =
reinterpret_cast<size_t> (x.get());
2520 if (registers.find(
this) == registers.end()) {
2531 stream <<
" const ";
2532 jit::add_type<T> (stream);
2533 stream <<
" " << registers[
this] <<
" = ";
2535 stream <<
"(" << registers[
l.get()] <<
" == ";
2537 jit::add_type<T> (stream);
2542 stream <<
" || " << registers[
r.get()] <<
" == ";
2544 jit::add_type<T> (stream);
2551 jit::add_type<T> (stream);
2558 stream << registers[
l.get()] <<
"*"
2559 << registers[
r.get()];
2573 if (
this == x.get()) {
2580 if ((this->
left->is_match(
x_cast->get_left()) &&
2581 this->right->is_match(
x_cast->get_right())) ||
2582 (
this->right->is_match(
x_cast->get_left()) &&
2583 this->left->is_match(
x_cast->get_right()))) {
2598 std::cout <<
"\\left(";
2599 this->
left->to_latex();
2600 std::cout <<
"\\right)";
2602 this->
left->to_latex();
2608 std::cout <<
"\\left(";
2609 this->
right->to_latex();
2610 std::cout <<
"\\right)";
2612 this->
right->to_latex();
2623 return this->
left->remove_pseudo() *
2624 this->
right->remove_pseudo();
2638 if (registers.find(
this) == registers.end()) {
2640 registers[
this] =
name;
2641 stream <<
" " <<
name
2642 <<
" [label = \"⨉\", shape = oval, style = filled, fillcolor = blue, fontcolor = white];" << std::endl;
2644 auto l = this->
left->to_vizgraph(stream, registers);
2645 stream <<
" " <<
name <<
" -- " << registers[
l.get()] <<
";" << std::endl;
2646 auto r = this->
right->to_vizgraph(stream, registers);
2647 stream <<
" " <<
name <<
" -- " << registers[
r.get()] <<
";" << std::endl;
2663 template<jit::
float_scalar T,
bool SAFE_MATH=false>
2666 auto temp = std::make_shared<multiply_node<T, SAFE_MATH>> (
l,
r)->reduce();
2668 for (
size_t i =
temp->get_hash();
2678#if defined(__clang__) || defined(__GNUC__)
2681 assert(
false &&
"Should never reach.");
2697 template<jit::
float_scalar T,
bool SAFE_MATH=false>
2716 template<jit::
float_scalar T, jit::
float_scalar L,
bool SAFE_MATH=false>
2735 template<jit::
float_scalar T, jit::
float_scalar R,
bool SAFE_MATH=false>
2742 template<jit::
float_scalar T,
bool SAFE_MATH=false>
2754 template<jit::
float_scalar T,
bool SAFE_MATH=false>
2756 return std::dynamic_pointer_cast<multiply_node<T, SAFE_MATH>> (x);
2768 template<jit::
float_scalar T,
bool SAFE_MATH=false>
2827 if ((l.get() &&
l->is(0)) ||
2828 (
r.get() &&
r->is(1))) {
2830 }
else if (l.get() &&
r.get()) {
2837 if (
pl1.get() && (r.get() ||
pl1->is_arg_match(
this->right))) {
2839 pl1->get_scale(),
pl1->get_offset());
2840 }
else if (
pr1.get() && (l.get() ||
pr1->is_arg_match(
this->left))) {
2842 pr1->get_scale(),
pr1->get_offset());
2848 if (
pl2.get() && (r.get() ||
pl2->is_arg_match(
this->right))) {
2850 pl2->get_num_columns(),
2851 pl2->get_left(),
pl2->get_x_scale(),
pl2->get_x_offset(),
2852 pl2->get_right(),
pl2->get_y_scale(),
pl2->get_y_offset());
2853 }
else if (
pr2.get() && (l.get() ||
pr2->is_arg_match(
this->left))) {
2855 pr2->get_num_columns(),
2856 pr2->get_left(),
pr2->get_x_scale(),
pr2->get_x_offset(),
2857 pr2->get_right(),
pr2->get_y_scale(),
pr2->get_y_offset());
2861 if (
pr2.get() &&
pr2->is_row_match(
this->left)) {
2865 pr2->get_num_columns(),
2866 pr2->get_left(),
pr2->get_x_scale(),
pr2->get_x_offset(),
2867 pr2->get_right(),
pr2->get_y_scale(),
pr2->get_y_offset());
2868 }
else if (
pr2.get() &&
pr2->is_col_match(
this->left)) {
2872 pr2->get_num_columns(),
2873 pr2->get_left(),
pr2->get_x_scale(),
pr2->get_x_offset(),
2874 pr2->get_right(),
pr2->get_y_scale(),
pr2->get_y_offset());
2875 }
else if (
pl2.get() &&
pl2->is_row_match(
this->right)) {
2879 pl2->get_num_columns(),
2880 pl2->get_left(),
pl2->get_x_scale(),
pl2->get_x_offset(),
2881 pl2->get_right(),
pl2->get_y_scale(),
pl2->get_y_offset());
2882 }
else if (
pl2.get() &&
pl2->is_col_match(
this->right)) {
2886 pl2->get_num_columns(),
2887 pl2->get_left(),
pl2->get_x_scale(),
pl2->get_x_offset(),
2888 pl2->get_right(),
pl2->get_y_scale(),
pl2->get_y_offset());
2891 if (this->
left->is_match(
this->right)) {
2896 if (this->
right->is_constant()) {
2907 return this->
left*
rald->get_right() /
2911 }
else if (
rard.get()) {
2912 return this->
left*
rard->get_right() /
2926 return this->
left*
rsld->get_right() /
2928 rsld->get_right()*
rs->get_right());
2929 }
else if (
rsrd.get()) {
2930 return this->
left*
rsrd->get_right() /
2931 (
rsrd->get_right()*
rs->get_left() -
2944 if (
lfma->get_middle()->is_match(
this->right) &&
2945 fmarm->get_right()->is_match(
this->right)) {
2946 return lfma->get_left() +
fmarm->get_left();
2947 }
else if (
lfma->get_middle()->is_match(
this->right) &&
2948 fmarm->get_left()->is_match(
this->right)) {
2949 return lfma->get_left() +
fmarm->get_right();
2950 }
else if (
lfma->get_left()->is_match(
this->right) &&
2951 fmarm->get_right()->is_match(
this->right)) {
2952 return lfma->get_middle() +
fmarm->get_left();
2953 }
else if (
lfma->get_left()->is_match(
this->right) &&
2954 fmarm->get_left()->is_match(
this->right)) {
2955 return lfma->get_middle() +
fmarm->get_right();
2964 if (
lm.get() &&
rm.get()) {
2969 return (
lm->get_left()/
rm->get_left()) *
2970 (
lm->get_right()/
rm->get_right());
2975 return (
lm->get_left()/
rm->get_right()) *
2976 (
lm->get_right()/
rm->get_left());
2984 if (
rm->get_left()->is_constant() &&
2985 rm->get_left()->is_normal()) {
2986 return ((1.0/
rm->get_left())*this->
left)/
rm->get_right();
2987 }
else if (
rm->get_right()->is_constant() &&
2988 rm->get_right()->is_normal()) {
2989 return ((1.0/
rm->get_right())*this->
left)/
rm->get_left();
3005 rmlald->get_left())*
rm->get_right());
3006 }
else if (
rmlard.get()) {
3010 rmlard->get_left())*
rm->get_right());
3020 rmrald->get_left())*
rm->get_left());
3021 }
else if (
rmrard.get()) {
3025 rmrard->get_left())*
rm->get_left());
3041 rmlsld->get_right()*
rmls->get_right())*
rm->get_right());
3042 }
else if (
rmlsrd.get()) {
3045 rmlsrd->get_left())*
rm->get_right());
3054 rmrsld->get_right()*
rmrs->get_right())*
rm->get_left());
3055 }
else if (
rmrsrd.get()) {
3058 rmrsrd->get_left())*
rm->get_left());
3063 if (
lm.get() &&
rm.get()) {
3068 if (
lm->get_left()->is_match(
rm->get_left())) {
3069 return lm->get_right()/
rm->get_right();
3070 }
else if (
lm->get_left()->is_match(
rm->get_right())) {
3071 return lm->get_right()/
rm->get_left();
3072 }
else if (
lm->get_right()->is_match(
rm->get_left())) {
3073 return lm->get_left()/
rm->get_right();
3074 }
else if (
lm->get_right()->is_match(
rm->get_right())) {
3075 return lm->get_left()/
rm->get_left();
3082 if (
lm->get_left()->is_match(this->
right)) {
3083 return lm->get_right();
3084 }
else if (
lm->get_right()->is_match(this->
right)) {
3085 return lm->get_left();
3092 return lm->get_right()*(
lm->get_left()/this->
right);
3095 return lm->get_left()*(
lm->get_right()/this->
right);
3104 return ld->get_left()/(
ld->get_right()*this->
right);
3107 return this->
left*
rd->get_right()/
rd->get_left();
3113 return pow(this->
left->get_power_base(),
3114 this->left->get_power_exponent() -
3115 this->right->get_power_exponent());
3123 return this->
left*
pow(
rp->get_left(), -
rp->get_right());
3133 if (
lpm->get_left()->is_match(
this->right->get_power_base())) {
3134 return pow(this->
right->get_power_base(),
3135 this->left->get_power_exponent() -
3136 this->right->get_power_exponent()) *
3138 this->left->get_power_exponent());
3139 }
else if (
lpm->get_right()->is_match(
this->right->get_power_base())) {
3140 return pow(this->
right->get_power_base(),
3141 this->left->get_power_exponent() -
3142 this->right->get_power_exponent()) *
3144 this->left->get_power_exponent());
3150 if (
lp.get() &&
rp.get()) {
3151 if (
lp->get_right()->is_match(
rp->get_right())) {
3152 return pow(
lp->get_left()/
rp->get_left(),
lp->get_right());
3160 if (
lp.get() &&
rm.get()) {
3163 if (
lpm->get_left()->is_match(
rm->get_left()->get_power_base())) {
3164 return (
pow(
rm->get_left()->get_power_base(),
3165 this->left->get_power_exponent() -
3166 rm->get_left()->get_power_exponent()) *
3168 this->left->get_power_exponent())) /
3170 }
else if (
lpm->get_right()->is_match(
rm->get_left()->get_power_base())) {
3171 return (
pow(
rm->get_left()->get_power_base(),
3172 this->left->get_power_exponent() -
3173 rm->get_left()->get_power_exponent()) *
3175 this->left->get_power_exponent())) /
3177 }
else if (
lpm->get_left()->is_match(
rm->get_right()->get_power_base())) {
3178 return (
pow(
rm->get_right()->get_power_base(),
3179 this->left->get_power_exponent() -
3180 rm->get_right()->get_power_exponent()) *
3182 this->left->get_power_exponent())) /
3184 }
else if (
lpm->get_right()->is_match(
rm->get_right()->get_power_base())) {
3185 return (
pow(
rm->get_right()->get_power_base(),
3186 this->left->get_power_exponent() -
3187 rm->get_right()->get_power_exponent()) *
3189 this->left->get_power_exponent())) /
3205 return lm->get_left()*
lmrm->get_left() *
3209 return lm->get_left()*
lmrm->get_right() *
3212 }
else if (
lmlm.get()) {
3215 return lm->get_right()*
lmlm->get_left() *
3219 return lm->get_right()*
lmlm->get_right() *
3233 if (
lmlpm->get_left()->is_match(
this->right->get_power_base())) {
3234 return lm->get_right() *
3236 lmlp->get_power_exponent() -
3237 this->right->get_power_exponent()) *
3239 lmlp->get_power_exponent());
3240 }
else if (
lmlpm->get_right()->is_match(
this->right->get_power_base())) {
3241 return lm->get_right() *
3243 lmlp->get_power_exponent() -
3244 this->right->get_power_exponent()) *
3246 lmlp->get_power_exponent());
3249 }
else if (
lmrp.get()) {
3252 if (
lmrpm->get_left()->is_match(
this->right->get_power_base())) {
3253 return lm->get_left() *
3255 lmrp->get_power_exponent() -
3256 this->right->get_power_exponent()) *
3258 lmrp->get_power_exponent());
3259 }
else if (
lmrpm->get_right()->is_match(
this->right->get_power_base())) {
3260 return lm->get_left() *
3262 lmrp->get_power_exponent() -
3263 this->right->get_power_exponent()) *
3265 lmrp->get_power_exponent());
3279 if (
lm.get() &&
rm.get()) {
3285 if (
lmlpm->get_left()->is_match(
rm->get_left()->get_power_base())) {
3286 return lm->get_right() *
3287 (
pow(
rm->get_left()->get_power_base(),
3288 lmlp->get_power_exponent() -
3289 rm->get_left()->get_power_exponent())) *
3291 lmlp->get_power_exponent()) /
3293 }
else if (
lmlpm->get_right()->is_match(
rm->get_left()->get_power_base())) {
3294 return lm->get_right() *
3295 (
pow(
rm->get_left()->get_power_base(),
3296 lmlp->get_power_exponent() -
3297 rm->get_left()->get_power_exponent())) *
3299 lmlp->get_power_exponent()) /
3301 }
else if (
lmlpm->get_left()->is_match(
rm->get_right()->get_power_base())) {
3302 return lm->get_right() *
3303 (
pow(
rm->get_left()->get_power_base(),
3304 lmlp->get_power_exponent() -
3305 rm->get_right()->get_power_exponent())) *
3307 lmlp->get_power_exponent()) /
3309 }
else if (
lmlpm->get_right()->is_match(
rm->get_right()->get_power_base())) {
3310 return lm->get_right() *
3311 (
pow(
rm->get_left()->get_power_base(),
3312 lmlp->get_power_exponent() -
3313 rm->get_right()->get_power_exponent())) *
3315 lmlp->get_power_exponent()) /
3319 }
else if (
lmrp.get()) {
3322 if (
lmrpm->get_left()->is_match(
rm->get_left()->get_power_base())) {
3323 return lm->get_left() *
3324 (
pow(
rm->get_left()->get_power_base(),
3325 lmrp->get_power_exponent() -
3326 rm->get_left()->get_power_exponent())) *
3328 lmrp->get_power_exponent()) /
3330 }
else if (
lmrpm->get_right()->is_match(
rm->get_left()->get_power_base())) {
3331 return lm->get_left() *
3332 (
pow(
rm->get_left()->get_power_base(),
3333 lmrp->get_power_exponent() -
3334 rm->get_left()->get_power_exponent())) *
3336 lmrp->get_power_exponent()) /
3338 }
else if (
lmrpm->get_left()->is_match(
rm->get_right()->get_power_base())) {
3339 return lm->get_left() *
3340 (
pow(
rm->get_left()->get_power_base(),
3341 lmrp->get_power_exponent() -
3342 rm->get_right()->get_power_exponent())) *
3344 lmrp->get_power_exponent()) /
3346 }
else if (
lmrpm->get_right()->is_match(
rm->get_right()->get_power_base())) {
3347 return lm->get_left() *
3348 (
pow(
rm->get_left()->get_power_base(),
3349 lmrp->get_power_exponent() -
3350 rm->get_right()->get_power_exponent())) *
3352 lmrp->get_power_exponent()) /
3368 if (
rexp.get() &&
lm.get()) {
3371 return lm->get_left()*(
lm->get_right()/this->
right);
3375 return lm->get_right()*(
lm->get_left()/this->
right);
3382 if (
rexp.get() &&
lm.get()) {
3388 return lmlm->get_left()*
lm->get_right() *
3391 return lmlm->get_right()*
lm->get_right() *
3394 }
else if (
lmrm.get()) {
3396 return lmrm->get_left()*
lm->get_left() *
3399 return lmrm->get_right()*
lm->get_left() *
3407 if (
lexp.get() &&
rm.get()) {
3410 return (this->
left/
rm->get_right())/
rm->get_left();
3414 return (this->
left/
rm->get_left())/
rm->get_right();
3422 if (
lm.get() &&
rm.get()) {
3427 return (
lm->get_left()/
rm->get_left()) *
3428 (
lm->get_right()/
rm->get_right());
3432 return (
lm->get_left()/
rm->get_right()) *
3433 (
lm->get_right()/
rm->get_left());
3440 return (
lm->get_right()/
rm->get_left()) *
3441 (
lm->get_left()/
rm->get_right());
3445 return (
lm->get_right()/
rm->get_right()) *
3446 (
lm->get_left()/
rm->get_left());
3453 if (
rd.get() &&
lexp.get()) {
3456 return (this->
left*
rd->get_right())/
rd->get_left();
3460 return rd->get_right()*(this->
left/
rd->get_left());
3490 const size_t hash =
reinterpret_cast<size_t> (x.get());
3512 if (registers.find(
this) == registers.end()) {
3523 stream <<
" const ";
3524 jit::add_type<T> (stream);
3525 stream <<
" " << registers[
this] <<
" = ";
3527 stream << registers[
l.get()] <<
" == ";
3529 jit::add_type<T> (stream);
3536 jit::add_type<T> (stream);
3543 stream << registers[
l.get()] <<
"/"
3544 << registers[
r.get()];
3557 if (
this == x.get()) {
3563 return this->
left->is_match(
x_cast->get_left()) &&
3574 std::cout <<
"\\frac{";
3575 this->
left->to_latex();
3577 this->
right->to_latex();
3588 return this->
left->remove_pseudo() /
3589 this->
right->remove_pseudo();
3603 if (registers.find(
this) == registers.end()) {
3605 registers[
this] =
name;
3606 stream <<
" " <<
name
3607 <<
" [label = \"\\\\\", shape = oval, style = filled, fillcolor = blue, fontcolor = white];" << std::endl;
3609 auto l = this->
left->to_vizgraph(stream, registers);
3610 stream <<
" " <<
name <<
" -- " << registers[
l.get()] <<
";" << std::endl;
3611 auto r = this->
right->to_vizgraph(stream, registers);
3612 stream <<
" " <<
name <<
" -- " << registers[
r.get()] <<
";" << std::endl;
3628 template<jit::
float_scalar T,
bool SAFE_MATH=false>
3631 auto temp = std::make_shared<divide_node<T, SAFE_MATH>> (
l,
r)->reduce();
3633 for (
size_t i =
temp->get_hash();
3643#if defined(__clang__) || defined(__GNUC__)
3646 assert(
false &&
"Should never reach.");
3662 template<jit::
float_scalar T,
bool SAFE_MATH=false>
3681 template<jit::
float_scalar T, jit::
float_scalar L,
bool SAFE_MATH=false>
3700 template<jit::
float_scalar T, jit::
float_scalar R,
bool SAFE_MATH=false>
3707 template<jit::
float_scalar T,
bool SAFE_MATH=false>
3719 template<jit::
float_scalar T,
bool SAFE_MATH=false>
3721 return std::dynamic_pointer_cast<divide_node<T, SAFE_MATH>> (x);
3735 template<jit::
float_scalar T,
bool SAFE_MATH=false>
3753 temp->get_middle()->is_match(sub->get_left())) {
3756 temp->get_right() -
temp->get_left()*sub->get_right()),
3758 this->right -
temp->get_right()*sub->get_right());
3760 if (
temp->get_middle()->is_match(sub->get_left()) &&
3762 auto temp2 =
temp->reduce_nested_fma(sub);
3766 this->right -
temp->get_right()*sub->get_right());
3836 if ((l.get() &&
l->is(0)) ||
3837 (
m.get() &&
m->is(0))) {
3839 }
else if (r.get() &&
r->is(0)) {
3841 }
else if (l.get() &&
m.get() &&
r.get()) {
3843 }
else if (l.get() &&
m.get()) {
3845 }
else if (l.get() &&
l->is(-1)) {
3847 }
else if (m.get() &&
m->is(-1)) {
3849 }
else if (l.get() &&
l->is(1)) {
3851 }
else if (m.get() &&
m->is(1)) {
3870 if (this->
left->is_match(
this->right)) {
3872 }
else if (this->
middle->is_match(
this->right)) {
3893 return fma(-this->
left,
ms->get_right(),
3901 auto temp = this->reduce_nested_fma(
ms);
3902 if (
temp.get() !=
this) {
3913 if (
rm->get_left()->is_match(
this->left)) {
3915 }
else if (
rm->get_left()->is_match(this->
middle)) {
3917 }
else if (
rm->get_right()->is_match(this->
left)) {
3919 }
else if (
rm->get_right()->is_match(this->
middle)) {
3926 if (
rmlc.get() &&
rmlc->evaluate().is_negative()) {
3928 (-1.0*
rm->get_left())*
rm->get_right();
3938 !this->
left->has_constant_zero()) {
3940 if (
temp->is_normal()) {
3948 !
this->middle->has_constant_zero()) {
3950 if (
temp->is_normal()) {
3958 !
this->left->has_constant_zero()) {
3960 if (
temp->is_normal()) {
3968 !
this->middle->has_constant_zero()) {
3970 if (
temp->is_normal()) {
3982 if (
mm->get_left()->is_match(
rm->get_left())) {
3986 }
else if (
mm->get_left()->is_match(
rm->get_right())) {
3990 }
else if (
mm->get_right()->is_match(
rm->get_left())) {
3994 }
else if (
mm->get_right()->is_match(
rm->get_right())) {
4003 if ((
lm.get() ||
mm.get()) &&
4004 (
this->left->get_complexity() +
this->middle->get_complexity() >
4005 this->right->get_complexity())) {
4006 return fma(
rm->get_left(),
rm->get_right(),
4016 if (
lm.get() &&
rm.get()) {
4019 !
lm->get_left()->has_constant_zero()) {
4020 auto temp =
rm->get_left()/
lm->get_left();
4021 if (
temp->is_normal()){
4022 return lm->get_left()*
fma(
lm->get_right(),
4029 !
lm->get_right()->has_constant_zero()) {
4030 auto temp =
rm->get_left()/
lm->get_right();
4031 if (
temp->is_normal()){
4032 return lm->get_right()*
fma(
lm->get_left(),
4039 !
lm->get_left()->has_constant_zero()) {
4040 auto temp =
rm->get_right()/
lm->get_left();
4041 if (
temp->is_normal()) {
4042 return lm->get_left()*
fma(
lm->get_right(),
4049 !
lm->get_right()->has_constant_zero()) {
4050 auto temp =
rm->get_right()/
lm->get_right();
4051 if (
temp->is_normal()) {
4052 return lm->get_right()*
fma(
lm->get_left(),
4064 return fma(
lm->get_left(),
4065 lm->get_right()*
this->middle,
4068 }
else if (
mm.get()) {
4075 if (
temp->is_normal()) {
4084 if (
temp->is_normal()) {
4092 return fma(
mm->get_left(),
4093 this->left*
mm->get_right(),
4100 if (
mm->get_left()->is_match(
this->right)) {
4104 }
else if (
mm->get_right()->is_match(
this->right)) {
4117 !
this->left->has_constant_zero()) {
4119 if (
temp->is_normal()) {
4126 !this->
middle->has_constant_zero()) {
4128 if (
temp->is_normal()) {
4138 if (
ld.get() &&
ld->get_right()->is_match(
this->middle)) {
4139 return ld->get_left() + this->
right;
4142 if (
md.get() &&
md->get_right()->is_match(
this->left)) {
4143 return md->get_left() + this->
right;
4147 if (
ld.get() &&
rd.get()) {
4149 if (
ld->get_left()->is_match(
rd->get_left())) {
4150 return ld->get_left()*(this->
middle/
ld->get_right() +
4151 1.0/
rd->get_right());
4162 if (
ldrm->get_right()->is_match(
rd->get_right())) {
4163 return fma(
ld->get_left(),
this->middle,
4164 rd->get_left()*
ldrm->get_left()) /
4166 }
else if (
ldrm->get_left()->is_match(
rd->get_right())) {
4167 return fma(
ld->get_left(),
this->middle,
4168 rd->get_left()*
ldrm->get_right()) /
4171 }
else if (
rdrm.get()) {
4172 if (
rdrm->get_right()->is_match(
ld->get_right())) {
4173 return fma(
ld->get_left()*
rdrm->get_left(),
4174 this->middle,
rd->get_left()) /
4176 }
else if (
rdrm->get_left()->is_match(
ld->get_right())) {
4177 return fma(
ld->get_left()*
rdrm->get_right(),
4178 this->middle,
rd->get_left()) /
4182 }
else if (
md.get() &&
rd.get()) {
4191 if (
mdrm->get_right()->is_match(
rd->get_right())) {
4193 rd->get_left()*
mdrm->get_left()) /
4195 }
else if (
mdrm->get_left()->is_match(
rd->get_right())) {
4197 rd->get_left()*
mdrm->get_right()) /
4200 }
else if (
rdrm.get()) {
4201 if (
rdrm->get_right()->is_match(
md->get_right())) {
4205 }
else if (
rdrm->get_left()->is_match(
md->get_right())) {
4220 if (this->
middle->is_match(
rfma->get_middle())) {
4224 }
else if (this->
left->is_match(
rfma->get_middle())) {
4228 }
else if (this->
middle->is_match(
rfma->get_left())) {
4232 }
else if (this->
left->is_match(
rfma->get_left())) {
4243 if (
mm->get_right()->is_match(
rfma->get_middle())) {
4244 return fma(
mm->get_right(),
4249 }
else if (
mm->get_left()->is_match(
rfma->get_middle())) {
4250 return fma(
mm->get_left(),
4255 }
else if (
mm->get_right()->is_match(
rfma->get_left())) {
4256 return fma(
mm->get_right(),
4259 rfma->get_middle()),
4261 }
else if (
mm->get_left()->is_match(
rfma->get_left())) {
4262 return fma(
mm->get_left(),
4265 rfma->get_middle()),
4268 }
else if (
lm.get()) {
4273 if (
lm->get_right()->is_match(
rfma->get_middle())) {
4274 return fma(
lm->get_right(),
4279 }
else if (
lm->get_left()->is_match(
rfma->get_middle())) {
4280 return fma(
lm->get_left(),
4285 }
else if (
lm->get_right()->is_match(
rfma->get_left())) {
4286 return fma(
lm->get_right(),
4289 rfma->get_middle()),
4291 }
else if (
lm->get_left()->is_match(
rfma->get_left())) {
4292 return fma(
lm->get_left(),
4295 rfma->get_middle()),
4307 if (
rfmamm->get_right()->is_match(
this->middle)) {
4313 }
else if (
rfmamm->get_right()->is_match(
this->left)) {
4319 }
else if (
rfmamm->get_left()->is_match(
this->middle)) {
4325 }
else if (
rfmamm->get_left()->is_match(
this->left)) {
4332 }
else if (
rfmalm.get()) {
4337 if (
rfmalm->get_right()->is_match(
this->middle)) {
4343 }
else if (
rfmalm->get_right()->is_match(
this->left)) {
4349 }
else if (
rfmalm->get_left()->is_match(
this->middle)) {
4355 }
else if (
rfmalm->get_left()->is_match(
this->left)) {
4369 if (
mm->get_right()->is_match(
rfmamm->get_right())) {
4370 return fma(
mm->get_right(),
4375 }
else if (
mm->get_left()->is_match(
rfmamm->get_right())) {
4376 return fma(
mm->get_left(),
4381 }
else if (
mm->get_right()->is_match(
rfmamm->get_left())) {
4382 return fma(
mm->get_right(),
4387 }
else if (
mm->get_left()->is_match(
rfmamm->get_left())) {
4388 return fma(
mm->get_left(),
4394 }
else if (
lm.get() &&
rfmamm.get()) {
4399 if (
lm->get_right()->is_match(
rfmamm->get_right())) {
4400 return fma(
lm->get_right(),
4405 }
else if (
lm->get_left()->is_match(
rfmamm->get_right())) {
4406 return fma(
lm->get_left(),
4411 }
else if (
lm->get_right()->is_match(
rfmamm->get_left())) {
4412 return fma(
lm->get_right(),
4417 }
else if (
lm->get_left()->is_match(
rfmamm->get_left())) {
4418 return fma(
lm->get_left(),
4424 }
else if (
mm.get() &&
rfmalm.get()) {
4429 if (
mm->get_right()->is_match(
rfmalm->get_right())) {
4430 return fma(
mm->get_right(),
4435 }
else if (
mm->get_left()->is_match(
rfmalm->get_right())) {
4436 return fma(
mm->get_left(),
4441 }
else if (
mm->get_right()->is_match(
rfmalm->get_left())) {
4442 return fma(
mm->get_right(),
4447 }
else if (
mm->get_left()->is_match(
rfmalm->get_left())) {
4448 return fma(
mm->get_left(),
4454 }
else if (
lm.get() &&
rfmalm.get()) {
4459 if (
lm->get_right()->is_match(
rfmalm->get_right())) {
4460 return fma(
lm->get_right(),
4465 }
else if (
lm->get_left()->is_match(
rfmalm->get_right())) {
4466 return fma(
lm->get_left(),
4471 }
else if (
lm->get_right()->is_match(
rfmalm->get_left())) {
4472 return fma(
lm->get_right(),
4477 }
else if (
lm->get_left()->is_match(
rfmalm->get_left())) {
4478 return fma(
lm->get_left(),
4489 return fma(
rfma->get_middle(),
4505 return fma(
rfma->get_middle(),
4524 rfma->get_middle()),
4540 rfma->get_middle()),
4554 if (this->
left->is_match(
rfma->get_left()) &&
4555 this->middle->is_match(
rfma->get_middle())) {
4557 }
else if (this->
left->is_match(
rfma->get_middle()) &&
4558 this->middle->is_match(
rfma->get_left())) {
4569 if (
fmald.get() &&
ld->get_right()->is_match(
fmald->get_right())) {
4570 return (
ld->get_left()*
this->middle +
4571 fmald->get_left()*
rfma->get_middle())/
ld->get_right() +
4573 }
else if (
fmamd.get() &&
ld->get_right()->is_match(
fmamd->get_right())) {
4574 return (
ld->get_left()*
this->middle +
4575 fmamd->get_left()*
rfma->get_left())/
ld->get_right() +
4578 }
else if (
md.get()) {
4579 if (
fmald.get() &&
md->get_right()->is_match(
fmald->get_right())) {
4580 return (
md->get_left()*
this->left +
4581 fmald->get_left()*
rfma->get_middle())/
md->get_right() +
4583 }
else if (
fmamd.get() &&
md->get_right()->is_match(
fmamd->get_right())) {
4584 return (
md->get_left()*
this->left +
4585 fmamd->get_left()*
rfma->get_left())/
md->get_right() +
4594 if (this->
left->is_all_variables()) {
4596 if (
rdl->get_complexity() <
this->left->get_complexity() +
4597 this->right->get_complexity()) {
4600 }
else if (this->
middle->is_all_variables()) {
4603 if ((
rdm->get_complexity() <
this->middle->get_complexity() +
4604 this->right->get_complexity()) &&
4605 !(
rdmc.get() &&
rdmc->evaluate().is_negative())) {
4625 return this->
left/
pow(
mp->get_left(), -
mp->get_right()) +
4644 if (
lp.get() &&
mp.get()) {
4645 if (
lp->get_right()->is_match(
mp->get_right())) {
4646 return pow(
lp->get_left()*
mp->get_left(),
4653 if (
rm.get() &&
mp.get()) {
4674 if (
rfma.get() &&
mp.get()) {
4678 rfma->get_left())) {
4684 rfma->get_middle()),
4687 rfma->get_left())) {
4693 rfma->get_middle()),
4708 rfma->get_right()));
4716 if (
md.get() &&
rd.get()) {
4717 if (
md->get_left()->is_match(
rd->get_left())) {
4718 return md->get_left()*(this->
left/
md->get_right() +
4719 1.0/
rd->get_right());
4720 }
else if (
md->get_right()->is_match(
rd->get_right())) {
4721 return (this->
left*
md->get_left() +
4722 rd->get_left())/
md->get_right();
4727 if (
ld.get() &&
rd.get()) {
4728 if (
ld->get_left()->is_match(
rd->get_left())) {
4729 return ld->get_left()*(this->
middle/
ld->get_right() +
4730 1.0/
rd->get_right());
4731 }
else if (
ld->get_right()->is_match(
rd->get_right())) {
4732 return (this->
middle*
ld->get_left() +
4733 rd->get_left())/
ld->get_right();
4739 if (
rm.get() &&
ld.get()) {
4741 if (
rmld.get() &&
ld->get_right()->is_match(
rmld->get_right())) {
4742 return fma(
ld->get_left(),
this->middle,
rmld->get_left()*
rm->get_right())/
ld->get_right();
4745 if (
rmrd.get() &&
ld->get_right()->is_match(
rmrd->get_right())) {
4746 return fma(
ld->get_left(),
this->middle,
rmrd->get_left()*
rm->get_left())/
ld->get_right();
4751 if (
rm.get() &&
md.get()) {
4753 if (
rmld.get() &&
md->get_right()->is_match(
rmld->get_right())) {
4754 return fma(this->
left,
md->get_left(),
rmld->get_left()*
rm->get_right())/
md->get_right();
4757 if (
rmrd.get() &&
md->get_right()->is_match(
rmrd->get_right())) {
4758 return fma(this->
left,
md->get_left(),
rmrd->get_left()*
rm->get_left())/
md->get_right();
4764 if (
rd.get() &&
lm.get()) {
4766 if (
lmld.get() &&
rd->get_right()->is_match(
lmld->get_right())) {
4767 return fma(
lmld->get_left()*
lm->get_right(),
this->middle,
rd->get_left())/
rd->get_right();
4770 if (
lmrd.get() &&
rd->get_right()->is_match(
lmrd->get_right())) {
4771 return fma(
lmld->get_left()*
lm->get_left(),
this->middle,
rd->get_left())/
rd->get_right();
4776 if (
rd.get() &&
mm.get()) {
4778 if (
mmld.get() &&
rd->get_right()->is_match(
mmld->get_right())) {
4779 return fma(this->
left,
mmld->get_left()*
mm->get_right(),
rd->get_left())/
rd->get_right();
4782 if (
mmrd.get() &&
rd->get_right()->is_match(
mmrd->get_right())) {
4783 return fma(this->
left,
mmrd->get_left()*
mm->get_left(),
rd->get_left())/
rd->get_right();
4795 if (
md.get() &&
rm.get()) {
4799 if (
rmlmld.get() &&
rmlmld->get_right()->is_match(
md->get_right())) {
4801 rmlmld->get_left()*
rmlm->get_right()*
rm->get_right())/
md->get_right();
4804 if (
rmlmrd.get() &&
rmlmrd->get_right()->is_match(
md->get_right())) {
4806 rmlmrd->get_left()*
rmlm->get_left()*
rm->get_right())/
md->get_right();
4812 if (
rmrmld.get() &&
rmrmld->get_right()->is_match(
md->get_right())) {
4814 rmrmld->get_left()*
rmrm->get_right()*
rm->get_left())/
md->get_right();
4817 if (
rmrmrd.get() &&
rmrmrd->get_right()->is_match(
md->get_right())) {
4819 rmrmrd->get_left()*
rmrm->get_left()*
rm->get_left())/
md->get_right();
4822 }
else if (
ld.get() &&
rm.get()) {
4826 if (
rmlmld.get() &&
rmlmld->get_right()->is_match(
ld->get_right())) {
4827 return fma(
ld->get_left(),
this->middle,
4828 rmlmld->get_left()*
rmlm->get_right()*
rm->get_right())/
ld->get_right();
4831 if (
rmlmrd.get() &&
rmlmrd->get_right()->is_match(
ld->get_right())) {
4832 return fma(
ld->get_left(),
this->middle,
4833 rmlmrd->get_left()*
rmlm->get_right()*
rm->get_right())/
ld->get_right();
4839 if (
rmrmld.get() &&
rmrmld->get_right()->is_match(
ld->get_right())) {
4840 return fma(
ld->get_left(),
this->middle,
4841 rmrmld->get_left()*
rmrm->get_right()*
rm->get_left())/
ld->get_right();
4844 if (
rmrmrd.get() &&
rmrmrd->get_right()->is_match(
ld->get_right())) {
4845 return fma(
ld->get_left(),
this->middle,
4846 rmrmrd->get_left()*
rmrm->get_left()*
rm->get_left())/
ld->get_right();
4854 if (
le.get() && me.get()) {
4855 return exp(
le->get_arg() + me->get_arg()) + this->
right;
4860 if (
mm.get() &&
le.get()) {
4876 if (
lm.get() && me.get()) {
4885 return fma(
lm->get_right()*
this->middle,
4895 if (
lm.get() &&
mm.get()) {
4900 return fma(
lm->get_left()*
mm->get_left(),
4901 lm->get_right()*
mm->get_right(),
4906 return fma(
lm->get_left()*
mm->get_right(),
4907 lm->get_right()*
mm->get_left(),
4915 return fma(
lm->get_right()*
mm->get_left(),
4916 lm->get_left()*
mm->get_right(),
4921 return fma(
lm->get_right()*
mm->get_right(),
4922 lm->get_left()*
mm->get_left(),
4932 if (
lm.get() &&
md.get()) {
4937 return fma(
lm->get_left()*
md->get_left(),
4938 lm->get_right()/
md->get_right(),
4943 return fma(
lm->get_left()/
md->get_right(),
4944 lm->get_right()*
md->get_left(),
4952 return fma(
lm->get_right()*
md->get_left(),
4953 lm->get_left()/
md->get_right(),
4958 return fma(
lm->get_right()/
md->get_right(),
4959 lm->get_left()*
md->get_left(),
4969 if (
ld.get() &&
mm.get()) {
4974 return fma(
ld->get_left()*
mm->get_left(),
4975 mm->get_right()/
ld->get_right(),
4980 return fma(
ld->get_left()*
mm->get_right(),
4981 mm->get_left()/
ld->get_right(),
4989 return fma(
mm->get_left()/
ld->get_right(),
4990 ld->get_left()*
mm->get_right(),
4995 return fma(
mm->get_right()/
ld->get_right(),
4996 ld->get_left()*
mm->get_left(),
5006 if (
ld.get() &&
md.get()) {
5011 return ((
ld->get_left()*
md->get_left()) /
5012 (
ld->get_right()*
md->get_right())) +
5017 return fma(
ld->get_left()/
md->get_right(),
5018 md->get_left()/
ld->get_right(),
5026 return fma(
md->get_left()/
ld->get_right(),
5027 ld->get_left()/
md->get_right(),
5032 return ((
ld->get_left()*
md->get_left()) /
5033 (
ld->get_right()*
md->get_right())) +
5056 const size_t hash =
reinterpret_cast<size_t> (x.get());
5060 this->right->df(x));
5083 if (registers.find(
this) == registers.end()) {
5098 stream <<
" const ";
5099 jit::add_type<T> (stream);
5100 stream <<
" " << registers[
this] <<
" = ";
5102 stream <<
"(" << registers[
l.get()] <<
" == ";
5104 jit::add_type<T> (stream);
5109 stream <<
" || " << registers[
m.get()] <<
" == ";
5111 jit::add_type<T> (stream);
5116 stream <<
") ? " << registers[
r.get()] <<
" : ";
5119 stream << registers[
l.get()] <<
"*"
5120 << registers[
m.get()] <<
" + "
5121 << registers[
r.get()];
5124 << registers[
l.get()] <<
", "
5125 << registers[
m.get()] <<
", "
5126 << registers[
r.get()] <<
")";
5141 if (
this == x.get()) {
5147 return this->
left->is_match(
x_cast->get_left()) &&
5159 std::cout <<
"\\left(";
5162 std::cout <<
"\\left(";
5163 this->
left->to_latex();
5164 std::cout <<
"\\right)";
5166 this->
left->to_latex();
5171 std::cout <<
"\\left(";
5172 this->
middle->to_latex();
5173 std::cout <<
"\\right)";
5175 this->
middle->to_latex();
5178 this->
right->to_latex();
5179 std::cout <<
"\\right)";
5189 return fma(this->
left->remove_pseudo(),
5190 this->middle->remove_pseudo(),
5191 this->right->remove_pseudo());
5205 if (registers.find(
this) == registers.end()) {
5207 registers[
this] =
name;
5208 stream <<
" " <<
name
5209 <<
" [label = \"fma\", shape = oval, style = filled, fillcolor = blue, fontcolor = white];" << std::endl;
5211 auto l = this->
left->to_vizgraph(stream, registers);
5212 stream <<
" " <<
name <<
" -- " << registers[
l.get()] <<
";" << std::endl;
5213 auto m = this->
middle->to_vizgraph(stream, registers);
5214 stream <<
" " <<
name <<
" -- " << registers[
m.get()] <<
";" << std::endl;
5215 auto r = this->
right->to_vizgraph(stream, registers);
5216 stream <<
" " <<
name <<
" -- " << registers[
r.get()] <<
";" << std::endl;
5233 template<jit::
float_scalar T,
bool SAFE_MATH=false>
5237 auto temp = std::make_shared<fma_node<T, SAFE_MATH>> (
l,
m,
r)->reduce();
5239 for (
size_t i =
temp->get_hash();
5249#if defined(__clang__) || defined(__GNUC__)
5252 assert(
false &&
"Should never reach.");
5270 template<jit::
float_scalar T, jit::
float_scalar L,
bool SAFE_MATH=false>
5291 template<jit::
float_scalar T, jit::
float_scalar M,
bool SAFE_MATH=false>
5312 template<jit::
float_scalar T, jit::
float_scalar R,
bool SAFE_MATH=false>
5334 template<jit::
float_scalar T, jit::
float_scalar L, jit::
float_scalar M,
bool SAFE_MATH=false>
5357 template<jit::
float_scalar T, jit::
float_scalar M, jit::
float_scalar R,
bool SAFE_MATH=false>
5380 template<jit::
float_scalar T, jit::
float_scalar L, jit::
float_scalar R,
bool SAFE_MATH=false>
5389 template<jit::
float_scalar T,
bool SAFE_MATH=false>
5401 template<jit::
float_scalar T,
bool SAFE_MATH=false>
5403 return std::dynamic_pointer_cast<fma_node<T, SAFE_MATH>> (x);
Class representing a generic buffer.
Definition backend.hpp:29
void subtract_row(const buffer< T > &x)
Subtract row operation.
Definition backend.hpp:407
void multiply_row(const buffer< T > &x)
Multiply row operation.
Definition backend.hpp:479
void add_col(const buffer< T > &x)
Add col operation.
Definition backend.hpp:371
void divide_col(const buffer< T > &x)
Divide col operation.
Definition backend.hpp:587
void add_row(const buffer< T > &x)
Add row operation.
Definition backend.hpp:335
void subtract_col(const buffer< T > &x)
Subtract col operation.
Definition backend.hpp:443
void multiply_col(const buffer< T > &x)
Multiply col operation.
Definition backend.hpp:515
void divide_row(const buffer< T > &x)
Divide row operation.
Definition backend.hpp:551
bool is_zero() const
Is every element zero.
Definition backend.hpp:141
An addition node.
Definition arithmetic.hpp:132
virtual backend::buffer< T > evaluate()
Evaluate the results of addition.
Definition arithmetic.hpp:166
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map ®isters)
Convert the node to vizgraph.
Definition arithmetic.hpp:741
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition arithmetic.hpp:623
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Query if the nodes match.
Definition arithmetic.hpp:677
virtual shared_leaf< T, SAFE_MATH > remove_pseudo()
Remove pseudo variable nodes.
Definition arithmetic.hpp:726
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map ®isters, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition arithmetic.hpp:645
add_node(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Construct an addition node.
Definition arithmetic.hpp:154
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduce an addition node.
Definition arithmetic.hpp:177
virtual void to_latex() const
Convert the node to latex.
Definition arithmetic.hpp:699
Class representing a branch node.
Definition node.hpp:1165
shared_leaf< T, SAFE_MATH > right
Right branch of the tree.
Definition node.hpp:1170
shared_leaf< T, SAFE_MATH > left
Left branch of the tree.
Definition node.hpp:1168
A division node.
Definition arithmetic.hpp:2769
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduce an division node.
Definition arithmetic.hpp:2822
virtual shared_leaf< T, SAFE_MATH > remove_pseudo()
Remove pseudo variable nodes.
Definition arithmetic.hpp:3586
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map ®isters, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition arithmetic.hpp:3508
divide_node(shared_leaf< T, SAFE_MATH > n, shared_leaf< T, SAFE_MATH > d)
Construct an addition node.
Definition arithmetic.hpp:2791
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Query if the nodes match.
Definition arithmetic.hpp:3556
virtual backend::buffer< T > evaluate()
Evaluate the results of division.
Definition arithmetic.hpp:2803
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map ®isters)
Convert the node to vizgraph.
Definition arithmetic.hpp:3601
virtual void to_latex() const
Convert the node to latex.
Definition arithmetic.hpp:3573
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition arithmetic.hpp:3485
A fused multiply add node.
Definition arithmetic.hpp:3736
virtual shared_leaf< T, SAFE_MATH > remove_pseudo()
Remove pseudo variable nodes.
Definition arithmetic.hpp:5187
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition arithmetic.hpp:5051
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map ®isters, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition arithmetic.hpp:5079
virtual backend::buffer< T > evaluate()
Evaluate the results of fused multiply add.
Definition arithmetic.hpp:3812
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map ®isters)
Convert the node to vizgraph.
Definition arithmetic.hpp:5203
fma_node(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > m, shared_leaf< T, SAFE_MATH > r)
Construct a fused multiply add node.
Definition arithmetic.hpp:3798
virtual void to_latex() const
Convert the node to latex.
Definition arithmetic.hpp:5158
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduce a fused multiply add node.
Definition arithmetic.hpp:3831
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Query if the nodes match.
Definition arithmetic.hpp:5140
Class representing a node leaf.
Definition node.hpp:364
virtual void endline(std::ostringstream &stream, const jit::register_usage &usage) const final
End a line in the kernel source.
Definition node.hpp:639
virtual backend::buffer< T > evaluate()=0
Evaluate method.
std::map< size_t, std::shared_ptr< leaf_node< T, SAFE_MATH > > > df_cache
Cache derivative terms.
Definition node.hpp:371
virtual bool has_pseudo() const
Query if the node contains pseudo variables.
Definition node.hpp:620
const size_t hash
Hash for node.
Definition node.hpp:367
A multiplication node.
Definition arithmetic.hpp:1720
virtual backend::buffer< T > evaluate()
Evaluate the results of multiplication.
Definition arithmetic.hpp:1859
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition arithmetic.hpp:2493
virtual shared_leaf< T, SAFE_MATH > remove_pseudo()
Remove pseudo variable nodes.
Definition arithmetic.hpp:2621
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduce an multiplication node.
Definition arithmetic.hpp:1884
multiply_node(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Construct a multiplication node.
Definition arithmetic.hpp:1848
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Query if the nodes match.
Definition arithmetic.hpp:2572
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map ®isters, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition arithmetic.hpp:2516
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map ®isters)
Convert the node to vizgraph.
Definition arithmetic.hpp:2636
virtual void to_latex() const
Convert the node to latex.
Definition arithmetic.hpp:2594
A subtraction node.
Definition arithmetic.hpp:879
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map ®isters, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition arithmetic.hpp:1475
subtract_node(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Construct a subtraction node.
Definition arithmetic.hpp:901
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition arithmetic.hpp:1453
virtual shared_leaf< T, SAFE_MATH > remove_pseudo()
Remove pseudo variable nodes.
Definition arithmetic.hpp:1551
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Query if the nodes match.
Definition arithmetic.hpp:1507
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduce an subtraction node.
Definition arithmetic.hpp:924
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map ®isters)
Convert the node to vizgraph.
Definition arithmetic.hpp:1566
virtual backend::buffer< T > evaluate()
Evaluate the results of subtraction.
Definition arithmetic.hpp:913
virtual void to_latex() const
Convert the node to latex.
Definition arithmetic.hpp:1524
Class representing a triple branch node.
Definition node.hpp:1289
shared_leaf< T, SAFE_MATH > middle
Middle branch of the tree.
Definition node.hpp:1292
Complex scalar concept.
Definition register.hpp:24
subroutine assert(test, message)
Assert check.
Definition f_binding_test.f90:38
buffer< T > fma(buffer< T > &a, buffer< T > &b, buffer< T > &c)
Fused multiply add operation.
Definition backend.hpp:950
Name space for graph nodes.
Definition arithmetic.hpp:13
shared_piecewise_2D< T, SAFE_MATH > piecewise_2D_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a piecewise 2D node.
Definition piecewise.hpp:1406
shared_leaf< T, SAFE_MATH > operator*(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Build multiply node from two leaves.
Definition arithmetic.hpp:2698
std::shared_ptr< add_node< T, SAFE_MATH > > shared_add
Convenience type alias for shared add nodes.
Definition arithmetic.hpp:851
shared_pow< T, SAFE_MATH > pow_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a power node.
Definition math.hpp:1416
shared_subtract< T, SAFE_MATH > subtract_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a subtract node.
Definition arithmetic.hpp:1706
constexpr shared_leaf< T, SAFE_MATH > zero()
Forward declare for zero.
Definition node.hpp:986
shared_leaf< T, SAFE_MATH > operator/(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Build divide node from two leaves.
Definition arithmetic.hpp:3663
shared_leaf< T, SAFE_MATH > multiply(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Build multiply node from two leaves.
Definition arithmetic.hpp:2664
shared_add< T, SAFE_MATH > add_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a add node.
Definition arithmetic.hpp:863
bool is_greater_exponent(shared_leaf< T, SAFE_MATH > a, shared_leaf< T, SAFE_MATH > b)
Check if the exponent is greater than the other.
Definition arithmetic.hpp:111
shared_leaf< T, SAFE_MATH > operator+(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Build add node from two leaves.
Definition arithmetic.hpp:806
std::shared_ptr< divide_node< T, SAFE_MATH > > shared_divide
Convenience type alias for shared divide nodes.
Definition arithmetic.hpp:3708
shared_leaf< T, SAFE_MATH > pow(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Build power node.
Definition math.hpp:1352
shared_sine< T, SAFE_MATH > sin_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a sine node.
Definition trigonometry.hpp:262
shared_divide< T, SAFE_MATH > divide_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a divide node.
Definition arithmetic.hpp:3720
shared_piecewise_1D< T, SAFE_MATH > piecewise_1D_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a piecewise 1D node.
Definition piecewise.hpp:629
shared_leaf< T, SAFE_MATH > operator-(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Build subtract node from two leaves.
Definition arithmetic.hpp:1630
shared_leaf< T, SAFE_MATH > divide(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Build divide node from two leaves.
Definition arithmetic.hpp:3629
shared_multiply< T, SAFE_MATH > multiply_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a multiply node.
Definition arithmetic.hpp:2755
shared_leaf< T, SAFE_MATH > exp(shared_leaf< T, SAFE_MATH > x)
Define exp convenience function.
Definition math.hpp:544
shared_exp< T, SAFE_MATH > exp_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a exp node.
Definition math.hpp:578
bool is_constant_promotable(shared_leaf< T, SAFE_MATH > a, shared_leaf< T, SAFE_MATH > b)
Check if the constants are promotable.
Definition arithmetic.hpp:53
shared_leaf< T, SAFE_MATH > fma(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > m, shared_leaf< T, SAFE_MATH > r)
Build fused multiply add node.
Definition arithmetic.hpp:5234
shared_constant< T, SAFE_MATH > constant_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a constant node.
Definition node.hpp:1034
bool is_variable_combinable(shared_leaf< T, SAFE_MATH > a, shared_leaf< T, SAFE_MATH > b)
Check if the variable is combinable.
Definition arithmetic.hpp:75
constexpr T i
Convenience type for imaginary constant.
Definition node.hpp:1018
shared_leaf< T, SAFE_MATH > subtract(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Build subtract node from two leaves.
Definition arithmetic.hpp:1595
std::shared_ptr< fma_node< T, SAFE_MATH > > shared_fma
Convenience type alias for shared add nodes.
Definition arithmetic.hpp:5390
std::shared_ptr< leaf_node< T, SAFE_MATH > > shared_leaf
Convenience type alias for shared leaf nodes.
Definition node.hpp:676
bool is_variable_promotable(shared_leaf< T, SAFE_MATH > a, shared_leaf< T, SAFE_MATH > b)
Check if the variable is variable is promotable.
Definition arithmetic.hpp:91
std::shared_ptr< multiply_node< T, SAFE_MATH > > shared_multiply
Convenience type alias for shared multiply nodes.
Definition arithmetic.hpp:2743
shared_leaf< T, SAFE_MATH > add(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Build add node from two leaves.
Definition arithmetic.hpp:772
shared_fma< T, SAFE_MATH > fma_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a fma node.
Definition arithmetic.hpp:5402
shared_cosine< T, SAFE_MATH > cos_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a cosine node.
Definition trigonometry.hpp:520
bool is_constant_combinable(shared_leaf< T, SAFE_MATH > a, shared_leaf< T, SAFE_MATH > b)
Check if nodes are constant combinable.
Definition arithmetic.hpp:25
std::shared_ptr< subtract_node< T, SAFE_MATH > > shared_subtract
Convenience type alias for shared subtract nodes.
Definition arithmetic.hpp:1694
std::string format_to_string(const T value)
Convert a value to a string while avoiding locale.
Definition register.hpp:212
std::map< void *, size_t > register_usage
Type alias for counting register usage.
Definition register.hpp:259
std::map< void *, std::string > register_map
Type alias for mapping node pointers to register names.
Definition register.hpp:257
std::string to_string(const char prefix, const NODE *pointer)
Convert a graph::leaf_node pointer to a string.
Definition register.hpp:246
Base nodes of graph computation framework.
void piecewise_1D()
Tests for 1D piecewise nodes.
Definition piecewise_test.cpp:80
void piecewise_2D()
Tests for 2D piecewise nodes.
Definition piecewise_test.cpp:306