Graph Framework
Loading...
Searching...
No Matches
arithmetic.hpp
Go to the documentation of this file.
1//------------------------------------------------------------------------------
6//------------------------------------------------------------------------------
7
8#ifndef arithmetic_h
9#define arithmetic_h
10
11#include "node.hpp"
12
13namespace graph {
14//------------------------------------------------------------------------------
23//------------------------------------------------------------------------------
24 template<jit::float_scalar T, bool SAFE_MATH=false>
27 if (a->is_constant() && b->is_constant()) {
28 auto a1 = piecewise_1D_cast(a);
29 auto a2 = piecewise_2D_cast(a);
30 auto b2 = piecewise_2D_cast(b);
31
32 return constant_cast(a).get() ||
33 constant_cast(b).get() ||
34 (a1.get() && a1->is_arg_match(b)) ||
35 (a2.get() && a2->is_arg_match(b)) ||
36 (a2.get() && (a2->is_row_match(b) || a2->is_col_match(b))) ||
37 (b2.get() && (b2->is_row_match(a) || b2->is_col_match(a)));
38 }
39 return false;
40 }
41
42//------------------------------------------------------------------------------
51//------------------------------------------------------------------------------
52 template<jit::float_scalar T, bool SAFE_MATH=false>
55 auto b1 = piecewise_1D_cast(b);
56 auto b2 = piecewise_2D_cast(b);
57
58 return a->is_constant() &&
59 (!b->is_constant() ||
60 (constant_cast(a).get() && (b1.get() || b2.get())) ||
61 (piecewise_1D_cast(a).get() && b2.get()));
62 }
63
64//------------------------------------------------------------------------------
73//------------------------------------------------------------------------------
74 template<jit::float_scalar T, bool SAFE_MATH=false>
77 return a->is_power_base_match(b);
78 }
79
80//------------------------------------------------------------------------------
89//------------------------------------------------------------------------------
90 template<jit::float_scalar T, bool SAFE_MATH=false>
93 return !b->is_constant() &&
94 (a->is_all_variables() &&
95 (!b->is_all_variables() ||
96 (b->is_all_variables() &&
97 a->get_complexity() < b->get_complexity())));
98 }
99
100//------------------------------------------------------------------------------
109//------------------------------------------------------------------------------
110 template<jit::float_scalar T, bool SAFE_MATH=false>
113 auto ae = constant_cast(a->get_power_exponent());
114 auto be = constant_cast(b->get_power_exponent());
115
116 return ae.get() && be.get() &&
117 std::abs(ae->evaluate().at(0)) > std::abs(be->evaluate().at(0));
118 }
119
120//******************************************************************************
121// Add node.
122//******************************************************************************
123//------------------------------------------------------------------------------
130//------------------------------------------------------------------------------
131 template<jit::float_scalar T, bool SAFE_MATH=false>
132 class add_node final : public branch_node<T, SAFE_MATH> {
133 private:
134//------------------------------------------------------------------------------
140//------------------------------------------------------------------------------
141 static std::string to_string(leaf_node<T, SAFE_MATH> *l,
143 return jit::format_to_string(reinterpret_cast<size_t> (l)) + "+" +
144 jit::format_to_string(reinterpret_cast<size_t> (r));
145 }
146
147 public:
148//------------------------------------------------------------------------------
153//------------------------------------------------------------------------------
158
159//------------------------------------------------------------------------------
165//------------------------------------------------------------------------------
167 backend::buffer<T> l_result = this->left->evaluate();
168 backend::buffer<T> r_result = this->right->evaluate();
169 return l_result + r_result;
170 }
171
172//------------------------------------------------------------------------------
176//------------------------------------------------------------------------------
178// Constant reductions.
179 auto l = constant_cast(this->left);
180 auto r = constant_cast(this->right);
181
182 if (l.get() && l->is(0)) {
183 return this->right;
184 } else if (r.get() && r->is(0)) {
185 return this->left;
186 } else if (l.get() && r.get()) {
187 return constant<T, SAFE_MATH> (this->evaluate());
188 } else if (r.get() && !l.get()) {
189 return this->right + this->left;
190 }
191
192 auto pl1 = piecewise_1D_cast(this->left);
193 auto pr1 = piecewise_1D_cast(this->right);
194
195 if (pl1.get() && (r.get() || pl1->is_arg_match(this->right))) {
196 return piecewise_1D(this->evaluate(), pl1->get_arg(),
197 pl1->get_scale(), pl1->get_offset());
198 } else if (pr1.get() && (l.get() || pr1->is_arg_match(this->left))) {
199 return piecewise_1D(this->evaluate(), pr1->get_arg(),
200 pr1->get_scale(), pr1->get_offset());
201 }
202
203 auto pl2 = piecewise_2D_cast(this->left);
204 auto pr2 = piecewise_2D_cast(this->right);
205
206 if (pl2.get() && (r.get() || pl2->is_arg_match(this->right))) {
207 return piecewise_2D(this->evaluate(),
208 pl2->get_num_columns(),
209 pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
210 pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
211 } else if (pr2.get() && (l.get() || pr2->is_arg_match(this->left))) {
212 return piecewise_2D(this->evaluate(),
213 pr2->get_num_columns(),
214 pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
215 pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
216 }
217
218// Combine 2D and 1D piecewise constants if a row or column matches.
219 if (pr2.get() && pr2->is_row_match(this->left)) {
220 backend::buffer<T> result = pl1->evaluate();
221 result.add_row(pr2->evaluate());
222 return piecewise_2D(result,
223 pr2->get_num_columns(),
224 pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
225 pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
226 } else if (pr2.get() && pr2->is_col_match(this->left)) {
227 backend::buffer<T> result = pl1->evaluate();
228 result.add_col(pr2->evaluate());
229 return piecewise_2D(result,
230 pr2->get_num_columns(),
231 pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
232 pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
233 } else if (pl2.get() && pl2->is_row_match(this->right)) {
234 backend::buffer<T> result = pl2->evaluate();
235 result.add_row(pr1->evaluate());
236 return piecewise_2D(result,
237 pl2->get_num_columns(),
238 pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
239 pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
240 } else if (pl2.get() && pl2->is_col_match(this->right)) {
241 backend::buffer<T> result = pl2->evaluate();
242 result.add_col(pr1->evaluate());
243 return piecewise_2D(result,
244 pl2->get_num_columns(),
245 pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
246 pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
247 }
248
249// Identity reductions.
250 if (this->left->is_match(this->right)) {
251 return 2.0*this->left;
252 }
253
254// Common factor reduction. If the left and right are both multiply nodes check
255// for a common factor. So you can change a*b + a*c -> a*(b + c).
256 auto lm = multiply_cast(this->left);
257 auto rm = multiply_cast(this->right);
258
259// v1 + -c*v2 -> v1 - c*v2
260// -c*v1 + v2 -> v2 - c*v1
261 if (rm.get() &&
262 rm->get_left()->is_constant() &&
263 rm->get_left()->evaluate().is_negative()) {
264 return this->left - (-this->right);
265 } else if (rm.get() &&
266 rm->get_left()->is_constant() &&
267 rm->get_left()->evaluate().is_negative()) {
268 return this->right - (-this->left);
269 }
270
271// a*b + c -> fma(a,b,c)
272// a + b*c -> fma(b,c,a)
273 if (lm.get()) {
274 return fma(lm->get_left(), lm->get_right(), this->right);
275 } else if (rm.get()) {
276 return fma(rm->get_left(), rm->get_right(), this->left);
277 }
278
279// Common denominator reduction. If the left and right are both divide nodes
280// for a common denominator. So you can change a/b + c/b -> (a + c)/d.
281 auto ld = divide_cast(this->left);
282 auto rd = divide_cast(this->right);
283
284// c is a constant.
285// a + -c/b -> a - c/b
286// a + (-c*d)/b -> a - (c*d)/b
287// -c/a + b -> b - c/a
288// (-c*d)/a + b -> b - (c*d)/a
289 if (rd.get()) {
290 auto rdlm = multiply_cast(rd->get_left());
291 if ((rd->get_left()->is_constant() &&
292 rd->get_left()->evaluate().is_negative()) ||
293 (rdlm.get() &&
294 (rdlm->get_left()->is_constant() &&
295 rdlm->get_left()->evaluate().is_negative()))) {
296 return this->left - (-rd->get_left())/rd->get_right();
297 }
298 } else if (ld.get()) {
299 auto ldlm = multiply_cast(ld->get_left());
300 if ((ld->get_left()->is_constant() &&
301 ld->get_left()->evaluate().is_negative()) ||
302 (ldlm.get() &&
303 (ldlm->get_left()->is_constant() &&
304 ldlm->get_left()->evaluate().is_negative()))) {
305 return this->right - (-ld->get_left())/ld->get_right();
306 }
307 }
308
309 if (ld.get() && rd.get()) {
310 if (ld->get_right()->is_match(rd->get_right())) {
311 return (ld->get_left() + rd->get_left())/ld->get_right();
312 }
313
314 auto ldlm = multiply_cast(ld->get_left());
315 auto rdlm = multiply_cast(rd->get_left());
316// a/b - c*a/d -> (1/b - c/d)*a
317// a/b - a*c/d -> (1/b - c/d)*a
318// c*a/b - a/d -> (c/b - 1/d)*a
319// a*c/b - a/d -> (c/b - 1/d)*a
320 if (rdlm.get()) {
321 if (ld->get_left()->is_match(rdlm->get_left())) {
322 return (1.0/ld->get_right() +
323 rdlm->get_right()/rd->get_right())*rdlm->get_left();
324 } else if (ld->get_left()->is_match(rdlm->get_right())) {
325 return (1.0/ld->get_right() +
326 rdlm->get_left()/rd->get_right())*rdlm->get_right();
327 }
328 } else if (ldlm.get()) {
329 if (rd->get_left()->is_match(ldlm->get_left())) {
330 return (ldlm->get_right()/ld->get_right() +
331 1.0/rd->get_right())*ldlm->get_left();
332 } else if (rd->get_left()->is_match(ldlm->get_right())) {
333 return (ldlm->get_left()/ld->get_right() +
334 1.0/rd->get_right())*ldlm->get_right();
335 }
336 }
337
338// c1*a/b + c2*a/d = c3*(a/b + c4*a/d)
339// a*b/c + d*b/e -> (a/c + d/e)*b
340// Make sure we prevent combining constants when we just need to factor out a
341// common term.
342// c1*a/b + c2*a/d -> (c1/b + c2/d)*a
343 if (ldlm.get() && rdlm.get()) {
344 if (is_constant_combinable(ldlm->get_left(),
345 rdlm->get_left()) &&
346 !ldlm->get_right()->is_match(rdlm->get_right())) {
347 return (ldlm->get_right()/ld->get_right() +
348 rdlm->get_left()/ldlm->get_left() *
349 rdlm->get_right()/rd->get_right())*ldlm->get_left();
350 }
351
352 if (ldlm->get_right()->is_match(rdlm->get_right())) {
353 return (ldlm->get_left()/ld->get_right() +
354 rdlm->get_left()/rd->get_right())*ldlm->get_right();
355 } else if (ldlm->get_right()->is_match(rdlm->get_left())) {
356 return (ldlm->get_left()/ld->get_right() +
357 rdlm->get_right()/rd->get_right())*ldlm->get_right();
358 } else if (ldlm->get_left()->is_match(rdlm->get_right())) {
359 return (ldlm->get_right()/ld->get_right() +
360 rdlm->get_left()/rd->get_right())*ldlm->get_left();
361 } else if (ldlm->get_left()->is_match(rdlm->get_left())) {
362 return (ldlm->get_right()/ld->get_right() +
363 rdlm->get_right()/rd->get_right())*ldlm->get_left();
364 }
365 }
366
367// (a/(c*b) + d/(e*c)) -> (a/b + d/e)/c
368// (a/(b*c) + d/(e*c)) -> (a/b + d/e)/c
369// (a/(c*b) + d/(c*e)) -> (a/b + d/e)/c
370// (a/(b*c) + d/(c*e)) -> (a/b + d/e)/c
371 auto ldrm = multiply_cast(ld->get_right());
372 auto rdrm = multiply_cast(rd->get_right());
373 if (ldrm.get() && rdrm.get()) {
374 if (ldrm->get_right()->is_match(rdrm->get_right())) {
375 return (ld->get_left()/ldrm->get_left() +
376 rd->get_left()/rdrm->get_left())/ldrm->get_right();
377 } else if (ldrm->get_right()->is_match(rdrm->get_left())) {
378 return (ld->get_left()/ldrm->get_left() +
379 rd->get_left()/rdrm->get_right())/ldrm->get_right();
380 } else if (ldrm->get_left()->is_match(rdrm->get_right())) {
381 return (ld->get_left()/ldrm->get_right() +
382 rd->get_left()/rdrm->get_left())/ldrm->get_left();
383 } else if (ldrm->get_left()->is_match(rdrm->get_left())) {
384 return (ld->get_left()/ldrm->get_right() +
385 rd->get_left()/rdrm->get_right())/ldrm->get_left();
386 }
387 }
388
389// a/b + c/(b*d) -> (a*b + c)/(b*d)
390// a/b + c/(d*b) -> (a*b + c)/(b*d)
391// a/(b*d) + c/b -> (c*b + a)/(b*d)
392// a/(d*b) + c/b -> (c*b + a)/(b*d)
393 if (rdrm.get()) {
394 if (ld->get_right()->is_match(rdrm->get_left())) {
395 return fma(ld->get_left(),
396 rdrm->get_right(),
397 rd->get_left()) /
398 rd->get_right();
399 } else if (ld->get_right()->is_match(rdrm->get_right())) {
400 return fma(ld->get_left(),
401 rdrm->get_left(),
402 rd->get_left()) /
403 rd->get_right();
404 }
405 } else if (ldrm.get()) {
406 if (rd->get_right()->is_match(ldrm->get_left())) {
407 return fma(rd->get_left(),
408 ldrm->get_right(),
409 ld->get_left()) /
410 ld->get_right();
411 } else if (rd->get_right()->is_match(ldrm->get_right())) {
412 return fma(rd->get_left(),
413 ldrm->get_left(),
414 ld->get_left()) /
415 ld->get_right();
416 }
417 }
418 }
419
420// Chained addition reductions.
421// a + (a + b) = fma(2,a,b)
422// a + (b + a) = fma(2,a,b)
423// (a + b) + a = fma(2,a,b)
424// (b + a) + a = fma(2,a,b)
425 auto la = add_cast(this->left);
426 if (la.get()) {
427 if (this->right->is_match(la->get_left())) {
428 return fma(2.0, this->right, la->get_right());
429 } else if (this->right->is_match(la->get_right())) {
430 return fma(2.0, this->right, la->get_left());
431 }
432 }
433 auto ra = add_cast(this->right);
434 if (ra.get()) {
435 if (this->left->is_match(ra->get_left())) {
436 return fma(2.0, this->left, ra->get_right());
437 } else if (this->left->is_match(ra->get_right())) {
438 return fma(2.0, this->left, ra->get_left());
439 }
440 }
441
442// Add subtraction reduction
443// (c1 - a) + c2 -> c3 - a
444// (a - c1) + c2 -> a + c2
445// These reductions are handled by moving constants to the right.
446// (a - b) + a -> 2a - b
447// (b - a) + a -> b
448 auto ls = subtract_cast(this->left);
449 if (ls.get()) {
450 if (ls->get_left()->is_match(this->right)) {
451 return static_cast<T> (2.0)*this->right - ls->get_right();
452 } else if (ls->get_right()->is_match(this->right)) {
453 return ls->get_left();
454 }
455 }
456
457// c1 + (c2 - a) -> c3 - a
458// c1 + (a - c2) -> c3 + a
459// a + (a - b) -> 2a - b
460// a + (b - a) -> b
461 auto rs = subtract_cast(this->right);
462 if (rs.get()) {
463 if (is_constant_combinable(this->left, rs->get_left())) {
464 return (this->left + rs->get_left()) - rs->get_right();
465 } else if (is_constant_combinable(this->left, rs->get_right())) {
466 return (this->left - rs->get_right()) + rs->get_left();
467 } else if (this->left->is_match(rs->get_left())) {
468 return static_cast<T> (2.0)*this->left - rs->get_right();
469 } else if (this->left->is_match(rs->get_right())) {
470 return rs->get_left();
471 }
472 }
473
474// Move cases like
475// (c1 + c2/x) + c3/y -> c1 + (c2/x + c3/y)
476// (c1 - c2/x) + c3/y -> c1 + (c3/y - c2/x)
477// in case of common denominators.
478 if (rd.get()) {
479 if (la.get() && divide_cast(la->get_right()).get()) {
480 return la->get_left() + (la->get_right() + this->right);
481 }
482
483 auto ls = subtract_cast(this->left);
484 if (ls.get() && divide_cast(ls->get_right()).get()) {
485 return ls->get_left() + (this->right - ls->get_right());
486 }
487 }
488
489 auto lfma = fma_cast(this->left);
490 auto rfma = fma_cast(this->right);
491 if (lfma.get()) {
492// fma(c,d,e) + a -> fma(c,d,e + a)
493 return fma(lfma->get_left(),
494 lfma->get_middle(),
495 lfma->get_right() + this->right);
496 } else if (rfma.get()) {
497// a + fma(c,d,e) -> fma(c,d,a + e)
498 return fma(rfma->get_left(),
499 rfma->get_middle(),
500 this->left + rfma->get_right());
501 }
502
503// fma(b,a,d) + fma(c,a,e) -> fma(a,b + c, d + e)
504// fma(a,b,d) + fma(c,a,e) -> fma(a,b + c, d + e)
505// fma(b,a,d) + fma(a,c,e) -> fma(a,b + c, d + e)
506// fma(a,b,d) + fma(a,c,e) -> fma(a,b + c, d + e)
507 if (lfma.get() && rfma.get()) {
508 if (lfma->get_middle()->is_match(rfma->get_middle())) {
509 return fma(lfma->get_middle(),
510 lfma->get_left() + rfma->get_left(),
511 lfma->get_right() + rfma->get_right());
512 } else if (lfma->get_left()->is_match(rfma->get_middle())) {
513 return fma(lfma->get_left(),
514 lfma->get_middle() + rfma->get_left(),
515 lfma->get_right() + rfma->get_right());
516 } else if (lfma->get_middle()->is_match(rfma->get_left())) {
517 return fma(lfma->get_middle(),
518 lfma->get_left() + rfma->get_middle(),
519 lfma->get_right() + rfma->get_right());
520 } else if (lfma->get_left()->is_match(rfma->get_left())) {
521 return fma(lfma->get_left(),
522 lfma->get_middle() + rfma->get_middle(),
523 lfma->get_right() + rfma->get_right());
524 }
525 }
526
527 auto pl = pow_cast(this->left);
528 auto pr = pow_cast(this->right);
529
530// (a*b)^c + (a*d)^c -> a^c*(b^c + d^c)
531// (b*a)^c + (a*d)^c -> a^c*(b^c + d^c)
532// (a*b)^c + (d*a)^c -> a^c*(b^c + d^c)
533// (b*a)^c + (d*a)^c -> a^c*(b^c + d^c)
534 if (pl.get() && pr.get() &&
535 pl->get_right()->is_match(pr->get_right())) {
536 auto plm = multiply_cast(pl->get_left());
537 auto prm = multiply_cast(pr->get_left());
538 if (plm.get() && prm.get()) {
539 if (plm->get_left()->is_match(prm->get_left())) {
540 return pow(plm->get_left(), pl->get_right())*
541 (pow(plm->get_right(), pl->get_right()) +
542 pow(prm->get_right(), pl->get_right()));
543 } else if (plm->get_left()->is_match(prm->get_right())) {
544 return pow(plm->get_left(), pl->get_right())*
545 (pow(plm->get_right(), pl->get_right()) +
546 pow(prm->get_left(), pl->get_right()));
547 } else if (plm->get_right()->is_match(prm->get_left())) {
548 return pow(plm->get_right(), pl->get_right())*
549 (pow(plm->get_left(), pl->get_right()) +
550 pow(prm->get_right(), pl->get_right()));
551 } else if (plm->get_right()->is_match(prm->get_right())) {
552 return pow(plm->get_right(), pl->get_right())*
553 (pow(plm->get_left(), pl->get_right()) +
554 pow(prm->get_left(), pl->get_right()));
555 }
556 }
557
558// cos(x)^2 + sin(x)^2 -> 1
559// sin(x)^2 + cos(x)^2 -> 1
560 auto plrc = constant_cast(pl->get_right());
561 if (plrc.get() && plrc->is(static_cast<T> (2.0))) {
562 auto pls = sin_cast(pl->get_left());
563 auto prc = cos_cast(pr->get_left());
564 auto plc = cos_cast(pl->get_left());
565 auto prs = sin_cast(pr->get_left());
566 if ((pls.get() && prc.get() && pls->get_arg()->is_match(prc->get_arg())) ||
567 (plc.get() && prs.get() && plc->get_arg()->is_match(prs->get_arg()))) {
568 return one<T, SAFE_MATH> ();
569 }
570 }
571 }
572
573// (a/y)^e + b/y^e -> (a^2 + b)/(y^e)
574// b/y^e + (a/y)^e -> (b + a^2)/(y^e)
575// (a/y)^e + (b/y)^e -> (a^2 + b^2)/(y^e)
576 if (pl.get() && rd.get()) {
577 auto rdp = pow_cast(rd->get_right());
578 if (rdp.get() && pl->get_right()->is_match(rdp->get_right())) {
579 auto plld = divide_cast(pl->get_left());
580 if (plld.get() &&
581 rdp->get_left()->is_match(plld->get_right())) {
582 return (pow(plld->get_left(), pl->get_right()) +
583 rd->get_left()) /
584 pow(rdp->get_left(), pl->get_right());
585 }
586 }
587 } else if (pr.get() && ld.get()) {
588 auto ldp = pow_cast(ld->get_right());
589 if (ldp.get() && pr->get_right()->is_match(ldp->get_right())) {
590 auto prld = divide_cast(pr->get_left());
591 if (prld.get() &&
592 ldp->get_left()->is_match(prld->get_right())) {
593 return (pow(prld->get_left(), pr->get_right()) +
594 ld->get_left()) /
595 pow(ldp->get_left(), pr->get_right());
596 }
597 }
598 } else if (pl.get() && pr.get()) {
599 if (pl->get_right()->is_match(pr->get_right())) {
600 auto pld = divide_cast(pl->get_left());
601 auto prd = divide_cast(pr->get_left());
602 if (pld.get() && prd.get() &&
603 pld->get_right()->is_match(prd->get_right())) {
604 return (pow(pld->get_left(), pl->get_right()) +
605 pow(prd->get_left(), pl->get_right())) /
606 pow(pld->get_right(), pl->get_right());
607 }
608 }
609 }
610
611 return this->shared_from_this();
612 }
613
614//------------------------------------------------------------------------------
621//------------------------------------------------------------------------------
624 if (this->is_match(x)) {
625 return one<T, SAFE_MATH> ();
626 }
627
628 const size_t hash = reinterpret_cast<size_t> (x.get());
629 if (this->df_cache.find(hash) == this->df_cache.end()) {
630 this->df_cache[hash] = this->left->df(x) + this->right->df(x);
631 }
632 return this->df_cache[hash];
633 }
634
635//------------------------------------------------------------------------------
643//------------------------------------------------------------------------------
645 compile(std::ostringstream &stream,
646 jit::register_map &registers,
648 const jit::register_usage &usage) {
649 if (registers.find(this) == registers.end()) {
650 shared_leaf<T, SAFE_MATH> l = this->left->compile(stream,
651 registers,
652 indices,
653 usage);
654 shared_leaf<T, SAFE_MATH> r = this->right->compile(stream,
655 registers,
656 indices,
657 usage);
658
659 registers[this] = jit::to_string('r', this);
660 stream << " const ";
661 jit::add_type<T> (stream);
662 stream << " " << registers[this] << " = "
663 << registers[l.get()] << " + "
664 << registers[r.get()];
665 this->endline(stream, usage);
666 }
667
668 return this->shared_from_this();
669 }
670
671//------------------------------------------------------------------------------
676//------------------------------------------------------------------------------
678 if (this == x.get()) {
679 return true;
680 }
681
682 auto x_cast = add_cast(x);
683 if (x_cast.get()) {
684// Addition is commutative.
685 if ((this->left->is_match(x_cast->get_left()) &&
686 this->right->is_match(x_cast->get_right())) ||
687 (this->right->is_match(x_cast->get_left()) &&
688 this->left->is_match(x_cast->get_right()))) {
689 return true;
690 }
691 }
692
693 return false;
694 }
695
696//------------------------------------------------------------------------------
698//------------------------------------------------------------------------------
699 virtual void to_latex() const {
700 bool l_brackets = add_cast(this->left).get() ||
701 subtract_cast(this->left).get();
702 bool r_brackets = add_cast(this->right).get() ||
703 subtract_cast(this->right).get();
704 if (l_brackets) {
705 std::cout << "\\left(";
706 }
707 this->left->to_latex();
708 if (l_brackets) {
709 std::cout << "\\right)";
710 }
711 std::cout << "+";
712 if (r_brackets) {
713 std::cout << "\\left(";
714 }
715 this->right->to_latex();
716 if (r_brackets) {
717 std::cout << "\\right)";
718 }
719 }
720
721//------------------------------------------------------------------------------
725//------------------------------------------------------------------------------
727 if (this->has_pseudo()) {
728 return this->left->remove_pseudo() +
729 this->right->remove_pseudo();
730 }
731 return this->shared_from_this();
732 }
733
734//------------------------------------------------------------------------------
740//------------------------------------------------------------------------------
741 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
742 jit::register_map &registers) {
743 if (registers.find(this) == registers.end()) {
744 const std::string name = jit::to_string('r', this);
745 registers[this] = name;
746 stream << " " << name
747 << " [label = \"+\", shape = oval, style = filled, fillcolor = blue, fontcolor = white];" << std::endl;
748
749 auto l = this->left->to_vizgraph(stream, registers);
750 stream << " " << name << " -- " << registers[l.get()] << ";" << std::endl;
751 auto r = this->right->to_vizgraph(stream, registers);
752 stream << " " << name << " -- " << registers[r.get()] << ";" << std::endl;
753 }
754
755 return this->shared_from_this();
756 }
757 };
758
759//------------------------------------------------------------------------------
770//------------------------------------------------------------------------------
771 template<jit::float_scalar T, bool SAFE_MATH=false>
774 auto temp = std::make_shared<add_node<T, SAFE_MATH>> (l, r)->reduce();
775// Test for hash collisions.
776 for (size_t i = temp->get_hash();
778 if (leaf_node<T, SAFE_MATH>::caches.nodes.find(i) ==
781 return temp;
782 } else if (temp->is_match(leaf_node<T, SAFE_MATH>::caches.nodes[i])) {
784 }
785 }
786#if defined(__clang__) || defined(__GNUC__)
788#else
789 assert(false && "Should never reach.");
790#endif
791 }
792
793//------------------------------------------------------------------------------
804//------------------------------------------------------------------------------
805 template<jit::float_scalar T, bool SAFE_MATH=false>
810
811//------------------------------------------------------------------------------
823//------------------------------------------------------------------------------
824 template<jit::float_scalar T, jit::float_scalar L, bool SAFE_MATH=false>
829
830//------------------------------------------------------------------------------
842//------------------------------------------------------------------------------
843 template<jit::float_scalar T, jit::float_scalar R, bool SAFE_MATH=false>
848
850 template<jit::float_scalar T, bool SAFE_MATH=false>
851 using shared_add = std::shared_ptr<add_node<T, SAFE_MATH>>;
852
853//------------------------------------------------------------------------------
861//------------------------------------------------------------------------------
862 template<jit::float_scalar T, bool SAFE_MATH=false>
864 return std::dynamic_pointer_cast<add_node<T, SAFE_MATH>> (x);
865 }
866
867//******************************************************************************
868// Subtract node.
869//******************************************************************************
870//------------------------------------------------------------------------------
877//------------------------------------------------------------------------------
878 template<jit::float_scalar T, bool SAFE_MATH=false>
879 class subtract_node final : public branch_node<T, SAFE_MATH> {
880 private:
881//------------------------------------------------------------------------------
887//------------------------------------------------------------------------------
888 static std::string to_string(leaf_node<T, SAFE_MATH> *l,
890 return jit::format_to_string(reinterpret_cast<size_t> (l)) + "-" +
891 jit::format_to_string(reinterpret_cast<size_t> (r));
892 }
893
894 public:
895//------------------------------------------------------------------------------
900//------------------------------------------------------------------------------
905
906//------------------------------------------------------------------------------
912//------------------------------------------------------------------------------
914 backend::buffer<T> l_result = this->left->evaluate();
915 backend::buffer<T> r_result = this->right->evaluate();
916 return l_result - r_result;
917 }
918
919//------------------------------------------------------------------------------
923//------------------------------------------------------------------------------
925// Identity reductions.
926 auto l = constant_cast(this->left);
927 if (this->left->is_match(this->right)) {
928 if (l.get() && l->is(0)) {
929 return this->left;
930 }
931
932 return zero<T, SAFE_MATH> ();
933 }
934
935// Constant reductions.
936 auto r = constant_cast(this->right);
937
938 if (l.get() && l->is(0)) {
939 return -this->right;
940 } else if (r.get() && r->is(0)) {
941 return this->left;
942 } else if (l.get() && r.get()) {
943 return constant<T, SAFE_MATH> (this->evaluate());
944 } else if (r.get() && r->evaluate().is_negative()) {
945 return this->left + -this->right;
946 }
947
948 auto pl1 = piecewise_1D_cast(this->left);
949 auto pr1 = piecewise_1D_cast(this->right);
950
951 if (pl1.get() && (r.get() || pl1->is_arg_match(this->right))) {
952 return piecewise_1D(this->evaluate(), pl1->get_arg(),
953 pl1->get_scale(), pl1->get_offset());
954 } else if (pr1.get() && (l.get() || pr1->is_arg_match(this->left))) {
955 return piecewise_1D(this->evaluate(), pr1->get_arg(),
956 pr1->get_scale(), pr1->get_offset());
957 }
958
959 auto pl2 = piecewise_2D_cast(this->left);
960 auto pr2 = piecewise_2D_cast(this->right);
961
962 if (pl2.get() && (r.get() || pl2->is_arg_match(this->right))) {
963 return piecewise_2D(this->evaluate(),
964 pl2->get_num_columns(),
965 pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
966 pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
967 } else if (pr2.get() && (l.get() || pr2->is_arg_match(this->left))) {
968 return piecewise_2D(this->evaluate(),
969 pr2->get_num_columns(),
970 pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
971 pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
972 }
973
974// Combine 2D and 1D piecewise constants if a row or column matches.
975 if (pr2.get() && pr2->is_row_match(this->left)) {
976 backend::buffer<T> result = pl1->evaluate();
977 result.subtract_row(pr2->evaluate());
978 return piecewise_2D(result,
979 pr2->get_num_columns(),
980 pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
981 pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
982 } else if (pr2.get() && pr2->is_col_match(this->left)) {
983 backend::buffer<T> result = pl1->evaluate();
984 result.subtract_col(pr2->evaluate());
985 return piecewise_2D(result,
986 pr2->get_num_columns(),
987 pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
988 pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
989 } else if (pl2.get() && pl2->is_row_match(this->right)) {
990 backend::buffer<T> result = pl2->evaluate();
991 result.subtract_row(pr1->evaluate());
992 return piecewise_2D(result,
993 pl2->get_num_columns(),
994 pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
995 pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
996 } else if (pl2.get() && pl2->is_col_match(this->right)) {
997 backend::buffer<T> result = pl2->evaluate();
998 result.subtract_col(pr1->evaluate());
999 return piecewise_2D(result,
1000 pl2->get_num_columns(),
1001 pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
1002 pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
1003 }
1004// (c1 + a) - c2 -> c3 + a
1005// c1 - (c2 + a) -> c3 - a
1006 auto la = add_cast(this->left);
1007 if (la.get()) {
1008 if (is_constant_combinable(la->get_left(), this->right)) {
1009 return (la->get_left() - this->right) + la->get_right();
1010 }
1011 }
1012 auto ra = add_cast(this->right);
1013 if (ra.get()) {
1014 if (is_constant_combinable(this->left, ra->get_left())) {
1015 return (this->left - ra->get_left()) - ra->get_right();
1016 }
1017 }
1018
1019// (c1 - a) - c2 -> c3 - a
1020// (a - c3) - c2 -> a + c3
1021 auto ls = subtract_cast(this->left);
1022 if (ls.get()) {
1023 if (is_constant_combinable(ls->get_left(), this->right)) {
1024 return (ls->get_left() - this->right) - ls->get_right();
1025 } else if (is_constant_combinable(ls->get_right(),
1026 this->right)) {
1027 return -(ls->get_right() + this->right) - ls->get_left();
1028 }
1029 }
1030// c1 - (c2 - a) -> c3 + a
1031// c1 - (a - c2) -> c3 - a
1032 auto rs = subtract_cast(this->right);
1033 if (rs.get()) {
1034 if (is_constant_combinable(this->left, rs->get_left())) {
1035 return (this->left - rs->get_left()) + rs->get_right();
1036 } else if (is_constant_combinable(this->left, rs->get_right())) {
1037 return (this->left + rs->get_right()) - rs->get_left();
1038 }
1039 }
1040
1041// Common factor reduction. If the left and right are both multiply nodes check
1042// for a common factor. So you can change a*b - a*c -> a*(b - c).
1043 auto lm = multiply_cast(this->left);
1044 auto rm = multiply_cast(this->right);
1045
1046// c1*(c2 + a) - c3 -> fma(c1,a,c4)
1047 if (lm.get()) {
1048 auto lmra = add_cast(lm->get_right());
1049 if (lmra.get()) {
1050 if (is_constant_combinable(lm->get_left(),
1051 lmra->get_left()) &&
1052 is_constant_combinable(lm->get_left(),
1053 this->right)) {
1054 return fma(lm->get_left(),
1055 lmra->get_right(),
1056 lm->get_left()*lmra->get_left() - this->right);
1057 }
1058 }
1059// c1*(c2 - a) - c3 -> c4 - c1*a
1060 auto lmrs = subtract_cast(lm->get_right());
1061 if (lmrs.get()) {
1062 if (is_constant_combinable(lm->get_left(),
1063 lmrs->get_left()) &&
1064 is_constant_combinable(lm->get_left(),
1065 this->right)) {
1066 return lm->get_left()*lmrs->get_left() - this->right -
1067 lm->get_left()*lmrs->get_right();
1068 }
1069 }
1070 }
1071
1072// Assume constants are on the left.
1073// v1 - -c*v2 -> v1 + c*v2
1074 if (rm.get() &&
1075 rm->get_left()->is_constant() &&
1076 rm->get_left()->evaluate().is_negative()) {
1077 return this->left + (-this->right);
1078 }
1079
1080 if (lm.get()) {
1081// Assume constants are on the left.
1082// -a - b -> -(a + b)
1083 auto lmc = constant_cast(lm->get_left());
1084 if (lmc.get() && lmc->is(-1)) {
1085 return lm->get_left()*(lm->get_right() + this->right);
1086 }
1087
1088// a*v - v = (a - 1)*v
1089// v*a - v = (a - 1)*v
1090 if (this->right->is_match(lm->get_right())) {
1091 return (lm->get_left() - 1.0)*this->right;
1092 } else if (this->right->is_match(lm->get_left())) {
1093 return (lm->get_right() - 1.0)*this->right;
1094 }
1095 }
1096// v - a*v = (1 - a)*v
1097// v - v*a = (1 - a)*v
1098 if (rm.get()) {
1099 if (this->left->is_match(rm->get_right())) {
1100 return (1.0 - rm->get_left())*this->left;
1101 } else if (this->left->is_match(rm->get_left())) {
1102 return (1.0 - rm->get_right())*this->left;
1103 }
1104 }
1105
1106 if (lm.get() && rm.get()) {
1107 if (lm->get_left()->is_match(rm->get_left())) {
1108// a*b - a*c -> a*(b - c)
1109 return lm->get_left()*(lm->get_right() - rm->get_right());
1110 } else if (lm->get_left()->is_match(rm->get_right())) {
1111// a*b - c*a -> a*(b - c)
1112 return lm->get_left()*(lm->get_right() - rm->get_left());
1113 } else if (lm->get_right()->is_match(rm->get_left())) {
1114// b*a - a*c -> a*(b - c)
1115 return lm->get_right()*(lm->get_left() - rm->get_right());
1116 } else if (lm->get_right()->is_match(rm->get_right())) {
1117// b*a - c*a -> a*(b - c)
1118 return lm->get_right()*(lm->get_left() - rm->get_left());
1119 }
1120
1121// Change cases like c1*a - c2*b -> c1*(a - c2/c1*b)
1122// Note need to make sure c1 doesn't contain any zeros.
1123 if (lm->get_left()->is_constant() &&
1124 rm->get_left()->is_constant() &&
1125 !lm->get_left()->has_constant_zero()) {
1126 return lm->get_left()*(lm->get_right() -
1127 (rm->get_left()/lm->get_left())*rm->get_right());
1128 }
1129
1130// Handle case
1131 auto rmrm = multiply_cast(rm->get_right());
1132 if (rmrm.get()) {
1133// a*b - c*(d*b) -> (a - c*d)*b
1134 if (lm->get_right()->is_match(rmrm->get_right())) {
1135 return (lm->get_left() - rm->get_left()*rmrm->get_left())*lm->get_right();
1136 }
1137// a*b - c*(b*d) -> (a - c*d)*b
1138 if (lm->get_right()->is_match(rmrm->get_left())) {
1139 return (lm->get_left() - rm->get_left()*rmrm->get_right())*lm->get_right();
1140 }
1141// b*a - c*(d*b) -> (a - c*d)*b
1142 if (lm->get_left()->is_match(rmrm->get_right())) {
1143 return (lm->get_right() - rm->get_left()*rmrm->get_left())*lm->get_left();
1144 }
1145// b*a - c*(b*d) -> (a - c*d)*b
1146 if (lm->get_left()->is_match(rmrm->get_left())) {
1147 return (lm->get_right() - rm->get_left()*rmrm->get_right())*lm->get_left();
1148 }
1149 }
1150 auto lmrm = multiply_cast(lm->get_right());
1151 if (lmrm.get()) {
1152// c*(d*b) - a*b -> (c*d - a)*b
1153 if (rm->get_right()->is_match(lmrm->get_right())) {
1154 return (lm->get_left()*lmrm->get_left() - rm->get_left())*rm->get_right();
1155 }
1156// c*(b*d) - a*b -> (c*d - a)*b
1157 if (rm->get_right()->is_match(lmrm->get_left())) {
1158 return (lm->get_left()*lmrm->get_right() - rm->get_left())*rm->get_right();
1159 }
1160// c*(d*b) - b*a -> (c*d - a)*b
1161 if (rm->get_left()->is_match(lmrm->get_right())) {
1162 return (lm->get_left()*lmrm->get_left() - rm->get_right())*rm->get_left();
1163 }
1164// c*(b*d) - b*a -> (c*d - a)*b
1165 if (rm->get_left()->is_match(lmrm->get_left())) {
1166 return (lm->get_left()*lmrm->get_right() - rm->get_right())*rm->get_left();
1167 }
1168 }
1169
1170// a/b*c - d/b*e -> (a*b - d*e)/b
1171// a/b*c - d*e/b -> (a*b - d*e)/b
1172// a*c/b - d/b*e -> (a*b - d*e)/b
1173// a*c/b - d*e/b -> (a*b - d*e)/b
1174 auto lmld = divide_cast(lm->get_left());
1175 auto rmld = divide_cast(rm->get_left());
1176 auto lmrd = divide_cast(lm->get_right());
1177 auto rmrd = divide_cast(rm->get_right());
1178 if (lmld.get() && rmld.get() &&
1179 lmld->get_right()->is_match(rmld->get_right())) {
1180 return (lmld->get_left()*lm->get_right() -
1181 rmld->get_left()*rm->get_right())/lmld->get_right();
1182 } else if (lmld.get() && rmrd.get() &&
1183 lmld->get_right()->is_match(rmrd->get_right())) {
1184 return (lmld->get_left()*lm->get_right() -
1185 rmrd->get_left()*rm->get_left())/lmld->get_right();
1186 } else if (lmrd.get() && rmld.get() &&
1187 lmrd->get_right()->is_match(rmld->get_right())) {
1188 return (lmrd->get_left()*lm->get_left() -
1189 rmld->get_left()*rm->get_right())/lmrd->get_right();
1190 } else if (lmrd.get() && rmrd.get() &&
1191 lmrd->get_right()->is_match(rmrd->get_right())) {
1192 return (lmrd->get_left()*lm->get_left() -
1193 rmrd->get_left()*rm->get_left())/lmrd->get_right();
1194 }
1195 }
1196
1197// Chained subtraction reductions.
1198 if (ls.get()) {
1199 auto lrm = multiply_cast(ls->get_right());
1200 if (lrm.get() && rm.get()) {
1201 if (lrm->get_left()->is_match(rm->get_left())) {
1202// (a - c*b) - c*d -> a - (b + d)*c
1203 return ls->get_left() -
1204 (lrm->get_right() +
1205 rm->get_right())*rm->get_left();
1206 } else if (lrm->get_left()->is_match(rm->get_right())) {
1207// (a - c*b) - d*c -> a - (b + d)*c
1208 return ls->get_left() -
1209 (lrm->get_right() +
1210 rm->get_left())*rm->get_right();
1211 } else if (lrm->get_right()->is_match(rm->get_left())) {
1212// (a - c*b) - c*d -> a - (b + d)*c
1213 return ls->get_left() -
1214 (lrm->get_left() +
1215 rm->get_right())*rm->get_left();
1216 } else if (lrm->get_right()->is_match(rm->get_right())) {
1217// (a - c*b) - d*c -> a - (b + d)*c
1218 return ls->get_left() -
1219 (lrm->get_left() +
1220 rm->get_left())*rm->get_right();
1221 }
1222 }
1223 }
1224
1225// Common denominator reduction. If the left and right are both divide nodes
1226// for a common denominator. So you can change a/b - c/b -> (a - c)/d.
1227 auto ld = divide_cast(this->left);
1228 auto rd = divide_cast(this->right);
1229
1230// c is a constant.
1231// a - -c/b -> a + c/b
1232// a - (-c*d)/b -> a + (c*d)/b
1233// -c/a - b -> -(b + c/a)
1234// (-c*d)/a - b -> -(b + (c*d)/a)
1235 if (rd.get()) {
1236 auto rdlm = multiply_cast(rd->get_left());
1237 if ((rd->get_left()->is_constant() &&
1238 rd->get_left()->evaluate().is_negative()) ||
1239 (rdlm.get() &&
1240 (rdlm->get_left()->is_constant() &&
1241 rdlm->get_left()->evaluate().is_negative()))) {
1242 return this->left + -this->right;
1243 }
1244 } else if (ld.get()) {
1245 auto ldlm = multiply_cast(ld->get_left());
1246 if ((ld->get_left()->is_constant() &&
1247 ld->get_left()->evaluate().is_negative()) ||
1248 (ldlm.get() &&
1249 (ldlm->get_left()->is_constant() &&
1250 ldlm->get_left()->evaluate().is_negative()))) {
1251 return -(-this->left + this->right);
1252 }
1253 }
1254
1255 if (ld.get() && rd.get()) {
1256 if (ld->get_right()->is_match(rd->get_right())) {
1257 return (ld->get_left() - rd->get_left())/ld->get_right();
1258 }
1259
1260 auto ldlm = multiply_cast(ld->get_left());
1261 auto rdlm = multiply_cast(rd->get_left());
1262// a/b - c*a/d -> (1/b - c/d)*a
1263// a/b - a*c/d -> (1/b - c/d)*a
1264// c*a/b - a/d -> (c/b - 1/d)*a
1265// a*c/b - a/d -> (c/b - 1/d)*a
1266 if (rdlm.get()) {
1267 if (ld->get_left()->is_match(rdlm->get_left())) {
1268 return (1.0/ld->get_right() -
1269 rdlm->get_right()/rd->get_right())*rdlm->get_left();
1270 } else if (ld->get_left()->is_match(rdlm->get_right())) {
1271 return (1.0/ld->get_right() -
1272 rdlm->get_left()/rd->get_right())*rdlm->get_right();
1273 }
1274 } else if (ldlm.get()) {
1275 if (rd->get_left()->is_match(ldlm->get_left())) {
1276 return (ldlm->get_right()/ld->get_right() -
1277 1.0/rd->get_right())*ldlm->get_left();
1278 } else if (rd->get_left()->is_match(ldlm->get_right())) {
1279 return (ldlm->get_left()/ld->get_right() -
1280 1.0/rd->get_right())*ldlm->get_right();
1281 }
1282 }
1283
1284// c1*a/b - c2*e/d = c3*(a/b - c4*e/d)
1285// a*b/c - d*b/e -> (a/c - d/e)*b
1286// Make sure we prevent combining constants when we just need to factor out a
1287// common term.
1288// c1*a/b - c2*a/d -> (c1/b - c2/d)*a
1289 if (ldlm.get() && rdlm.get()) {
1290 if (is_constant_combinable(ldlm->get_left(),
1291 rdlm->get_left()) &&
1292 !ldlm->get_right()->is_match(rdlm->get_right())) {
1293 return (ldlm->get_right()/ld->get_right() -
1294 rdlm->get_left()/ldlm->get_left() *
1295 rdlm->get_right()/rd->get_right())*ldlm->get_left();
1296 }
1297
1298 if (ldlm->get_right()->is_match(rdlm->get_right())) {
1299 return (ldlm->get_left()/ld->get_right() -
1300 rdlm->get_left()/rd->get_right())*ldlm->get_right();
1301 } else if (ldlm->get_right()->is_match(rdlm->get_left())) {
1302 return (ldlm->get_left()/ld->get_right() -
1303 rdlm->get_right()/rd->get_right())*ldlm->get_right();
1304 } else if (ldlm->get_left()->is_match(rdlm->get_right())) {
1305 return (ldlm->get_right()/ld->get_right() -
1306 rdlm->get_left()/rd->get_right())*ldlm->get_left();
1307 } else if (ldlm->get_left()->is_match(rdlm->get_left())) {
1308 return (ldlm->get_right()/ld->get_right() -
1309 rdlm->get_right()/rd->get_right())*ldlm->get_left();
1310 }
1311 }
1312
1313// (a/(c*b) - d/(e*c)) -> (a/b - d/e)/c
1314// (a/(b*c) - d/(e*c)) -> (a/b - d/e)/c
1315// (a/(c*b) - d/(c*e)) -> (a/b - d/e)/c
1316// (a/(b*c) - d/(c*e)) -> (a/b - d/e)/c
1317 auto ldrm = multiply_cast(ld->get_right());
1318 auto rdrm = multiply_cast(rd->get_right());
1319 if (ldrm.get() && rdrm.get()) {
1320 if (ldrm->get_right()->is_match(rdrm->get_right())) {
1321 return (ld->get_left()/ldrm->get_left() -
1322 rd->get_left()/rdrm->get_left())/ldrm->get_right();
1323 } else if (ldrm->get_right()->is_match(rdrm->get_left())) {
1324 return (ld->get_left()/ldrm->get_left() -
1325 rd->get_left()/rdrm->get_right())/ldrm->get_right();
1326 } else if (ldrm->get_left()->is_match(rdrm->get_right())) {
1327 return (ld->get_left()/ldrm->get_right() -
1328 rd->get_left()/rdrm->get_left())/ldrm->get_left();
1329 } else if (ldrm->get_left()->is_match(rdrm->get_left())) {
1330 return (ld->get_left()/ldrm->get_right() -
1331 rd->get_left()/rdrm->get_right())/ldrm->get_left();
1332 }
1333 }
1334
1335// a/b - c/(b*d) -> (a*d - c)/(b*d)
1336// a/b - c/(d*b) -> (a*d - c)/(b*d)
1337// a/(b*d) - c/b -> (a - c*d)/(b*d)
1338// a/(d*b) - c/b -> (a - c*d)/(b*d)
1339 if (rdrm.get()) {
1340 if (ld->get_right()->is_match(rdrm->get_left())) {
1341 return (ld->get_left()*rdrm->get_right() - rd->get_left()) /
1342 rd->get_right();
1343 } else if (ld->get_right()->is_match(rdrm->get_right())) {
1344 return (ld->get_left()*rdrm->get_left() - rd->get_left()) /
1345 rd->get_right();
1346 }
1347 } else if (ldrm.get()) {
1348 if (rd->get_right()->is_match(ldrm->get_left())) {
1349 return (ld->get_left() - rd->get_left()*ldrm->get_right()) /
1350 ld->get_right();
1351 } else if (rd->get_right()->is_match(ldrm->get_right())) {
1352 return (ld->get_left() - rd->get_left()*ldrm->get_left()) /
1353 ld->get_right();
1354 }
1355 }
1356 }
1357
1358// Move cases like
1359// (c1 + c2/x) - c3/y -> c1 + (c2/x - c3/y)
1360// (c1 - c2/x) - c3/y -> c1 - (c2/x + c3/y)
1361// in case of common denominators.
1362 if (rd.get()) {
1363 auto la = add_cast(this->left);
1364 if (la.get() && divide_cast(la->get_right()).get()) {
1365 return la->get_left() + (la->get_right() - this->right);
1366 } else if (ls.get() && divide_cast(ls->get_right()).get()) {
1367 return ls->get_left() - (this->right + ls->get_right());
1368 }
1369 }
1370
1371// Handle cases like:
1372// (a/y)^e - b/y^e -> (a^2 - b)/(y^e)
1373// b/y^e - (a/y)^e -> (b - a^2)/(y^e)
1374// (a/y)^e - (b/y)^e -> (a^2 - b^2)/(y^e)
1375 auto pl = pow_cast(this->left);
1376 auto pr = pow_cast(this->right);
1377 if (pl.get() && rd.get()) {
1378 auto rdp = pow_cast(rd->get_right());
1379 if (rdp.get() && pl->get_right()->is_match(rdp->get_right())) {
1380 auto plld = divide_cast(pl->get_left());
1381 if (plld.get() &&
1382 rdp->get_left()->is_match(plld->get_right())) {
1383 return (pow(plld->get_left(), pl->get_right()) -
1384 rd->get_left()) /
1385 pow(rdp->get_left(), pl->get_right());
1386 }
1387 }
1388 } else if (pr.get() && ld.get()) {
1389 auto ldp = pow_cast(ld->get_right());
1390 if (ldp.get() && pr->get_right()->is_match(ldp->get_right())) {
1391 auto prld = divide_cast(pr->get_left());
1392 if (prld.get() &&
1393 ldp->get_left()->is_match(prld->get_right())) {
1394 return (pow(prld->get_left(), pr->get_right()) -
1395 ld->get_left()) /
1396 pow(ldp->get_left(), pr->get_right());
1397 }
1398 }
1399 } else if (pl.get() && pr.get()) {
1400 if (pl->get_right()->is_match(pr->get_right())) {
1401 auto pld = divide_cast(pl->get_left());
1402 auto prd = divide_cast(pr->get_left());
1403 if (pld.get() && prd.get() &&
1404 pld->get_right()->is_match(prd->get_right())) {
1405 return (pow(pld->get_left(), pl->get_right()) -
1406 pow(prd->get_left(), pl->get_right())) /
1407 pow(pld->get_right(), pl->get_right());
1408 }
1409 }
1410 }
1411
1412 auto lfma = fma_cast(this->left);
1413 auto rfma = fma_cast(this->right);
1414
1415 if (lfma.get() && rfma.get()) {
1416 if (lfma->get_middle()->is_match(rfma->get_middle())) {
1417 return fma(lfma->get_left() - rfma->get_left(),
1418 lfma->get_middle(),
1419 lfma->get_right() - rfma->get_right());
1420 }
1421 }
1422
1423// fma(c,d,e) - a -> fma(c,d,e - a)
1424 if (lfma.get() && !this->right->is_all_variables()) {
1425 return fma(lfma->get_left(),
1426 lfma->get_middle(),
1427 lfma->get_right() - this->right);
1428 }
1429
1430// Reduce cases chained subtract multiply divide.
1431 if (ls.get()) {
1432// (a - b*c) - d*e -> a - (b*c + d*e)
1433// (a - b/c) - d/e -> a - (b/c + d/e)
1434 auto lsrd = divide_cast(ls->get_right());
1435 if ((multiply_cast(ls->get_right()).get() && (rm.get() || rd.get())) ||
1436 (divide_cast(ls->get_right()).get() && (rm.get() || rd.get()))) {
1437 return ls->get_left() - (ls->get_right() + this->right);
1438 }
1439 }
1440
1441 return this->shared_from_this();
1442 }
1443
1444//------------------------------------------------------------------------------
1451//------------------------------------------------------------------------------
1454 if (this->is_match(x)) {
1455 return one<T, SAFE_MATH> ();
1456 }
1457
1458 const size_t hash = reinterpret_cast<size_t> (x.get());
1459 if (this->df_cache.find(hash) == this->df_cache.end()) {
1460 this->df_cache[hash] = this->left->df(x) - this->right->df(x);
1461 }
1462 return this->df_cache[hash];
1463 }
1464
1465//------------------------------------------------------------------------------
1473//------------------------------------------------------------------------------
1475 compile(std::ostringstream &stream,
1476 jit::register_map &registers,
1478 const jit::register_usage &usage) {
1479 if (registers.find(this) == registers.end()) {
1480 shared_leaf<T, SAFE_MATH> l = this->left->compile(stream,
1481 registers,
1482 indices,
1483 usage);
1484 shared_leaf<T, SAFE_MATH> r = this->right->compile(stream,
1485 registers,
1486 indices,
1487 usage);
1488
1489 registers[this] = jit::to_string('r', this);
1490 stream << " const ";
1491 jit::add_type<T> (stream);
1492 stream << " " << registers[this] << " = "
1493 << registers[l.get()] << " - "
1494 << registers[r.get()];
1495 this->endline(stream, usage);
1496 }
1497
1498 return this->shared_from_this();
1499 }
1500
1501//------------------------------------------------------------------------------
1506//------------------------------------------------------------------------------
1508 if (this == x.get()) {
1509 return true;
1510 }
1511
1512 auto x_cast = subtract_cast(x);
1513 if (x_cast.get()) {
1514 return this->left->is_match(x_cast->get_left()) &&
1515 this->right->is_match(x_cast->get_right());
1516 }
1517
1518 return false;
1519 }
1520
1521//------------------------------------------------------------------------------
1523//------------------------------------------------------------------------------
1524 virtual void to_latex() const {
1525 bool l_brackets = add_cast(this->left).get() ||
1526 subtract_cast(this->left).get();
1527 bool r_brackets = add_cast(this->right).get() ||
1528 subtract_cast(this->right).get();
1529 if (l_brackets) {
1530 std::cout << "\\left(";
1531 }
1532 this->left->to_latex();
1533 if (l_brackets) {
1534 std::cout << "\\right)";
1535 }
1536 std::cout << "-";
1537 if (r_brackets) {
1538 std::cout << "\\left(";
1539 }
1540 this->right->to_latex();
1541 if (r_brackets) {
1542 std::cout << "\\right)";
1543 }
1544 }
1545
1546//------------------------------------------------------------------------------
1550//------------------------------------------------------------------------------
1552 if (this->has_pseudo()) {
1553 return this->left->remove_pseudo() -
1554 this->right->remove_pseudo();
1555 }
1556 return this->shared_from_this();
1557 }
1558
1559//------------------------------------------------------------------------------
1565//------------------------------------------------------------------------------
1566 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
1567 jit::register_map &registers) {
1568 if (registers.find(this) == registers.end()) {
1569 const std::string name = jit::to_string('r', this);
1570 registers[this] = name;
1571 stream << " " << name
1572 << " [label = \"-\", shape = oval, style = filled, fillcolor = blue, fontcolor = white];" << std::endl;
1573
1574 auto l = this->left->to_vizgraph(stream, registers);
1575 stream << " " << name << " -- " << registers[l.get()] << ";" << std::endl;
1576 auto r = this->right->to_vizgraph(stream, registers);
1577 stream << " " << name << " -- " << registers[r.get()] << ";" << std::endl;
1578 }
1579
1580 return this->shared_from_this();
1581 }
1582 };
1583
1584//------------------------------------------------------------------------------
1593//------------------------------------------------------------------------------
1594 template<jit::float_scalar T, bool SAFE_MATH=false>
1597 auto temp = std::make_shared<subtract_node<T, SAFE_MATH>> (l, r)->reduce();
1598// Test for hash collisions.
1599 for (size_t i = temp->get_hash();
1601 if (leaf_node<T, SAFE_MATH>::caches.nodes.find(i) ==
1602 leaf_node<T, SAFE_MATH>::caches.nodes.end()) {
1604 return temp;
1605 } else if (temp->is_match(leaf_node<T, SAFE_MATH>::caches.nodes[i])) {
1606 return leaf_node<T, SAFE_MATH>::caches.nodes[i];
1607 }
1608 }
1609#if defined(__clang__) || defined(__GNUC__)
1611#else
1612 assert(false && "Should never reach.");
1613#endif
1614 }
1615
1616//------------------------------------------------------------------------------
1628//------------------------------------------------------------------------------
1629 template<jit::float_scalar T, bool SAFE_MATH=false>
1634
1635//------------------------------------------------------------------------------
1648//------------------------------------------------------------------------------
1649 template<jit::float_scalar T, jit::float_scalar L, bool SAFE_MATH=false>
1654
1655//------------------------------------------------------------------------------
1668//------------------------------------------------------------------------------
1669 template<jit::float_scalar T, jit::float_scalar R, bool SAFE_MATH=false>
1674
1675//------------------------------------------------------------------------------
1686//------------------------------------------------------------------------------
1687 template<jit::float_scalar T, bool SAFE_MATH=false>
1691
1693 template<jit::float_scalar T, bool SAFE_MATH=false>
1694 using shared_subtract = std::shared_ptr<subtract_node<T, SAFE_MATH>>;
1695
1696//------------------------------------------------------------------------------
1704//------------------------------------------------------------------------------
1705 template<jit::float_scalar T, bool SAFE_MATH=false>
1707 return std::dynamic_pointer_cast<subtract_node<T, SAFE_MATH>> (x);
1708 }
1709
1710//******************************************************************************
1711// Multiply node.
1712//******************************************************************************
1713//------------------------------------------------------------------------------
1718//------------------------------------------------------------------------------
1719 template<jit::float_scalar T, bool SAFE_MATH=false>
1720 class multiply_node final : public branch_node<T, SAFE_MATH> {
1721 private:
1722//------------------------------------------------------------------------------
1729//------------------------------------------------------------------------------
1731 reduce_nested_fma_times_constant(shared_leaf<T, SAFE_MATH> trial) {
1732 auto temp = fma_cast(trial);
1733 if (temp.get()) {
1734 if (is_constant_combinable(this->left, temp->get_left()) &&
1735 is_constant_combinable(this->left, temp->get_right())) {
1736 return fma(this->left*temp->get_left(),
1737 temp->get_middle(),
1738 this->left*temp->get_right());
1739 } else {
1740 auto temp2 = reduce_nested_fma_times_constant(temp->get_left());
1741 if (temp2.get()) {
1742 return fma(temp2,
1743 temp->get_middle(),
1744 this->left*temp->get_right());
1745 }
1746 }
1747 }
1748 return null_leaf<T, SAFE_MATH> ();
1749 }
1750
1751//------------------------------------------------------------------------------
1759//------------------------------------------------------------------------------
1761 expand_nested_fma_times_add(shared_leaf<T, SAFE_MATH> trial,
1763 auto temp = fma_cast(trial);
1764 if (temp.get()) {
1765 if (add->get_right()->is_match(temp->get_middle()) &&
1766 is_constant_combinable(add->get_left(), temp->get_right())) {
1767 auto temp2 = expand_nested_fma_times_add2(temp->get_left(),
1768 temp, add);
1769 if (temp2.get()) {
1770 return fma(temp2,
1771 add->get_right(),
1772 temp->get_right()*add->get_left());
1773 } else if (is_constant_combinable(add->get_left(), temp->get_left())) {
1774 return fma(fma(temp->get_left(),
1775 add->get_right(),
1776 add->get_left()*temp->get_left() + temp->get_right()),
1777 add->get_right(),
1778 temp->get_right()*add->get_left());
1779 }
1780 }
1781 }
1782 return null_leaf<T, SAFE_MATH> ();
1783 }
1784
1785//------------------------------------------------------------------------------
1794//------------------------------------------------------------------------------
1796 expand_nested_fma_times_add2(shared_leaf<T, SAFE_MATH> trial,
1799 auto temp = fma_cast(trial);
1800 auto temp2 = fma_cast(last);
1801 assert(temp2.get() && "Assumed a fma node.");
1802 if (temp.get()) {
1803 if (add->get_right()->is_match(temp->get_middle()) &&
1804 is_constant_combinable(add->get_left(), temp->get_left()) &&
1805 is_constant_combinable(add->get_left(), temp->get_right())) {
1806
1807 return fma(fma(temp->get_left(),
1808 add->get_right(),
1809 add->get_left()*temp->get_left() +
1810 temp->get_right()),
1811 add->get_right(),
1812 add->get_left()*temp->get_right() +
1813 temp2->get_right());
1814 } else {
1815 auto temp3 = expand_nested_fma_times_add2(temp->get_left(),
1816 temp, add);
1817 if (temp3.get()) {
1818 return fma(temp3,
1819 add->get_right(),
1820 add->get_left()*temp->get_right() +
1821 temp2->get_right());
1822 }
1823 }
1824 }
1825 return null_leaf<T, SAFE_MATH> ();
1826 }
1827
1828//------------------------------------------------------------------------------
1834//------------------------------------------------------------------------------
1835 static std::string to_string(leaf_node<T, SAFE_MATH> *l,
1837 return jit::format_to_string(reinterpret_cast<size_t> (l)) + "*" +
1838 jit::format_to_string(reinterpret_cast<size_t> (r));
1839 }
1840
1841 public:
1842//------------------------------------------------------------------------------
1847//------------------------------------------------------------------------------
1851
1852//------------------------------------------------------------------------------
1858//------------------------------------------------------------------------------
1860 backend::buffer<T> l_result = this->left->evaluate();
1861
1862// If the left are right are same don't evaluate the right.
1863// NOTE: Do not use is_match here. Remove once power is implemented.
1864 if (this->left.get() == this->right.get()) {
1865 return l_result*l_result;
1866 }
1867
1868// If all the elements on the left are zero, return the left side without
1869// reevaluating the right side. Stop this loop early once the first non zero
1870// element is encountered.
1871 if (l_result.is_zero()) {
1872 return l_result;
1873 }
1874
1875 backend::buffer<T> r_result = this->right->evaluate();
1876 return l_result*r_result;
1877 }
1878
1879//------------------------------------------------------------------------------
1883//------------------------------------------------------------------------------
1885 auto l = constant_cast(this->left);
1886 auto r = constant_cast(this->right);
1887
1888 if (l.get() && l->is(1)) {
1889 return this->right;
1890 } else if (l.get() && l->is(0)) {
1891 return this->left;
1892 } else if (r.get() && r->is(1)) {
1893 return this->left;
1894 } else if (r.get() && r->is(0)) {
1895 return this->right;
1896 } else if (l.get() && r.get()) {
1897 return constant<T, SAFE_MATH> (this->evaluate());
1898 }
1899
1900 auto pl1 = piecewise_1D_cast(this->left);
1901 auto pr1 = piecewise_1D_cast(this->right);
1902
1903 if (pl1.get() && (r.get() || pl1->is_arg_match(this->right))) {
1904 return piecewise_1D(this->evaluate(), pl1->get_arg(),
1905 pl1->get_scale(), pl1->get_offset());
1906 } else if (pr1.get() && (l.get() || pr1->is_arg_match(this->left))) {
1907 return piecewise_1D(this->evaluate(), pr1->get_arg(),
1908 pr1->get_scale(), pr1->get_offset());
1909 }
1910
1911 auto pl2 = piecewise_2D_cast(this->left);
1912 auto pr2 = piecewise_2D_cast(this->right);
1913
1914 if (pl2.get() && (r.get() || pl2->is_arg_match(this->right))) {
1915 return piecewise_2D(this->evaluate(),
1916 pl2->get_num_columns(),
1917 pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
1918 pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
1919 } else if (pr2.get() && (l.get() || pr2->is_arg_match(this->left))) {
1920 return piecewise_2D(this->evaluate(),
1921 pr2->get_num_columns(),
1922 pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
1923 pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
1924 }
1925
1926// Combine 2D and 1D piecewise constants if a row or column matches.
1927 if (pr2.get() && pr2->is_row_match(this->left)) {
1928 backend::buffer<T> result = pl1->evaluate();
1929 result.multiply_row(pr2->evaluate());
1930 return piecewise_2D(result,
1931 pr2->get_num_columns(),
1932 pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
1933 pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
1934 } else if (pr2.get() && pr2->is_col_match(this->left)) {
1935 backend::buffer<T> result = pl1->evaluate();
1936 result.multiply_col(pr2->evaluate());
1937 return piecewise_2D(result,
1938 pr2->get_num_columns(),
1939 pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
1940 pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
1941 } else if (pl2.get() && pl2->is_row_match(this->right)) {
1942 backend::buffer<T> result = pl2->evaluate();
1943 result.multiply_row(pr1->evaluate());
1944 return piecewise_2D(result,
1945 pl2->get_num_columns(),
1946 pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
1947 pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
1948 } else if (pl2.get() && pl2->is_col_match(this->right)) {
1949 backend::buffer<T> result = pl2->evaluate();
1950 result.multiply_col(pr1->evaluate());
1951 return piecewise_2D(result,
1952 pl2->get_num_columns(),
1953 pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
1954 pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
1955 }
1956
1957// Move constants to the left.
1958 if (is_constant_promotable(this->right, this->left)) {
1959 return this->right*this->left;
1960 }
1961
1962// Disable if the right is power like to avoid infinite loop.
1963 if (is_variable_promotable(this->left, this->right)) {
1964 return this->right*this->left;
1965 }
1966
1967// Move trig to the right.
1968 auto cl = cos_cast(this->left);
1969 auto sl = sin_cast(this->left);
1970 if ((cl.get() && !this->right->is_power_like() &&
1971 !this->right->is_all_variables() &&
1972 !sin_cast(this->right).get()) ||
1973 (sl.get() && !this->right->is_power_like() &&
1974 !this->right->is_all_variables()) ||
1975 (sl.get() && cos_cast(this->right).get())) {
1976 return this->right*this->left;
1977 }
1978
1979// Reduce x*x to x^2
1980 if (this->left->is_match(this->right)) {
1981 return pow(this->left, 2.0);
1982 }
1983
1984// Gather common terms.
1985 auto lm = multiply_cast(this->left);
1986 if (lm.get()) {
1987// Promote constants before variables.
1988// (c*v1)*v2 -> c*(v1*v2)
1989 if (is_constant_promotable(lm->get_left(),
1990 lm->get_right())) {
1991 return lm->get_left()*(lm->get_right()*this->right);
1992 }
1993
1994// (a^c*b)*a^d -> a^(c+d)*b
1995// (b*a^c)*a^d -> a^(c+d)*b
1996 if (is_variable_combinable(this->right, lm->get_left())) {
1997 return (this->right*lm->get_left())*lm->get_right();
1998 } else if (is_variable_combinable(this->right, lm->get_right())) {
1999 return (this->right*lm->get_right())*lm->get_left();
2000 }
2001
2002// Assume variables, sqrt of variables, and powers of variables are on the
2003// right.
2004// (a*v)*b -> (a*b)*v
2005 if (is_variable_promotable(lm->get_right(), this->right)) {
2006 return (lm->get_left()*this->right)*lm->get_right();
2007 }
2008
2009// (a*(b*c)^e)*c^f -> a*b^e*c^(e+f)
2010 auto lmrp = pow_cast(lm->get_right());
2011 if (lmrp.get()) {
2012 auto lmrplm = multiply_cast(lmrp->get_left());
2013 if (lmrplm.get() &&
2014 is_variable_combinable(lmrplm->get_right(),
2015 this->right)) {
2016 return (lm->get_left()*pow(lmrplm->get_left(),
2017 lmrp->get_right()))*pow(this->right->get_power_base(),
2018 lmrp->get_right() +
2019 this->right->get_power_exponent());
2020 }
2021 }
2022 }
2023
2024 auto rm = multiply_cast(this->right);
2025 if (rm.get()) {
2026// Assume constants are on the left.
2027// c1*(c2*v) -> c3*v
2028 if (is_constant_combinable(this->left,
2029 rm->get_left())) {
2030 auto temp = this->left*rm->get_left();
2031 if (temp->is_normal()) {
2032 return temp*rm->get_right();
2033 }
2034 }
2035
2036// a*(a*b) -> a^2*b
2037// a*(b*a) -> a^2*b
2038 if (is_variable_combinable(this->left, rm->get_left())) {
2039 return (this->left*rm->get_left())*rm->get_right();
2040 } else if (is_variable_combinable(this->left, rm->get_right())) {
2041 return (this->left*rm->get_right())*rm->get_left();
2042 }
2043
2044// Assume variables are on the left.
2045// a*(b*v) -> (a*b)*v
2046 if (is_variable_promotable(rm->get_right(), this->left)) {
2047 return (this->left*rm->get_left())*rm->get_right();
2048 }
2049
2050// c1*(fma(c2,x,c3)*y)-> fma(c4,x,c5)*y
2051// c1*(fma(fma(c2,x,c3),x,c4)*y)-> fma(fma(c5,x,c6),x,c7)*y
2052// c1*(fma(fma(fma(c2,x,c3),x,c4),x,c5)*y)-> fma(fma(fma(c6,x,c7),x,c8),x,c9)*y
2053// etc...
2054 auto temp = this->reduce_nested_fma_times_constant(rm->get_left());
2055 if (temp.get()) {
2056 return temp*rm->get_right();
2057 }
2058 }
2059
2060// v1*(c*v2) -> c*(v1*v2)
2061 if (rm.get() &&
2062 is_constant_promotable(rm->get_left(), this->left)) {
2063 return rm->get_left()*(this->left*rm->get_right());
2064 }
2065
2066// Assume trig on the right.
2067// a*(b*sin) -> (a*b)*sin
2068// a*(b*cos) -> (a*b)*cos
2069// (a*sin)*b -> (a*b)*sin
2070// (a*cos)*b -> (a*b)*cos
2071 if (lm.get() &&
2072 (sin_cast(lm->get_right()).get() ||
2073 cos_cast(lm->get_right()).get()) &&
2074 !sin_cast(this->right).get() &&
2075 !this->right->is_power_like()) {
2076 return (lm->get_left()*this->right)*lm->get_right();
2077 } else if (rm.get() &&
2078 (sin_cast(rm->get_right()).get() ||
2079 cos_cast(rm->get_right()).get()) &&
2080 !this->left->is_constant()) {
2081 return (this->left*rm->get_left())*rm->get_right();
2082 }
2083
2084// Factor out common constants c*b*c*d -> c*c*b*d. c*c will get reduced to c on
2085// the second pass.
2086 if (lm.get() && rm.get()) {
2087 if (is_constant_combinable(lm->get_left(),
2088 rm->get_left())) {
2089 auto temp = lm->get_left()*rm->get_left();
2090 if (temp->is_normal()) {
2091 return temp*(lm->get_right()*rm->get_right());
2092 }
2093 } else if (is_constant_combinable(lm->get_left(),
2094 rm->get_right())) {
2095 auto temp = lm->get_left()*rm->get_right();
2096 if (temp->is_normal()) {
2097 return temp*(lm->get_right()*rm->get_left());
2098 }
2099 } else if (is_constant_combinable(lm->get_right(),
2100 rm->get_left())) {
2101 auto temp = lm->get_right()*rm->get_left();
2102 if (temp->is_normal()) {
2103 return temp*(lm->get_left()*rm->get_right());
2104 }
2105 } else if (is_constant_combinable(lm->get_right(),
2106 rm->get_right())) {
2107 auto temp = lm->get_right()*rm->get_right();
2108 if (temp->is_normal()) {
2109 return temp*(lm->get_left()*rm->get_left());
2110 }
2111 }
2112
2113// Gather common terms. This will help reduce sqrt(a)*sqrt(a).
2114 if (lm->get_left()->is_match(rm->get_left())) {
2115 return (lm->get_left()*rm->get_left()) *
2116 (lm->get_right()*rm->get_right());
2117 } else if (lm->get_right()->is_match(rm->get_left())) {
2118 return (lm->get_right()*rm->get_left()) *
2119 (lm->get_left()*rm->get_right());
2120 } else if (lm->get_left()->is_match(rm->get_right())) {
2121 return (lm->get_left()*rm->get_right()) *
2122 (lm->get_right()*rm->get_left());
2123 } else if (lm->get_right()->is_match(rm->get_right())) {
2124 return (lm->get_right()*rm->get_right()) *
2125 (lm->get_left()*rm->get_left());
2126 }
2127 }
2128
2129// Common factor reduction. (a/b)*(c/a) = c/b.
2130 auto ld = divide_cast(this->left);
2131 auto rd = divide_cast(this->right);
2132
2133// a*(b/c) -> (a*b)/c
2134// (a/c)*b -> (a*b)/c
2135 if (rd.get()) {
2136 return (this->left*rd->get_left())/rd->get_right();
2137 } else if (ld.get()) {
2138 return (ld->get_left()*this->right)/ld->get_right();
2139 }
2140
2141// (a/b)*(c/a) -> c/b
2142// (b/a)*(a/c) -> c/b
2143 if (ld.get() && rd.get()) {
2144 if (ld->get_left()->is_match(rd->get_right())) {
2145 return rd->get_left()/ld->get_right();
2146 } else if (ld->get_right()->is_match(rd->get_left())) {
2147 return ld->get_left()/rd->get_right();
2148 }
2149
2150// Convert (a/b)*(c/d) -> (a*c)/(b*d). This should help reduce cases like.
2151// (a/b)*(a/b) + (c/b)*(c/b).
2152 return (ld->get_left()*rd->get_left()) /
2153 (ld->get_right()*rd->get_right());
2154 }
2155
2156// Power reductions.
2157 if (is_variable_combinable(this->left, this->right)) {
2158 return pow(this->left->get_power_base(),
2159 this->left->get_power_exponent() +
2160 this->right->get_power_exponent());
2161 }
2162
2163// a*b^-c -> a/b^c
2164 auto rp = pow_cast(this->right);
2165 if (rp.get()) {
2166 auto exponent = constant_cast(rp->get_right());
2167 if (exponent.get() && exponent->evaluate().is_negative()) {
2168 return this->left/pow(rp->get_left(), -rp->get_right());
2169 }
2170 }
2171// b^-c*a -> a/b^c
2172 auto lp = pow_cast(this->left);
2173 if (lp.get()) {
2174 auto exponent = constant_cast(lp->get_right());
2175 if (exponent.get() && exponent->evaluate().is_negative()) {
2176 return this->right/pow(lp->get_left(), -lp->get_right());
2177 }
2178 }
2179// a^b*c^b -> (a*c)^b
2180 if (lp.get() && rp.get()) {
2181 if (lp->get_right()->is_match(rp->get_right())) {
2182 return pow(lp->get_left()*rp->get_left(), lp->get_right());
2183 }
2184 }
2185// (a*b^c)*d^c -> a*(b*d)^c
2186// (a^c*b)*d^c -> b*(a*d)^c
2187// a^c*(b*d^c) -> b*(a*d)^c
2188// a^c*(b^c*d) -> d*(a*b)^c
2189 if (lm.get() && rp.get()) {
2190 auto lmlp = pow_cast(lm->get_left());
2191 auto lmrp = pow_cast(lm->get_right());
2192 if (lmrp.get()) {
2193 if (lmrp->get_right()->is_match(rp->get_right())) {
2194 return lm->get_left()*pow(lmrp->get_left()*rp->get_left(),
2195 rp->get_right());
2196 }
2197 } else if (lmlp.get()) {
2198 if (lmlp->get_right()->is_match(rp->get_right())) {
2199 return lm->get_right()*pow(lmlp->get_left()*rp->get_left(),
2200 rp->get_right());
2201 }
2202 }
2203 } else if (rm.get() && lp.get()) {
2204 auto rmlp = pow_cast(rm->get_left());
2205 auto rmrp = pow_cast(rm->get_right());
2206 if (rmrp.get()) {
2207 if (rmrp->get_right()->is_match(lp->get_right())) {
2208 return rm->get_left()*pow(lp->get_left()*rmrp->get_left(),
2209 lp->get_right());
2210 }
2211 } else if (rmlp.get()) {
2212 if (rmlp->get_right()->is_match(lp->get_right())) {
2213 return rm->get_right()*pow(lp->get_left()*rmlp->get_left(),
2214 lp->get_right());
2215 }
2216 }
2217 }
2218
2219// (b*a)^c*a^d -> b^c*a^(c + d)
2220// (a*b)^c*a^d -> b^c*a^(c + d)
2221// a^d*(b*a)^c -> b^c*a^(c + d)
2222// a^d*(a*b)^c -> b^c*a^(c + d)
2223 if (lp.get() && rp.get()) {
2224 auto lplm = multiply_cast(lp->get_left());
2225 auto rplm = multiply_cast(rp->get_left());
2226 if (lplm.get()) {
2227 if (is_variable_combinable(lplm->get_right(),
2228 this->right)) {
2229 return pow(lplm->get_left()->get_power_base(),
2230 this->left->get_power_exponent())*
2231 pow(this->right->get_power_base(),
2232 this->left->get_power_exponent() +
2233 this->right->get_power_exponent());
2234 } else if (is_variable_combinable(lplm->get_left(),
2235 this->right)) {
2236 return pow(lplm->get_right()->get_power_base(),
2237 this->left->get_power_exponent())*
2238 pow(this->right->get_power_base(),
2239 this->left->get_power_exponent() +
2240 this->right->get_power_exponent());
2241 }
2242 }
2243
2244 if (rplm.get()) {
2245 if (is_variable_combinable(rplm->get_right(),
2246 this->left)) {
2247 return pow(rplm->get_left()->get_power_base(),
2248 this->right->get_power_exponent())*
2249 pow(this->left->get_power_base(),
2250 this->left->get_power_exponent() +
2251 this->right->get_power_exponent());
2252 } else if (is_variable_combinable(rplm->get_left(),
2253 this->left)) {
2254 return pow(rplm->get_right()->get_power_base(),
2255 this->right->get_power_exponent())*
2256 pow(this->left->get_power_base(),
2257 this->left->get_power_exponent() +
2258 this->right->get_power_exponent());
2259 }
2260 }
2261 }
2262
2263 auto lpd = divide_cast(this->left->get_power_base());
2264 if (lpd.get()) {
2265// (a/b)^c*b^d -> a^c*b^(c-d)
2266 if (is_variable_combinable(lpd->get_right(),
2267 this->right)) {
2268 return pow(lpd->get_left(), this->left->get_power_exponent()) *
2269 pow(this->right->get_power_base(),
2270 this->right->get_power_exponent() -
2271 this->left->get_power_exponent()*lpd->get_right()->get_power_exponent());
2272 }
2273// (b/a)^c*b^d -> b^(c+d)/a^c
2274 if (is_variable_combinable(lpd->get_left(), this->right)) {
2275 return pow(this->right->get_power_base(),
2276 this->right->get_power_exponent() +
2277 this->left->get_power_exponent()*lpd->get_left()->get_power_exponent()) /
2278 pow(lpd->get_right(), this->left->get_power_exponent());
2279 }
2280 }
2281 auto rpd = divide_cast(this->right->get_power_base());
2282 if (rpd.get()) {
2283// b^d*(a/b)^c -> a^c*b^(c-d)
2284 if (is_variable_combinable(rpd->get_right(),
2285 this->left)) {
2286 return pow(rpd->get_left(), this->right->get_power_exponent()) *
2287 pow(this->left->get_power_base(),
2288 this->left->get_power_exponent() -
2289 this->right->get_power_exponent()*rpd->get_right()->get_power_exponent());
2290 }
2291// b^d*(b/a)^c -> b^(c+d)/a^c
2292 if (is_variable_combinable(rpd->get_left(),
2293 this->left)) {
2294 return pow(this->right->get_power_base(),
2295 this->right->get_power_exponent() +
2296 this->right->get_power_exponent()*rpd->get_left()->get_power_exponent()) /
2297 pow(rpd->get_right(), this->right->get_power_exponent());
2298 }
2299 }
2300
2301// exp(a)*exp(b) -> exp(a + b)
2302 auto le = exp_cast(this->left);
2303 auto re = exp_cast(this->right);
2304 if (le.get() && re.get()) {
2305 return exp(le->get_arg() + re->get_arg());
2306 }
2307
2308// exp(a)*(exp(b)*c) -> c*(exp(a)*exp(b))
2309// exp(a)*(c*exp(b)) -> c*(exp(a)*exp(b))
2310 if (le.get() && rm.get()) {
2311 auto rmle = exp_cast(rm->get_left());
2312 if (rmle.get()) {
2313 return rm->get_right()*(this->left*rm->get_left());
2314 }
2315 auto rmre = exp_cast(rm->get_right());
2316 if (rmre.get()) {
2317 return rm->get_left()*(this->left*rm->get_right());
2318 }
2319 }
2320// (exp(a)*c)*exp(b) -> c*(exp(a)*exp(b))
2321// (c*exp(a))*exp(b) -> c*(exp(a)*exp(b))
2322 if (re.get() && lm.get()) {
2323 auto lmle = exp_cast(lm->get_left());
2324 if (lmle.get()) {
2325 return lm->get_right()*(this->right*lm->get_left());
2326 }
2327 auto lmre = exp_cast(lm->get_right());
2328 if (lmre.get()) {
2329 return lm->get_left()*(this->right*lm->get_right());
2330 }
2331 }
2332// (exp(a)*c)*(exp(b)*d) -> (c*d)*(exp(a)*exp(b))
2333// (exp(a)*c)*(d*exp(b)) -> (c*d)*(exp(a)*exp(b))
2334// (c*exp(a))*(exp(b)*d) -> (c*d)*(exp(a)*exp(b))
2335// (c*exp(a))*(d*exp(b)) -> (c*d)*(exp(a)*exp(b))
2336 if (lm.get() && rm.get()) {
2337 auto lmle = exp_cast(lm->get_left());
2338 if (lmle.get()) {
2339 auto rmle = exp_cast(rm->get_left());
2340 if (rmle.get()) {
2341 return (lm->get_right()*rm->get_right()) *
2342 (lm->get_left()*rm->get_left());
2343 }
2344 auto rmre = exp_cast(rm->get_right());
2345 if (rmre.get()) {
2346 return (lm->get_right()*rm->get_left()) *
2347 (lm->get_left()*rm->get_right());
2348 }
2349 }
2350 auto lmre = exp_cast(lm->get_right());
2351 if (lmre.get()) {
2352 auto rmle = exp_cast(rm->get_left());
2353 if (rmle.get()) {
2354 return (lm->get_left()*rm->get_right()) *
2355 (lm->get_right()*rm->get_left());
2356 }
2357 auto rmre = exp_cast(rm->get_right());
2358 if (rmre.get()) {
2359 return (lm->get_left()*rm->get_left()) *
2360 (lm->get_right()*rm->get_right());
2361 }
2362 }
2363 }
2364
2365 if (ld.get() && re.get()) {
2366// (c/exp(a))*exp(b) -> c*(exp(b)/exp(a))
2367 auto ldre = exp_cast(ld->get_right());
2368 if (ldre.get()) {
2369 return ld->get_left()*(this->right/ld->get_right());
2370 }
2371// (exp(a)/c)*exp(b) -> (exp(a)*exp(b))/c
2372 auto ldle = exp_cast(ld->get_left());
2373 if (ldle.get()) {
2374 return (ld->get_left()*this->right)/ld->get_right();
2375 }
2376 }
2377 if (rd.get() && le.get()) {
2378// exp(a)*(c/exp(a)) -> c*(exp(a)/exp(b))
2379 auto rdre = exp_cast(rd->get_right());
2380 if (rdre.get()) {
2381 return rd->get_left()*(this->left/rd->get_right());
2382 }
2383// exp(a)*(exp(b)/c) -> (exp(a)*exp(b))/c
2384 auto rdle = exp_cast(rd->get_left());
2385 if (rdle.get()) {
2386 return (this->left*rd->get_left())/rd->get_right();
2387 }
2388 }
2389
2390 if (ld.get() && rm.get()) {
2391 auto rmle = exp_cast(rm->get_left());
2392 if (rmle.get()) {
2393// (c/exp(a))*(exp(b)*d) -> (c*d)*(exp(b)/exp(a))
2394 auto ldre = exp_cast(ld->get_right());
2395 if (ldre.get()) {
2396 return (ld->get_left()*rm->get_right()) *
2397 (rm->get_left()/ld->get_right());
2398 }
2399// (exp(a)/c)*(exp(b)*d) -> (d/c)*(exp(a)*exp(b))
2400 auto ldle = exp_cast(ld->get_left());
2401 if (ldle.get()) {
2402 return (rm->get_right()/ld->get_right()) *
2403 (ld->get_left()*rm->get_left());
2404 }
2405 }
2406 auto rmre = exp_cast(rm->get_right());
2407 if (rmre.get()) {
2408// (c/exp(a))*(d*exp(b)) -> (c*d)*(exp(b)/exp(a))
2409 auto ldre = exp_cast(ld->get_right());
2410 if (ldre.get()) {
2411 return (ld->get_left()*rm->get_left()) *
2412 (rm->get_right()/ld->get_right());
2413 }
2414// (exp(a)/c)*(d*exp(b)) -> (d/c)*(exp(a)*exp(b))
2415 auto ldle = exp_cast(ld->get_left());
2416 if (ldle.get()) {
2417 return (rm->get_left()/ld->get_right()) *
2418 (ld->get_left()*rm->get_right());
2419 }
2420 }
2421 } else if (rd.get() && lm.get()) {
2422 auto lmre = exp_cast(lm->get_right());
2423 if (lmre.get()) {
2424// (c*exp(a))*(exp(b)/d) -> (c/d)*(exp(a)*exp(b))
2425 auto rdre = exp_cast(rd->get_left());
2426 if (rdre.get()) {
2427 return (lm->get_left()/rd->get_right()) *
2428 (lm->get_right()*rd->get_left());
2429 }
2430// (c*exp(a))*(d/exp(b)) -> (c*d)*(exp(a)/exp(b))
2431 auto rdle = exp_cast(rd->get_right());
2432 if (rdle.get()) {
2433 return (lm->get_left()*rd->get_left()) *
2434 (lm->get_right()/rd->get_right());
2435 }
2436 }
2437 auto lmle = exp_cast(lm->get_left());
2438 if (lmle.get()) {
2439// (exp(a)*c)*(d/exp(b)) -> (c*d)*(exp(a)/exp(b))
2440 auto rdle = exp_cast(rd->get_right());
2441 if (rdle.get()) {
2442 return (lm->get_right()*rd->get_left()) *
2443 (lm->get_left()/rd->get_right());
2444 }
2445// (exp(a)*c)*(exp(b)/d) -> (c/d)*(exp(a)*exp(b))
2446 auto rdre = exp_cast(rd->get_left());
2447 if (rdre.get()) {
2448 return (lm->get_right()/rd->get_right()) *
2449 (lm->get_left()*rd->get_left());
2450 }
2451 }
2452 }
2453
2454// c1*fma(c2,x,c3) -> fma(c4,x,c5)
2455// c1*fma(fma(c2,x,c3),x,c4) -> fma(fma(c5,x,c6),x,c7)
2456// c1*fma(fma(fma(c2,x,c3),x,c4),x,c5) -> fma(fma(fma(c6,x,c7),x,c8),x,c9)
2457// etc...
2458 auto fma_reduce = this->reduce_nested_fma_times_constant(this->right);
2459 if (fma_reduce.get()) {
2460 return fma_reduce;
2461 }
2462
2463// fma(c1,x,c2)*(c3 + x) -> fma(fma(c1,x,c4),x,c5)
2464// fma(fma(c1,x,c2),x,c3)*(c4 + x) -> fma(fma(fma(c1,x,c5),x,c6),x,c7)
2465// etc...
2466 auto ra = add_cast(this->right);
2467 if (ra.get()) {
2468 auto fma_expand = this->expand_nested_fma_times_add(this->left,
2469 ra);
2470 if (fma_expand.get()) {
2471 return fma_expand;
2472 }
2473 }
2474
2475// Cases like
2476// (c/exp(a))*(exp(b)/d) -> (c/d)*(exp(b)/exp(a))
2477// (c/exp(a))*(d/exp(b)) -> (c*e)/(exp(b)*exp(a))
2478// (exp(a)/c)*(d/exp(b)) -> (d/c)*(exp(a)/exp(b))
2479// (exp(a)/c)*(exp(b)/d) -> (exp(a)*exp(b))/(c*d)
2480// Are taken care of by (a/b)*(c/d) -> (a*c)/(b*d) conversion above.
2481
2482 return this->shared_from_this();
2483 }
2484
2485//------------------------------------------------------------------------------
2492//------------------------------------------------------------------------------
2494 if (this->is_match(x)) {
2495 return one<T, SAFE_MATH> ();
2496 }
2497
2498 const size_t hash = reinterpret_cast<size_t> (x.get());
2499 if (this->df_cache.find(hash) == this->df_cache.end()) {
2500 this->df_cache[hash] = this->left->df(x)*this->right
2501 + this->left*this->right->df(x);
2502 }
2503 return this->df_cache[hash];
2504 }
2505
2506//------------------------------------------------------------------------------
2514//------------------------------------------------------------------------------
2516 compile(std::ostringstream &stream,
2517 jit::register_map &registers,
2519 const jit::register_usage &usage) {
2520 if (registers.find(this) == registers.end()) {
2521 shared_leaf<T, SAFE_MATH> l = this->left->compile(stream,
2522 registers,
2523 indices,
2524 usage);
2525 shared_leaf<T, SAFE_MATH> r = this->right->compile(stream,
2526 registers,
2527 indices,
2528 usage);
2529
2530 registers[this] = jit::to_string('r', this);
2531 stream << " const ";
2532 jit::add_type<T> (stream);
2533 stream << " " << registers[this] << " = ";
2534 if constexpr (SAFE_MATH) {
2535 stream << "(" << registers[l.get()] << " == ";
2536 if constexpr (jit::complex_scalar<T>) {
2537 jit::add_type<T> (stream);
2538 stream << "(0, 0)";
2539 } else {
2540 stream << "0";
2541 }
2542 stream << " || " << registers[r.get()] << " == ";
2543 if constexpr (jit::complex_scalar<T>) {
2544 jit::add_type<T> (stream);
2545 stream << "(0, 0)";
2546 } else {
2547 stream << "0";
2548 }
2549 stream << ") ? ";
2550 if constexpr (jit::complex_scalar<T>) {
2551 jit::add_type<T> (stream);
2552 stream << "(0, 0)";
2553 } else {
2554 stream << "0";
2555 }
2556 stream << " : ";
2557 }
2558 stream << registers[l.get()] << "*"
2559 << registers[r.get()];
2560 this->endline(stream, usage);
2561 }
2562
2563 return this->shared_from_this();
2564 }
2565
2566//------------------------------------------------------------------------------
2571//------------------------------------------------------------------------------
2573 if (this == x.get()) {
2574 return true;
2575 }
2576
2577 auto x_cast = multiply_cast(x);
2578 if (x_cast.get()) {
2579// Multiplication is commutative.
2580 if ((this->left->is_match(x_cast->get_left()) &&
2581 this->right->is_match(x_cast->get_right())) ||
2582 (this->right->is_match(x_cast->get_left()) &&
2583 this->left->is_match(x_cast->get_right()))) {
2584 return true;
2585 }
2586 }
2587
2588 return false;
2589 }
2590
2591//------------------------------------------------------------------------------
2593//------------------------------------------------------------------------------
2594 virtual void to_latex() const {
2595 if (constant_cast(this->left).get() ||
2596 add_cast(this->left).get() ||
2597 subtract_cast(this->left).get()) {
2598 std::cout << "\\left(";
2599 this->left->to_latex();
2600 std::cout << "\\right)";
2601 } else {
2602 this->left->to_latex();
2603 }
2604 std::cout << " ";
2605 if (constant_cast(this->right).get() ||
2606 add_cast(this->right).get() ||
2607 subtract_cast(this->right).get()) {
2608 std::cout << "\\left(";
2609 this->right->to_latex();
2610 std::cout << "\\right)";
2611 } else {
2612 this->right->to_latex();
2613 }
2614 }
2615
2616//------------------------------------------------------------------------------
2620//------------------------------------------------------------------------------
2622 if (this->has_pseudo()) {
2623 return this->left->remove_pseudo() *
2624 this->right->remove_pseudo();
2625 }
2626 return this->shared_from_this();
2627 }
2628
2629//------------------------------------------------------------------------------
2635//------------------------------------------------------------------------------
2636 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
2637 jit::register_map &registers) {
2638 if (registers.find(this) == registers.end()) {
2639 const std::string name = jit::to_string('r', this);
2640 registers[this] = name;
2641 stream << " " << name
2642 << " [label = \"⨉\", shape = oval, style = filled, fillcolor = blue, fontcolor = white];" << std::endl;
2643
2644 auto l = this->left->to_vizgraph(stream, registers);
2645 stream << " " << name << " -- " << registers[l.get()] << ";" << std::endl;
2646 auto r = this->right->to_vizgraph(stream, registers);
2647 stream << " " << name << " -- " << registers[r.get()] << ";" << std::endl;
2648 }
2649
2650 return this->shared_from_this();
2651 }
2652 };
2653
2654//------------------------------------------------------------------------------
2662//------------------------------------------------------------------------------
2663 template<jit::float_scalar T, bool SAFE_MATH=false>
2666 auto temp = std::make_shared<multiply_node<T, SAFE_MATH>> (l, r)->reduce();
2667// Test for hash collisions.
2668 for (size_t i = temp->get_hash();
2670 if (leaf_node<T, SAFE_MATH>::caches.nodes.find(i) ==
2671 leaf_node<T, SAFE_MATH>::caches.nodes.end()) {
2673 return temp;
2674 } else if (temp->is_match(leaf_node<T, SAFE_MATH>::caches.nodes[i])) {
2675 return leaf_node<T, SAFE_MATH>::caches.nodes[i];
2676 }
2677 }
2678#if defined(__clang__) || defined(__GNUC__)
2680#else
2681 assert(false && "Should never reach.");
2682#endif
2683 }
2684
2685//------------------------------------------------------------------------------
2696//------------------------------------------------------------------------------
2697 template<jit::float_scalar T, bool SAFE_MATH=false>
2702
2703//------------------------------------------------------------------------------
2715//------------------------------------------------------------------------------
2716 template<jit::float_scalar T, jit::float_scalar L, bool SAFE_MATH=false>
2721
2722//------------------------------------------------------------------------------
2734//------------------------------------------------------------------------------
2735 template<jit::float_scalar T, jit::float_scalar R, bool SAFE_MATH=false>
2740
2742 template<jit::float_scalar T, bool SAFE_MATH=false>
2743 using shared_multiply = std::shared_ptr<multiply_node<T, SAFE_MATH>>;
2744
2745//------------------------------------------------------------------------------
2753//------------------------------------------------------------------------------
2754 template<jit::float_scalar T, bool SAFE_MATH=false>
2756 return std::dynamic_pointer_cast<multiply_node<T, SAFE_MATH>> (x);
2757 }
2758
2759//******************************************************************************
2760// Divide node.
2761//******************************************************************************
2762//------------------------------------------------------------------------------
2767//------------------------------------------------------------------------------
2768 template<jit::float_scalar T, bool SAFE_MATH=false>
2769 class divide_node final : public branch_node<T, SAFE_MATH> {
2770 private:
2771//------------------------------------------------------------------------------
2777//------------------------------------------------------------------------------
2778 static std::string to_string(leaf_node<T, SAFE_MATH> *l,
2780 return jit::format_to_string(reinterpret_cast<size_t> (l)) + "/" +
2781 jit::format_to_string(reinterpret_cast<size_t> (r));
2782 }
2783
2784 public:
2785//------------------------------------------------------------------------------
2790//------------------------------------------------------------------------------
2795
2796//------------------------------------------------------------------------------
2802//------------------------------------------------------------------------------
2804 backend::buffer<T> l_result = this->left->evaluate();
2805
2806// If all the elements on the left are zero, return the left side without
2807// reevaluating the right side. Stop this loop early once the first non zero
2808// element is encountered.
2809 if (l_result.is_zero()) {
2810 return l_result;
2811 }
2812
2813 backend::buffer<T> r_result = this->right->evaluate();
2814 return l_result/r_result;
2815 }
2816
2817//------------------------------------------------------------------------------
2821//------------------------------------------------------------------------------
2823// Constant Reductions.
2824 auto l = constant_cast(this->left);
2825 auto r = constant_cast(this->right);
2826
2827 if ((l.get() && l->is(0)) ||
2828 (r.get() && r->is(1))) {
2829 return this->left;
2830 } else if (l.get() && r.get()) {
2831 return constant<T, SAFE_MATH> (this->evaluate());
2832 }
2833
2834 auto pl1 = piecewise_1D_cast(this->left);
2835 auto pr1 = piecewise_1D_cast(this->right);
2836
2837 if (pl1.get() && (r.get() || pl1->is_arg_match(this->right))) {
2838 return piecewise_1D(this->evaluate(), pl1->get_arg(),
2839 pl1->get_scale(), pl1->get_offset());
2840 } else if (pr1.get() && (l.get() || pr1->is_arg_match(this->left))) {
2841 return piecewise_1D(this->evaluate(), pr1->get_arg(),
2842 pr1->get_scale(), pr1->get_offset());
2843 }
2844
2845 auto pl2 = piecewise_2D_cast(this->left);
2846 auto pr2 = piecewise_2D_cast(this->right);
2847
2848 if (pl2.get() && (r.get() || pl2->is_arg_match(this->right))) {
2849 return piecewise_2D(this->evaluate(),
2850 pl2->get_num_columns(),
2851 pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
2852 pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
2853 } else if (pr2.get() && (l.get() || pr2->is_arg_match(this->left))) {
2854 return piecewise_2D(this->evaluate(),
2855 pr2->get_num_columns(),
2856 pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
2857 pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
2858 }
2859
2860// Combine 2D and 1D piecewise constants if a row or column matches.
2861 if (pr2.get() && pr2->is_row_match(this->left)) {
2862 backend::buffer<T> result = pl1->evaluate();
2863 result.divide_row(pr2->evaluate());
2864 return piecewise_2D(result,
2865 pr2->get_num_columns(),
2866 pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
2867 pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
2868 } else if (pr2.get() && pr2->is_col_match(this->left)) {
2869 backend::buffer<T> result = pl1->evaluate();
2870 result.divide_col(pr2->evaluate());
2871 return piecewise_2D(result,
2872 pr2->get_num_columns(),
2873 pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
2874 pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
2875 } else if (pl2.get() && pl2->is_row_match(this->right)) {
2876 backend::buffer<T> result = pl2->evaluate();
2877 result.divide_row(pr1->evaluate());
2878 return piecewise_2D(result,
2879 pl2->get_num_columns(),
2880 pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
2881 pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
2882 } else if (pl2.get() && pl2->is_col_match(this->right)) {
2883 backend::buffer<T> result = pl2->evaluate();
2884 result.divide_col(pr1->evaluate());
2885 return piecewise_2D(result,
2886 pl2->get_num_columns(),
2887 pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
2888 pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
2889 }
2890
2891 if (this->left->is_match(this->right)) {
2892 return one<T, SAFE_MATH> ();
2893 }
2894
2895// Reduce cases of a/c1 -> c2*a
2896 if (this->right->is_constant()) {
2897 return (1.0/this->right)*this->left;
2898 }
2899
2900// a/(b/c + d) -> a*c/(c*d + b)
2901// a/(d + b/c) -> a*c/(c*d + b)
2902 auto ra = add_cast(this->right);
2903 if (ra.get()) {
2904 auto rald = divide_cast(ra->get_left());
2905 auto rard = divide_cast(ra->get_right());
2906 if (rald.get()) {
2907 return this->left*rald->get_right() /
2908 fma(rald->get_right(),
2909 ra->get_right(),
2910 rald->get_left());
2911 } else if (rard.get()) {
2912 return this->left*rard->get_right() /
2913 fma(rard->get_right(),
2914 ra->get_left(),
2915 rard->get_left());
2916 }
2917 }
2918
2919// a/(b/c - d) -> a*c/(b - c*d)
2920// a/(d - b/c) -> a*c/(c*d - b)
2921 auto rs = subtract_cast(this->right);
2922 if (rs.get()) {
2923 auto rsld = divide_cast(rs->get_left());
2924 auto rsrd = divide_cast(rs->get_right());
2925 if (rsld.get()) {
2926 return this->left*rsld->get_right() /
2927 (rsld->get_left() -
2928 rsld->get_right()*rs->get_right());
2929 } else if (rsrd.get()) {
2930 return this->left*rsrd->get_right() /
2931 (rsrd->get_right()*rs->get_left() -
2932 rsrd->get_left());
2933 }
2934 }
2935
2936// fma(a,d,c*d)/d -> a + c
2937// fma(a,d,d*c)/d -> a + c
2938// fma(d,a,c*d)/d -> a + c
2939// fma(d,a,d*c)/d -> a + c
2940 auto lfma = fma_cast(this->left);
2941 if (lfma.get()) {
2942 auto fmarm = multiply_cast(lfma->get_right());
2943 if (fmarm.get()) {
2944 if (lfma->get_middle()->is_match(this->right) &&
2945 fmarm->get_right()->is_match(this->right)) {
2946 return lfma->get_left() + fmarm->get_left();
2947 } else if (lfma->get_middle()->is_match(this->right) &&
2948 fmarm->get_left()->is_match(this->right)) {
2949 return lfma->get_left() + fmarm->get_right();
2950 } else if (lfma->get_left()->is_match(this->right) &&
2951 fmarm->get_right()->is_match(this->right)) {
2952 return lfma->get_middle() + fmarm->get_left();
2953 } else if (lfma->get_left()->is_match(this->right) &&
2954 fmarm->get_left()->is_match(this->right)) {
2955 return lfma->get_middle() + fmarm->get_right();
2956 }
2957 }
2958 }
2959
2960// Common factor reduction. (a*b)/(a*c) = b/c.
2961 auto lm = multiply_cast(this->left);
2962 auto rm = multiply_cast(this->right);
2963
2964 if (lm.get() && rm.get()) {
2965 if (is_variable_combinable(lm->get_left(),
2966 rm->get_left()) ||
2967 is_variable_combinable(lm->get_right(),
2968 rm->get_right())) {
2969 return (lm->get_left()/rm->get_left()) *
2970 (lm->get_right()/rm->get_right());
2971 } else if (is_variable_combinable(lm->get_left(),
2972 rm->get_right()) ||
2973 is_variable_combinable(lm->get_right(),
2974 rm->get_left())) {
2975 return (lm->get_left()/rm->get_right()) *
2976 (lm->get_right()/rm->get_left());
2977 }
2978 }
2979
2980// Move constants to the numerator.
2981// a/(c1*b) -> (c2*a)/b
2982// a/(b*c1) -> (c2*a)/b
2983 if (rm.get()) {
2984 if (rm->get_left()->is_constant() &&
2985 rm->get_left()->is_normal()) {
2986 return ((1.0/rm->get_left())*this->left)/rm->get_right();
2987 } else if (rm->get_right()->is_constant() &&
2988 rm->get_right()->is_normal()) {
2989 return ((1.0/rm->get_right())*this->left)/rm->get_left();
2990 }
2991
2992// a/((b/c + d)*e) -> a*c/((c*d + b)*e)
2993// a/((d + b/c)*e) -> a*c/((c*d + b)*e)
2994// a/(e*(b/c + d)) -> a*c/((c*d + b)*e)
2995// a/(e*(d + b/c)) -> a*c/((c*d + b)*e)
2996 auto rmla = add_cast(rm->get_left());
2997 auto rmra = add_cast(rm->get_right());
2998 if (rmla.get()) {
2999 auto rmlald = divide_cast(rmla->get_left());
3000 auto rmlard = divide_cast(rmla->get_right());
3001 if (rmlald.get()) {
3002 return this->left*rmlald->get_right() /
3003 (fma(rmlald->get_right(),
3004 rmla->get_right(),
3005 rmlald->get_left())*rm->get_right());
3006 } else if (rmlard.get()) {
3007 return this->left*rmlard->get_right() /
3008 (fma(rmlard->get_right(),
3009 rmla->get_left(),
3010 rmlard->get_left())*rm->get_right());
3011 }
3012 }
3013 if (rmra.get()) {
3014 auto rmrald = divide_cast(rmra->get_left());
3015 auto rmrard = divide_cast(rmra->get_right());
3016 if (rmrald.get()) {
3017 return this->left*rmrald->get_right() /
3018 (fma(rmrald->get_right(),
3019 rmra->get_right(),
3020 rmrald->get_left())*rm->get_left());
3021 } else if (rmrard.get()) {
3022 return this->left*rmrard->get_right() /
3023 (fma(rmrard->get_right(),
3024 rmra->get_left(),
3025 rmrard->get_left())*rm->get_left());
3026 }
3027 }
3028
3029// a/((b/c - d)*e) -> a*c/((b - c*d)*e)
3030// a/(e*(b/c - d)) -> a*c/((b - c*d)*e)
3031// a/((d - b/c)*e) -> a*c/((c*d - b)*e)
3032// a/(e*(d - b/c)) -> a*c/((c*d - b)*e)
3033 auto rmls = subtract_cast(rm->get_left());
3034 auto rmrs = subtract_cast(rm->get_right());
3035 if (rmls.get()) {
3036 auto rmlsld = divide_cast(rmls->get_left());
3037 auto rmlsrd = divide_cast(rmls->get_right());
3038 if (rmlsld.get()) {
3039 return this->left*rmlsld->get_right() /
3040 ((rmlsld->get_left() -
3041 rmlsld->get_right()*rmls->get_right())*rm->get_right());
3042 } else if (rmlsrd.get()) {
3043 return this->left*rmlsrd->get_right() /
3044 ((rmlsrd->get_right()*rmls->get_left() -
3045 rmlsrd->get_left())*rm->get_right());
3046 }
3047 }
3048 if (rmrs.get()) {
3049 auto rmrsld = divide_cast(rmrs->get_left());
3050 auto rmrsrd = divide_cast(rmrs->get_right());
3051 if (rmrsld.get()) {
3052 return this->left*rmrsld->get_right() /
3053 ((rmrsld->get_left() -
3054 rmrsld->get_right()*rmrs->get_right())*rm->get_left());
3055 } else if (rmrsrd.get()) {
3056 return this->left*rmrsrd->get_right() /
3057 ((rmrsrd->get_right()*rmrs->get_left() -
3058 rmrsrd->get_left())*rm->get_left());
3059 }
3060 }
3061 }
3062
3063 if (lm.get() && rm.get()) {
3064// (a*b)/(a*c) -> b/c
3065// (b*a)/(a*c) -> b/c
3066// (a*b)/(c*a) -> b/c
3067// (b*a)/(c*a) -> b/c
3068 if (lm->get_left()->is_match(rm->get_left())) {
3069 return lm->get_right()/rm->get_right();
3070 } else if (lm->get_left()->is_match(rm->get_right())) {
3071 return lm->get_right()/rm->get_left();
3072 } else if (lm->get_right()->is_match(rm->get_left())) {
3073 return lm->get_left()/rm->get_right();
3074 } else if (lm->get_right()->is_match(rm->get_right())) {
3075 return lm->get_left()/rm->get_left();
3076 }
3077 }
3078
3079 if (lm.get()) {
3080// (v1*v2)/v1 -> v2
3081// (v2*v1)/v1 -> v2
3082 if (lm->get_left()->is_match(this->right)) {
3083 return lm->get_right();
3084 } else if (lm->get_right()->is_match(this->right)) {
3085 return lm->get_left();
3086 }
3087
3088// (v1^a*v2)/v1^b -> v2*(v1^a/v1^b)
3089// (v2*v1^a)/v1^b -> v2*(v1^a/v1^b)
3090 if (is_variable_combinable(lm->get_left(),
3091 this->right)) {
3092 return lm->get_right()*(lm->get_left()/this->right);
3093 } else if (is_variable_combinable(lm->get_right(),
3094 this->right)) {
3095 return lm->get_left()*(lm->get_right()/this->right);
3096 }
3097 }
3098
3099// (a/b)/c -> a/(b*c)
3100// a/(b/c) -> a*c/b
3101 auto ld = divide_cast(this->left);
3102 auto rd = divide_cast(this->right);
3103 if (ld.get()) {
3104 return ld->get_left()/(ld->get_right()*this->right);
3105 }
3106 if (rd.get()) {
3107 return this->left*rd->get_right()/rd->get_left();
3108 }
3109
3110// Power reductions.
3111 if (is_variable_combinable(this->left,
3112 this->right)) {
3113 return pow(this->left->get_power_base(),
3114 this->left->get_power_exponent() -
3115 this->right->get_power_exponent());
3116 }
3117
3118// a/b^-c -> a*b^c
3119 auto rp = pow_cast(this->right);
3120 if (rp.get()) {
3121 auto exponent = constant_cast(rp->get_right());
3122 if (exponent.get() && exponent->evaluate().is_negative()) {
3123 return this->left*pow(rp->get_left(), -rp->get_right());
3124 }
3125 }
3126
3127// (a*b)^c/(a^d) = a^(c - d)*b^c
3128// (b*a)^c/(a^d) = a^(c - d)*b^c
3129 auto lp = pow_cast(this->left);
3130 if (lp.get()) {
3131 auto lpm = multiply_cast(this->left->get_power_base());
3132 if (lpm.get()) {
3133 if (lpm->get_left()->is_match(this->right->get_power_base())) {
3134 return pow(this->right->get_power_base(),
3135 this->left->get_power_exponent() -
3136 this->right->get_power_exponent()) *
3137 pow(lpm->get_right(),
3138 this->left->get_power_exponent());
3139 } else if (lpm->get_right()->is_match(this->right->get_power_base())) {
3140 return pow(this->right->get_power_base(),
3141 this->left->get_power_exponent() -
3142 this->right->get_power_exponent()) *
3143 pow(lpm->get_left(),
3144 this->left->get_power_exponent());
3145 }
3146 }
3147 }
3148
3149// a^b/c^b -> (a/c)^b
3150 if (lp.get() && rp.get()) {
3151 if (lp->get_right()->is_match(rp->get_right())) {
3152 return pow(lp->get_left()/rp->get_left(), lp->get_right());
3153 }
3154 }
3155
3156// (a*b)^c/((a^d)*e) = a^(c - d)*b^c/e
3157// (b*a)^c/((a^d)*e) = a^(c - d)*b^c/e
3158// (a*b)^c/(e*(a^d)) = a^(c - d)*b^c/e
3159// (b*a)^c/(e*(a^d)) = a^(c - d)*b^c/e
3160 if (lp.get() && rm.get()) {
3161 auto lpm = multiply_cast(this->left->get_power_base());
3162 if (lpm.get()) {
3163 if (lpm->get_left()->is_match(rm->get_left()->get_power_base())) {
3164 return (pow(rm->get_left()->get_power_base(),
3165 this->left->get_power_exponent() -
3166 rm->get_left()->get_power_exponent()) *
3167 pow(lpm->get_right(),
3168 this->left->get_power_exponent())) /
3169 rm->get_right();
3170 } else if (lpm->get_right()->is_match(rm->get_left()->get_power_base())) {
3171 return (pow(rm->get_left()->get_power_base(),
3172 this->left->get_power_exponent() -
3173 rm->get_left()->get_power_exponent()) *
3174 pow(lpm->get_left(),
3175 this->left->get_power_exponent())) /
3176 rm->get_right();
3177 } else if (lpm->get_left()->is_match(rm->get_right()->get_power_base())) {
3178 return (pow(rm->get_right()->get_power_base(),
3179 this->left->get_power_exponent() -
3180 rm->get_right()->get_power_exponent()) *
3181 pow(lpm->get_right(),
3182 this->left->get_power_exponent())) /
3183 rm->get_left();
3184 } else if (lpm->get_right()->is_match(rm->get_right()->get_power_base())) {
3185 return (pow(rm->get_right()->get_power_base(),
3186 this->left->get_power_exponent() -
3187 rm->get_right()->get_power_exponent()) *
3188 pow(lpm->get_left(),
3189 this->left->get_power_exponent())) /
3190 rm->get_left();
3191 }
3192 }
3193 }
3194
3195 if (lm.get()) {
3196// a*(b*c)/c -> a*b
3197// a*(c*b)/c -> a*b
3198// (a*c)*b/c -> a*b
3199// (c*a)*b/c -> a*b
3200 auto lmrm = multiply_cast(lm->get_right());
3201 auto lmlm = multiply_cast(lm->get_left());
3202 if (lmrm.get()) {
3203 if (is_variable_combinable(lmrm->get_right(),
3204 this->right)) {
3205 return lm->get_left()*lmrm->get_left() *
3206 (lmrm->get_right()/this->right);
3207 } else if (is_variable_combinable(lmrm->get_left(),
3208 this->right)) {
3209 return lm->get_left()*lmrm->get_right() *
3210 (lmrm->get_left()/this->right);
3211 }
3212 } else if (lmlm.get()) {
3213 if (is_variable_combinable(lmlm->get_right(),
3214 this->right)) {
3215 return lm->get_right()*lmlm->get_left() *
3216 (lmlm->get_right()/this->right);
3217 } else if (is_variable_combinable(lmlm->get_left(),
3218 this->right)) {
3219 return lm->get_right()*lmlm->get_right() *
3220 (lmlm->get_left()/this->right);
3221 }
3222 }
3223
3224// (f*(a*b)^c)/(a^d) = f*a^(c - d)*b^c
3225// (f*(b*a)^c)/(a^d) = f*a^(c - d)*b^c
3226// (((a*b)^c)*f)/(a^d) = f*a^(c - d)*b^c
3227// (((b*a)^c)*f)/(a^d) = f*a^(c - d)*b^c
3228 auto lmlp = pow_cast(lm->get_left());
3229 auto lmrp = pow_cast(lm->get_right());
3230 if (lmlp.get()) {
3231 auto lmlpm = multiply_cast(lmlp->get_power_base());
3232 if (lmlpm.get()) {
3233 if (lmlpm->get_left()->is_match(this->right->get_power_base())) {
3234 return lm->get_right() *
3235 pow(this->right->get_power_base(),
3236 lmlp->get_power_exponent() -
3237 this->right->get_power_exponent()) *
3238 pow(lmlpm->get_right(),
3239 lmlp->get_power_exponent());
3240 } else if (lmlpm->get_right()->is_match(this->right->get_power_base())) {
3241 return lm->get_right() *
3242 pow(this->right->get_power_base(),
3243 lmlp->get_power_exponent() -
3244 this->right->get_power_exponent()) *
3245 pow(lmlpm->get_left(),
3246 lmlp->get_power_exponent());
3247 }
3248 }
3249 } else if (lmrp.get()) {
3250 auto lmrpm = multiply_cast(lmrp->get_power_base());
3251 if (lmrpm.get()) {
3252 if (lmrpm->get_left()->is_match(this->right->get_power_base())) {
3253 return lm->get_left() *
3254 pow(this->right->get_power_base(),
3255 lmrp->get_power_exponent() -
3256 this->right->get_power_exponent()) *
3257 pow(lmrpm->get_right(),
3258 lmrp->get_power_exponent());
3259 } else if (lmrpm->get_right()->is_match(this->right->get_power_base())) {
3260 return lm->get_left() *
3261 pow(this->right->get_power_base(),
3262 lmrp->get_power_exponent() -
3263 this->right->get_power_exponent()) *
3264 pow(lmrpm->get_left(),
3265 lmrp->get_power_exponent());
3266 }
3267 }
3268 }
3269 }
3270
3271// f*(a*b)^c/((a^d)*e) = a^(c - d)*b^c/e
3272// f*(b*a)^c/((a^d)*e) = a^(c - d)*b^c/e
3273// f*(a*b)^c/(e*(a^d)) = a^(c - d)*b^c/e
3274// f*(b*a)^c/(e*(a^d)) = a^(c - d)*b^c/e
3275// (a*b)^c*f/((a^d)*e) = a^(c - d)*b^c/e
3276// (b*a)^c*f/((a^d)*e) = a^(c - d)*b^c/e
3277// (a*b)^c*f/(e*(a^d)) = a^(c - d)*b^c/e
3278// (b*a)^c*f/(e*(a^d)) = a^(c - d)*b^c/e
3279 if (lm.get() && rm.get()) {
3280 auto lmlp = pow_cast(lm->get_left());
3281 auto lmrp = pow_cast(lm->get_right());
3282 if (lmlp.get()) {
3283 auto lmlpm = multiply_cast(lmlp->get_power_base());
3284 if (lmlpm.get()) {
3285 if (lmlpm->get_left()->is_match(rm->get_left()->get_power_base())) {
3286 return lm->get_right() *
3287 (pow(rm->get_left()->get_power_base(),
3288 lmlp->get_power_exponent() -
3289 rm->get_left()->get_power_exponent())) *
3290 pow(lmlpm->get_right(),
3291 lmlp->get_power_exponent()) /
3292 rm->get_right();
3293 } else if (lmlpm->get_right()->is_match(rm->get_left()->get_power_base())) {
3294 return lm->get_right() *
3295 (pow(rm->get_left()->get_power_base(),
3296 lmlp->get_power_exponent() -
3297 rm->get_left()->get_power_exponent())) *
3298 pow(lmlpm->get_left(),
3299 lmlp->get_power_exponent()) /
3300 rm->get_right();
3301 } else if (lmlpm->get_left()->is_match(rm->get_right()->get_power_base())) {
3302 return lm->get_right() *
3303 (pow(rm->get_left()->get_power_base(),
3304 lmlp->get_power_exponent() -
3305 rm->get_right()->get_power_exponent())) *
3306 pow(lmlpm->get_right(),
3307 lmlp->get_power_exponent()) /
3308 rm->get_left();
3309 } else if (lmlpm->get_right()->is_match(rm->get_right()->get_power_base())) {
3310 return lm->get_right() *
3311 (pow(rm->get_left()->get_power_base(),
3312 lmlp->get_power_exponent() -
3313 rm->get_right()->get_power_exponent())) *
3314 pow(lmlpm->get_left(),
3315 lmlp->get_power_exponent()) /
3316 rm->get_left();
3317 }
3318 }
3319 } else if (lmrp.get()) {
3320 auto lmrpm = multiply_cast(lmrp->get_power_base());
3321 if (lmrpm.get()) {
3322 if (lmrpm->get_left()->is_match(rm->get_left()->get_power_base())) {
3323 return lm->get_left() *
3324 (pow(rm->get_left()->get_power_base(),
3325 lmrp->get_power_exponent() -
3326 rm->get_left()->get_power_exponent())) *
3327 pow(lmrpm->get_right(),
3328 lmrp->get_power_exponent()) /
3329 rm->get_right();
3330 } else if (lmrpm->get_right()->is_match(rm->get_left()->get_power_base())) {
3331 return lm->get_left() *
3332 (pow(rm->get_left()->get_power_base(),
3333 lmrp->get_power_exponent() -
3334 rm->get_left()->get_power_exponent())) *
3335 pow(lmrpm->get_left(),
3336 lmrp->get_power_exponent()) /
3337 rm->get_right();
3338 } else if (lmrpm->get_left()->is_match(rm->get_right()->get_power_base())) {
3339 return lm->get_left() *
3340 (pow(rm->get_left()->get_power_base(),
3341 lmrp->get_power_exponent() -
3342 rm->get_right()->get_power_exponent())) *
3343 pow(lmrpm->get_right(),
3344 lmrp->get_power_exponent()) /
3345 rm->get_left();
3346 } else if (lmrpm->get_right()->is_match(rm->get_right()->get_power_base())) {
3347 return lm->get_left() *
3348 (pow(rm->get_left()->get_power_base(),
3349 lmrp->get_power_exponent() -
3350 rm->get_right()->get_power_exponent())) *
3351 pow(lmrpm->get_left(),
3352 lmrp->get_power_exponent()) /
3353 rm->get_left();
3354 }
3355 }
3356 }
3357 }
3358
3359// exp(a)/exp(b) -> exp(a - b)
3360 auto lexp = exp_cast(this->left);
3361 auto rexp = exp_cast(this->right);
3362 if (lexp.get() && rexp.get()) {
3363 return exp(lexp->get_arg() - rexp->get_arg());
3364 }
3365
3366// (c*exp(a))/exp(b) -> c*(exp(a)/exp(b))
3367// (exp(a)*c)/exp(b) -> c*(exp(a)/exp(b))
3368 if (rexp.get() && lm.get()) {
3369 auto lmre = exp_cast(lm->get_right());
3370 if (lmre.get()) {
3371 return lm->get_left()*(lm->get_right()/this->right);
3372 }
3373 auto lmle = exp_cast(lm->get_left());
3374 if (lmle.get()) {
3375 return lm->get_right()*(lm->get_left()/this->right);
3376 }
3377 }
3378// ((c*exp(a))*d)/exp(b)
3379// ((exp(a)*c)*d)/exp(b)
3380// (c*(exp(a)*d))/exp(b)
3381// (c*(d*exp(a)))/exp(b)
3382 if (rexp.get() && lm.get()) {
3383 auto lmlm = multiply_cast(lm->get_left());
3384 auto lmrm = multiply_cast(lm->get_right());
3385
3386 if (lmlm.get()) {
3387 if (exp_cast(lmlm->get_right()).get()) {
3388 return lmlm->get_left()*lm->get_right() *
3389 (lmlm->get_right()/this->right);
3390 } else if (exp_cast(lmlm->get_left()).get()) {
3391 return lmlm->get_right()*lm->get_right() *
3392 (lmlm->get_left()/this->right);
3393 }
3394 } else if (lmrm.get()) {
3395 if (exp_cast(lmrm->get_right()).get()) {
3396 return lmrm->get_left()*lm->get_left() *
3397 (lmrm->get_right()/this->right);
3398 } else if (exp_cast(lmrm->get_left()).get()) {
3399 return lmrm->get_right()*lm->get_left() *
3400 (lmrm->get_left()/this->right);
3401 }
3402 }
3403 }
3404
3405// exp(a)/(c*exp(b)) -> (exp(a)/exp(b))/c
3406// exp(a)/(exp(b)*c) -> (exp(a)/exp(b))/c
3407 if (lexp.get() && rm.get()) {
3408 auto rmre = exp_cast(rm->get_right());
3409 if (rmre.get()) {
3410 return (this->left/rm->get_right())/rm->get_left();
3411 }
3412 auto rmle = exp_cast(rm->get_left());
3413 if (rmle.get()) {
3414 return (this->left/rm->get_left())/rm->get_right();
3415 }
3416 }
3417
3418// (c*exp(a))/(d*exp(b)) -> (c/d)*(exp(a)/exp(b))
3419// (c*exp(a))/(exp(b)*d) -> (c/d)*(exp(a)/exp(b))
3420// (exp(a)*c)/(d*exp(b)) -> (c/d)*(exp(a)/exp(b))
3421// (exp(a)*c)/(exp(b)*d) -> (c/d)*(exp(a)/exp(b))
3422 if (lm.get() && rm.get()) {
3423 auto lmre = exp_cast(lm->get_right());
3424 if (lmre.get()) {
3425 auto rmre = exp_cast(rm->get_right());
3426 if (rmre.get()) {
3427 return (lm->get_left()/rm->get_left()) *
3428 (lm->get_right()/rm->get_right());
3429 }
3430 auto rmle = exp_cast(rm->get_left());
3431 if (rmle.get()) {
3432 return (lm->get_left()/rm->get_right()) *
3433 (lm->get_right()/rm->get_left());
3434 }
3435 }
3436 auto lmle = exp_cast(lm->get_left());
3437 if (lmle.get()) {
3438 auto rmre = exp_cast(rm->get_right());
3439 if (rmre.get()) {
3440 return (lm->get_right()/rm->get_left()) *
3441 (lm->get_left()/rm->get_right());
3442 }
3443 auto rmle = exp_cast(rm->get_left());
3444 if (rmle.get()) {
3445 return (lm->get_right()/rm->get_right()) *
3446 (lm->get_left()/rm->get_left());
3447 }
3448 }
3449 }
3450
3451// exp(a)/(c/exp(b)) -> (exp(a)*exp(b))/c
3452// exp(a)/(exp(b)/c) -> c*(exp(a)/exp(b))
3453 if (rd.get() && lexp.get()) {
3454 auto rdre = exp_cast(rd->get_right());
3455 if (rdre.get()) {
3456 return (this->left*rd->get_right())/rd->get_left();
3457 }
3458 auto rdle = exp_cast(rd->get_left());
3459 if (rdle.get()) {
3460 return rd->get_right()*(this->left/rd->get_left());
3461 }
3462 }
3463
3464// (c/exp(a))/exp(b) -> c/(exp(a)*exp(b))
3465// (exp(a)/c)/exp(b) -> exp(a)/(c*exp(b))
3466// (c/exp(a))/(d/exp(b)) -> (c*exp(b))/(d*exp(a))
3467// (c/exp(a))/(exp(b)/d) -> (c*d)/(exp(b)*exp(a))
3468// (exp(a)/c)/(d/exp(b)) -> (exp(a)*exp(b))/(d*c)
3469// (exp(a)/c)/(exp(b)/d) -> (exp(a)*d)/(exp(b)*c)
3470// Note cases like this are already transformed by the (a/b)/c -> a/(b*c)
3471// above.
3472
3473 return this->shared_from_this();
3474 }
3475
3476//------------------------------------------------------------------------------
3483//------------------------------------------------------------------------------
3486 if (this->is_match(x)) {
3487 return one<T, SAFE_MATH> ();
3488 }
3489
3490 const size_t hash = reinterpret_cast<size_t> (x.get());
3491 if (this->df_cache.find(hash) == this->df_cache.end()) {
3492 this->df_cache[hash] = this->left->df(x)/this->right
3493 - this->left*this->right->df(x)/(this->right*this->right);
3494 }
3495 return this->df_cache[hash];
3496 }
3497
3498//------------------------------------------------------------------------------
3506//------------------------------------------------------------------------------
3508 compile(std::ostringstream &stream,
3509 jit::register_map &registers,
3511 const jit::register_usage &usage) {
3512 if (registers.find(this) == registers.end()) {
3513 shared_leaf<T, SAFE_MATH> l = this->left->compile(stream,
3514 registers,
3515 indices,
3516 usage);
3517 shared_leaf<T, SAFE_MATH> r = this->right->compile(stream,
3518 registers,
3519 indices,
3520 usage);
3521
3522 registers[this] = jit::to_string('r', this);
3523 stream << " const ";
3524 jit::add_type<T> (stream);
3525 stream << " " << registers[this] << " = ";
3526 if constexpr (SAFE_MATH) {
3527 stream << registers[l.get()] << " == ";
3528 if constexpr (jit::complex_scalar<T>) {
3529 jit::add_type<T> (stream);
3530 stream << "(0, 0)";
3531 } else {
3532 stream << "0";
3533 }
3534 stream << " ? ";
3535 if constexpr (jit::complex_scalar<T>) {
3536 jit::add_type<T> (stream);
3537 stream << "(0, 0)";
3538 } else {
3539 stream << "0";
3540 }
3541 stream << " : ";
3542 }
3543 stream << registers[l.get()] << "/"
3544 << registers[r.get()];
3545 this->endline(stream, usage);
3546 }
3547 return this->shared_from_this();
3548 }
3549
3550//------------------------------------------------------------------------------
3555//------------------------------------------------------------------------------
3557 if (this == x.get()) {
3558 return true;
3559 }
3560
3561 auto x_cast = divide_cast(x);
3562 if (x_cast.get()) {
3563 return this->left->is_match(x_cast->get_left()) &&
3564 this->right->is_match(x_cast->get_right());
3565 }
3566
3567 return false;
3568 }
3569
3570//------------------------------------------------------------------------------
3572//------------------------------------------------------------------------------
3573 virtual void to_latex() const {
3574 std::cout << "\\frac{";
3575 this->left->to_latex();
3576 std::cout << "}{";
3577 this->right->to_latex();
3578 std::cout << "}";
3579 }
3580
3581//------------------------------------------------------------------------------
3585//------------------------------------------------------------------------------
3587 if (this->has_pseudo()) {
3588 return this->left->remove_pseudo() /
3589 this->right->remove_pseudo();
3590 }
3591 return this->shared_from_this();
3592 }
3593
3594//------------------------------------------------------------------------------
3600//------------------------------------------------------------------------------
3601 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
3602 jit::register_map &registers) {
3603 if (registers.find(this) == registers.end()) {
3604 const std::string name = jit::to_string('r', this);
3605 registers[this] = name;
3606 stream << " " << name
3607 << " [label = \"\\\\\", shape = oval, style = filled, fillcolor = blue, fontcolor = white];" << std::endl;
3608
3609 auto l = this->left->to_vizgraph(stream, registers);
3610 stream << " " << name << " -- " << registers[l.get()] << ";" << std::endl;
3611 auto r = this->right->to_vizgraph(stream, registers);
3612 stream << " " << name << " -- " << registers[r.get()] << ";" << std::endl;
3613 }
3614
3615 return this->shared_from_this();
3616 }
3617 };
3618
3619//------------------------------------------------------------------------------
3627//------------------------------------------------------------------------------
3628 template<jit::float_scalar T, bool SAFE_MATH=false>
3631 auto temp = std::make_shared<divide_node<T, SAFE_MATH>> (l, r)->reduce();
3632// Test for hash collisions.
3633 for (size_t i = temp->get_hash();
3635 if (leaf_node<T, SAFE_MATH>::caches.nodes.find(i) ==
3636 leaf_node<T, SAFE_MATH>::caches.nodes.end()) {
3638 return temp;
3639 } else if (temp->is_match(leaf_node<T, SAFE_MATH>::caches.nodes[i])) {
3640 return leaf_node<T, SAFE_MATH>::caches.nodes[i];
3641 }
3642 }
3643#if defined(__clang__) || defined(__GNUC__)
3645#else
3646 assert(false && "Should never reach.");
3647#endif
3648 }
3649
3650//------------------------------------------------------------------------------
3661//------------------------------------------------------------------------------
3662 template<jit::float_scalar T, bool SAFE_MATH=false>
3667
3668//------------------------------------------------------------------------------
3680//------------------------------------------------------------------------------
3681 template<jit::float_scalar T, jit::float_scalar L, bool SAFE_MATH=false>
3686
3687//------------------------------------------------------------------------------
3699//------------------------------------------------------------------------------
3700 template<jit::float_scalar T, jit::float_scalar R, bool SAFE_MATH=false>
3705
3707 template<jit::float_scalar T, bool SAFE_MATH=false>
3708 using shared_divide = std::shared_ptr<divide_node<T, SAFE_MATH>>;
3709
3710//------------------------------------------------------------------------------
3718//------------------------------------------------------------------------------
3719 template<jit::float_scalar T, bool SAFE_MATH=false>
3721 return std::dynamic_pointer_cast<divide_node<T, SAFE_MATH>> (x);
3722 }
3723
3724//******************************************************************************
3725// fused multiply add node.
3726//******************************************************************************
3727//------------------------------------------------------------------------------
3734//------------------------------------------------------------------------------
3735 template<jit::float_scalar T, bool SAFE_MATH=false>
3736 class fma_node final : public triple_node<T, SAFE_MATH> {
3737 private:
3738//------------------------------------------------------------------------------
3745//------------------------------------------------------------------------------
3747 reduce_nested_fma(shared_subtract<T, SAFE_MATH> sub) {
3748 auto temp = fma_cast(this->left);
3749 if (temp.get()) {
3750 if (is_constant_combinable(sub->get_right(), temp->get_left()) &&
3751 is_constant_combinable(sub->get_right(), temp->get_right()) &&
3752 is_constant_combinable(this->right, temp->get_right()) &&
3753 temp->get_middle()->is_match(sub->get_left())) {
3754 return fma(fma(temp->get_left(),
3755 sub->get_left(),
3756 temp->get_right() - temp->get_left()*sub->get_right()),
3757 sub->get_left(),
3758 this->right - temp->get_right()*sub->get_right());
3759 } else {
3760 if (temp->get_middle()->is_match(sub->get_left()) &&
3761 is_constant_combinable(sub->get_right(), this->right)) {
3762 auto temp2 = temp->reduce_nested_fma(sub);
3763 if (temp2.get()) {
3764 return fma(temp2,
3765 sub->get_left(),
3766 this->right - temp->get_right()*sub->get_right());
3767 }
3768 }
3769 }
3770 }
3771 return this->shared_from_this();
3772 }
3773
3774//------------------------------------------------------------------------------
3781//------------------------------------------------------------------------------
3782 static std::string to_string(leaf_node<T, SAFE_MATH> *l,
3785 return "fma" + jit::format_to_string(reinterpret_cast<size_t> (l))
3786 + jit::format_to_string(reinterpret_cast<size_t> (m))
3787 + jit::format_to_string(reinterpret_cast<size_t> (r));
3788 }
3789
3790 public:
3791//------------------------------------------------------------------------------
3797//------------------------------------------------------------------------------
3804
3805//------------------------------------------------------------------------------
3811//------------------------------------------------------------------------------
3813 backend::buffer<T> l_result = this->left->evaluate();
3814 backend::buffer<T> r_result = this->right->evaluate();
3815
3816// If all the elements on the left are zero, return the left side without
3817// reevaluating the right side.
3818 if (l_result.is_zero()) {
3819 return r_result;
3820 }
3821
3822 backend::buffer<T> m_result = this->middle->evaluate();
3823 return backend::fma(l_result, m_result, r_result);
3824 }
3825
3826//------------------------------------------------------------------------------
3830//------------------------------------------------------------------------------
3832 auto l = constant_cast(this->left);
3833 auto m = constant_cast(this->middle);
3834 auto r = constant_cast(this->right);
3835
3836 if ((l.get() && l->is(0)) ||
3837 (m.get() && m->is(0))) {
3838 return this->right;
3839 } else if (r.get() && r->is(0)) {
3840 return this->left*this->middle;
3841 } else if (l.get() && m.get() && r.get()) {
3842 return constant<T, SAFE_MATH> (this->evaluate());
3843 } else if (l.get() && m.get()) {
3844 return this->left*this->middle + this->right;
3845 } else if (l.get() && l->is(-1)) {
3846 return this->right - this->middle;
3847 } else if (m.get() && m->is(-1)) {
3848 return this->right - this->left;
3849 } else if (l.get() && l->is(1)) {
3850 return this->middle + this->right;
3851 } else if (m.get() && m->is(1)) {
3852 return this->left + this->right;
3853 }
3854
3855// Check if the left and middle are combinable. This will be constant merged in
3856// multiply reduction.
3857 if (is_constant_combinable(this->left, this->middle) ||
3858 is_variable_combinable(this->left, this->middle)) {
3859 return (this->left*this->middle) + this->right;
3860 }
3861
3862// fma(c2,c1,a) -> fma(c1,c2,a)
3863 if (is_constant_promotable(this->middle,
3864 this->left)) {
3865 return fma(this->middle, this->left, this->right);
3866 }
3867
3868// fma(a,b,a) -> a*(1 + b)
3869// fma(b,a,a) -> a*(1 + b)
3870 if (this->left->is_match(this->right)) {
3871 return this->left*(1.0 + this->middle);
3872 } else if (this->middle->is_match(this->right)) {
3873 return this->middle*(1.0 + this->left);
3874 }
3875
3876// fma(c1,c2 + a,c3) -> fma(c4,a,c5)
3877 auto ma = add_cast(this->middle);
3878 if (ma.get()) {
3879 if (is_constant_combinable(this->left, ma->get_left()) &&
3880 is_constant_combinable(this->left, this->right)) {
3881 return fma(this->left,
3882 ma->get_right(),
3883 fma(this->left, ma->get_left(), this->right));
3884 }
3885 }
3886
3887// fma(c1,c2 - a,c3) -> fma(-c1,a,c1*c2 + c3)
3888// fma(c1,a - c2,c3) -> fma(c1,a,c3 - c1*c2)
3889 auto ms = subtract_cast(this->middle);
3890 if (ms.get()) {
3891 if (is_constant_combinable(this->left, ms->get_left()) &&
3892 is_constant_combinable(this->left, this->right)) {
3893 return fma(-this->left, ms->get_right(),
3894 this->left*ms->get_left() + this->right);
3895 } else if (is_constant_combinable(this->left, ms->get_right()) &&
3896 is_constant_combinable(this->left, this->right)) {
3897 return fma(this->left, ms->get_left(),
3898 this->right - this->left*ms->get_right());
3899 }
3900
3901 auto temp = this->reduce_nested_fma(ms);
3902 if (temp.get() != this) {
3903 return temp;
3904 }
3905 }
3906
3907// Common factor reduction. If the left and right are both multiply nodes check
3908// for a common factor. So you can change a*b + (a*c) -> a*(b + c).
3909 auto lm = multiply_cast(this->left);
3910 auto mm = multiply_cast(this->middle);
3911 auto rm = multiply_cast(this->right);
3912 if (rm.get()) {
3913 if (rm->get_left()->is_match(this->left)) {
3914 return this->left*(this->middle + rm->get_right());
3915 } else if (rm->get_left()->is_match(this->middle)) {
3916 return this->middle*(this->left + rm->get_right());
3917 } else if (rm->get_right()->is_match(this->left)) {
3918 return this->left*(this->middle + rm->get_left());
3919 } else if (rm->get_right()->is_match(this->middle)) {
3920 return this->middle*(this->left + rm->get_left());
3921 }
3922
3923// Change case of
3924// fma(a,b,-c1*b) -> a*b - c1*b
3925 auto rmlc = constant_cast(rm->get_left());
3926 if (rmlc.get() && rmlc->evaluate().is_negative()) {
3927 return this->left*this->middle -
3928 (-1.0*rm->get_left())*rm->get_right();
3929 }
3930
3931// Change cases like
3932// fma(c1,a,c2*b) -> c1*fma(c3,b,a)
3933// fma(a,c1,c2*b) -> c1*fma(c3,b,a)
3934// fma(c1,a,b*c2) -> c1*fma(c3,b,a)
3935// fma(a,c1,b*c2) -> c1*fma(c3,b,a)
3936 if (is_constant_combinable(this->left,
3937 rm->get_left()) &&
3938 !this->left->has_constant_zero()) {
3939 auto temp = rm->get_left()/this->left;
3940 if (temp->is_normal()) {
3941 return this->left*fma(temp,
3942 rm->get_right(),
3943 this->middle);
3944 }
3945 }
3947 rm->get_left()) &&
3948 !this->middle->has_constant_zero()) {
3949 auto temp = rm->get_left()/this->middle;
3950 if (temp->is_normal()) {
3951 return this->middle*fma(temp,
3952 rm->get_right(),
3953 this->left);
3954 }
3955 }
3956 if (is_constant_combinable(this->left,
3957 rm->get_right()) &&
3958 !this->left->has_constant_zero()) {
3959 auto temp = rm->get_right()/this->left;
3960 if (temp->is_normal()) {
3961 return this->left*fma(temp,
3962 rm->get_left(),
3963 this->middle);
3964 }
3965 }
3967 rm->get_right()) &&
3968 !this->middle->has_constant_zero()) {
3969 auto temp = rm->get_right()/this->middle;
3970 if (temp->is_normal()) {
3971 return this->middle*fma(temp,
3972 rm->get_left(),
3973 this->left);
3974 }
3975 }
3976
3977// fma(a,b*c,b*d) -> b*fma(a,c,d)
3978// fma(a,c*b,b*d) -> b*fma(a,c,d)
3979// fma(a,b*c,d*b) -> b*fma(a,c,d)
3980// fma(a,c*b,d*b) -> b*fma(a,c,d)
3981 if (mm.get()) {
3982 if (mm->get_left()->is_match(rm->get_left())) {
3983 return mm->get_left()*fma(this->left,
3984 mm->get_right(),
3985 rm->get_right());
3986 } else if (mm->get_left()->is_match(rm->get_right())) {
3987 return mm->get_left()*fma(this->left,
3988 mm->get_right(),
3989 rm->get_left());
3990 } else if (mm->get_right()->is_match(rm->get_left())) {
3991 return mm->get_right()*fma(this->left,
3992 mm->get_left(),
3993 rm->get_right());
3994 } else if (mm->get_right()->is_match(rm->get_right())) {
3995 return mm->get_right()*fma(this->left,
3996 mm->get_left(),
3997 rm->get_left());
3998 }
3999 }
4000
4001// Convert fma(a*b,c,d*e) -> fma(d,e,a*b*c)
4002// Convert fma(a,b*c,d*e) -> fma(d,e,a*b*c)
4003 if ((lm.get() || mm.get()) &&
4004 (this->left->get_complexity() + this->middle->get_complexity() >
4005 this->right->get_complexity())) {
4006 return fma(rm->get_left(), rm->get_right(),
4007 this->left*this->middle);
4008 }
4009 }
4010
4011// Handle cases like.
4012// fma(c1*a,b,c2*d) -> c1*(a*b + c2/c1*d)
4013// fma(a*c1,b,c2*d) -> c1*(a*b + c2/c1*d)
4014// fma(c1*a,b,d*c2*d) -> c1*(a*b + c2/c1*d)
4015// fma(a*c1,b,d*c2*d) -> c1*(a*b + c2/c1*d)
4016 if (lm.get() && rm.get()) {
4017 if (is_constant_combinable(rm->get_left(),
4018 lm->get_left()) &&
4019 !lm->get_left()->has_constant_zero()) {
4020 auto temp = rm->get_left()/lm->get_left();
4021 if (temp->is_normal()){
4022 return lm->get_left()*fma(lm->get_right(),
4023 this->middle,
4024 temp*rm->get_right());
4025 }
4026 }
4027 if (is_constant_combinable(rm->get_left(),
4028 lm->get_right()) &&
4029 !lm->get_right()->has_constant_zero()) {
4030 auto temp = rm->get_left()/lm->get_right();
4031 if (temp->is_normal()){
4032 return lm->get_right()*fma(lm->get_left(),
4033 this->middle,
4034 temp*rm->get_right());
4035 }
4036 }
4037 if (is_constant_combinable(rm->get_right(),
4038 lm->get_left()) &&
4039 !lm->get_left()->has_constant_zero()) {
4040 auto temp = rm->get_right()/lm->get_left();
4041 if (temp->is_normal()) {
4042 return lm->get_left()*fma(lm->get_right(),
4043 this->middle,
4044 temp*rm->get_left());
4045 }
4046 }
4047 if (is_constant_combinable(rm->get_right(),
4048 lm->get_right()) &&
4049 !lm->get_right()->has_constant_zero()) {
4050 auto temp = rm->get_right()/lm->get_right();
4051 if (temp->is_normal()) {
4052 return lm->get_right()*fma(lm->get_left(),
4053 this->middle,
4054 temp*rm->get_left());
4055 }
4056 }
4057 }
4058
4059// Move constant multiplies to the left.
4060 if (lm.get()) {
4061// fma(c1*a,b,c) -> fma(c1,a*b,c)
4062 if (is_constant_promotable(lm->get_left(),
4063 lm->get_right())) {
4064 return fma(lm->get_left(),
4065 lm->get_right()*this->middle,
4066 this->right);
4067 }
4068 } else if (mm.get()) {
4069// fma(c1,c2*a,b) -> fma(c3,a,b)
4070// fma(c1,a*c2,b) -> fma(c3,a,b)
4071// fma(a,c1*b,c) -> fma(c1,a*b,c)
4072 if (is_constant_combinable(this->left,
4073 mm->get_left())) {
4074 auto temp = this->left*mm->get_left();
4075 if (temp->is_normal()) {
4076 return fma(temp,
4077 mm->get_right(),
4078 this->right);
4079 }
4080 }
4081 if (is_constant_combinable(this->left,
4082 mm->get_right())) {
4083 auto temp = this->left*mm->get_right();
4084 if (temp->is_normal()) {
4085 return fma(temp,
4086 mm->get_left(),
4087 this->right);
4088 }
4089 }
4090 if (is_constant_promotable(mm->get_left(),
4091 this->left)) {
4092 return fma(mm->get_left(),
4093 this->left*mm->get_right(),
4094 this->right);
4095 }
4096 }
4097
4098// fma(a,b*c,b) -> b*fma(a,c,1)
4099 if (mm.get()) {
4100 if (mm->get_left()->is_match(this->right)) {
4101 return mm->get_left()*fma(this->left,
4102 mm->get_right(),
4103 1.0);
4104 } else if (mm->get_right()->is_match(this->right)) {
4105 return mm->get_right()*fma(this->left,
4106 mm->get_left(),
4107 1.0);
4108 }
4109 }
4110
4111// fma(c1,a,c2/b) -> c1*(a + c3/b)
4112// fma(a,c1,c2/b) -> c1*(a + c3/b)
4113 auto rd = divide_cast(this->right);
4114 if (rd.get()) {
4115 if (is_constant_combinable(this->left,
4116 rd->get_left()) &&
4117 !this->left->has_constant_zero()) {
4118 auto temp = rd->get_left()/this->left;
4119 if (temp->is_normal()) {
4120 return this->left*(this->middle +
4121 temp/rd->get_right());
4122 }
4123 }
4125 rd->get_left()) &&
4126 !this->middle->has_constant_zero()) {
4127 auto temp = rd->get_left()/this->middle;
4128 if (temp->is_normal()) {
4129 return this->middle*(this->left +
4130 temp/rd->get_right());
4131 }
4132 }
4133 }
4134
4135// Reduce fma(a/b,b,c) -> a + c
4136// Reduce fma(a,b/a,c) -> b + c
4137 auto ld = divide_cast(this->left);
4138 if (ld.get() && ld->get_right()->is_match(this->middle)) {
4139 return ld->get_left() + this->right;
4140 }
4141 auto md = divide_cast(this->middle);
4142 if (md.get() && md->get_right()->is_match(this->left)) {
4143 return md->get_left() + this->right;
4144 }
4145
4146// Common denominator reductions.
4147 if (ld.get() && rd.get()) {
4148// fma(b/c,a,b,d) -> b(a/c + 1/d)
4149 if (ld->get_left()->is_match(rd->get_left())) {
4150 return ld->get_left()*(this->middle/ld->get_right() +
4151 1.0/rd->get_right());
4152 }
4153
4154// fma(a/(b*c),d,e/c) -> fma(a,d,e*b)/(b*c)
4155// fma(a/(c*b),d,e/c) -> fma(a,d,e*b)/(c*b)
4156// fma(a/c,d,e/(c*b)) -> fma(a*b,d,e)/(b*c)
4157// fma(a/c,d,e/(b*c)) -> fma(a*b,d,e)/(c*b)
4158 auto ldrm = multiply_cast(ld->get_right());
4159 auto rdrm = multiply_cast(rd->get_right());
4160
4161 if (ldrm.get()) {
4162 if (ldrm->get_right()->is_match(rd->get_right())) {
4163 return fma(ld->get_left(), this->middle,
4164 rd->get_left()*ldrm->get_left()) /
4165 ld->get_right();
4166 } else if (ldrm->get_left()->is_match(rd->get_right())) {
4167 return fma(ld->get_left(), this->middle,
4168 rd->get_left()*ldrm->get_right()) /
4169 ld->get_right();
4170 }
4171 } else if (rdrm.get()) {
4172 if (rdrm->get_right()->is_match(ld->get_right())) {
4173 return fma(ld->get_left()*rdrm->get_left(),
4174 this->middle, rd->get_left()) /
4175 rd->get_right();
4176 } else if (rdrm->get_left()->is_match(ld->get_right())) {
4177 return fma(ld->get_left()*rdrm->get_right(),
4178 this->middle, rd->get_left()) /
4179 rd->get_right();
4180 }
4181 }
4182 } else if (md.get() && rd.get()) {
4183// fma(a,d/(b*c),e/c) -> fma(a,d,e*b)/(b*c)
4184// fma(a,d/(c*b),e/c) -> fma(a,d,e*b)/(c*b)
4185// fma(a,d/c,e/(c*b)) -> fma(a,d*b,e)/(b*c)
4186// fma(a,d/c,e/(b*c)) -> fma(a,d*b,e)/(c*b)
4187 auto mdrm = multiply_cast(md->get_right());
4188 auto rdrm = multiply_cast(rd->get_right());
4189
4190 if (mdrm.get()) {
4191 if (mdrm->get_right()->is_match(rd->get_right())) {
4192 return fma(this->left, md->get_left(),
4193 rd->get_left()*mdrm->get_left()) /
4194 md->get_right();
4195 } else if (mdrm->get_left()->is_match(rd->get_right())) {
4196 return fma(this->left, md->get_left(),
4197 rd->get_left()*mdrm->get_right()) /
4198 md->get_right();
4199 }
4200 } else if (rdrm.get()) {
4201 if (rdrm->get_right()->is_match(md->get_right())) {
4202 return fma(this->left, md->get_left()*rdrm->get_left(),
4203 rd->get_left()) /
4204 rd->get_right();
4205 } else if (rdrm->get_left()->is_match(md->get_right())) {
4206 return fma(this->left, md->get_left()*rdrm->get_right(),
4207 rd->get_left()) /
4208 rd->get_right();
4209 }
4210 }
4211 }
4212
4213// Chained fma reductions.
4214 auto rfma = fma_cast(this->right);
4215 if (rfma.get()) {
4216// fma(a, b, fma(c, b, d)) -> fma(b, a + c, d)
4217// fma(b, a, fma(c, b, d)) -> fma(b, a + c, d)
4218// fma(a, b, fma(b, c, d)) -> fma(b, a + c, d)
4219// fma(b, a, fma(b, c, d)) -> fma(b, a + c, d)
4220 if (this->middle->is_match(rfma->get_middle())) {
4221 return fma(this->middle,
4222 this->left + rfma->get_left(),
4223 rfma->get_right());
4224 } else if (this->left->is_match(rfma->get_middle())) {
4225 return fma(this->left,
4226 this->middle + rfma->get_left(),
4227 rfma->get_right());
4228 } else if (this->middle->is_match(rfma->get_left())) {
4229 return fma(this->middle,
4230 this->left + rfma->get_middle(),
4231 rfma->get_right());
4232 } else if (this->left->is_match(rfma->get_left())) {
4233 return fma(this->left,
4234 this->middle + rfma->get_middle(),
4235 rfma->get_right());
4236 }
4237
4238 if (mm.get()) {
4239// fma(a, e*b, fma(c, b, d)) -> fma(b, fma(a, e, c), d)
4240// fma(a, b*e, fma(c, b, d)) -> fma(b, fma(a, e, c), d)
4241// fma(a, e*b, fma(b, c, d)) -> fma(b, fma(a, e, c), d)
4242// fma(a, b*e, fma(b, c, d)) -> fma(b, fma(a, e, c), d)
4243 if (mm->get_right()->is_match(rfma->get_middle())) {
4244 return fma(mm->get_right(),
4245 fma(this->left,
4246 mm->get_left(),
4247 rfma->get_left()),
4248 rfma->get_right());
4249 } else if (mm->get_left()->is_match(rfma->get_middle())) {
4250 return fma(mm->get_left(),
4251 fma(this->left,
4252 mm->get_right(),
4253 rfma->get_left()),
4254 rfma->get_right());
4255 } else if (mm->get_right()->is_match(rfma->get_left())) {
4256 return fma(mm->get_right(),
4257 fma(this->left,
4258 mm->get_left(),
4259 rfma->get_middle()),
4260 rfma->get_right());
4261 } else if (mm->get_left()->is_match(rfma->get_left())) {
4262 return fma(mm->get_left(),
4263 fma(this->left,
4264 mm->get_right(),
4265 rfma->get_middle()),
4266 rfma->get_right());
4267 }
4268 } else if (lm.get()) {
4269// fma(e*b, a, fma(c, b, d)) -> fma(b, fma(a, e, c), d)
4270// fma(b*e, a, fma(c, b, d)) -> fma(b, fma(a, e, c), d)
4271// fma(e*b, a, fma(b, c, d)) -> fma(b, fma(a, e, c), d)
4272// fma(e*d, a, fma(b, c, d)) -> fma(b, fma(a, e, c), d)
4273 if (lm->get_right()->is_match(rfma->get_middle())) {
4274 return fma(lm->get_right(),
4275 fma(this->middle,
4276 lm->get_left(),
4277 rfma->get_left()),
4278 rfma->get_right());
4279 } else if (lm->get_left()->is_match(rfma->get_middle())) {
4280 return fma(lm->get_left(),
4281 fma(this->middle,
4282 lm->get_right(),
4283 rfma->get_left()),
4284 rfma->get_right());
4285 } else if (lm->get_right()->is_match(rfma->get_left())) {
4286 return fma(lm->get_right(),
4287 fma(this->middle,
4288 lm->get_left(),
4289 rfma->get_middle()),
4290 rfma->get_right());
4291 } else if (lm->get_left()->is_match(rfma->get_left())) {
4292 return fma(lm->get_left(),
4293 fma(this->middle,
4294 lm->get_right(),
4295 rfma->get_middle()),
4296 rfma->get_right());
4297 }
4298 }
4299
4300 auto rfmamm = multiply_cast(rfma->get_middle());
4301 auto rfmalm = multiply_cast(rfma->get_left());
4302 if (rfmamm.get()) {
4303// fma(a, b, fma(c, e*b, d)) -> fma(b, fma(c, e, a), d)
4304// fma(b, a, fma(c, e*b, d)) -> fma(b, fma(c, e, a), d)
4305// fma(a, b, fma(c, b*e, d)) -> fma(b, fma(c, e, a), d)
4306// fma(b, a, fma(c, b*e, d)) -> fma(b, fma(c, e, a), d)
4307 if (rfmamm->get_right()->is_match(this->middle)) {
4308 return fma(this->middle,
4309 fma(rfma->get_left(),
4310 rfmamm->get_left(),
4311 this->left),
4312 rfma->get_right());
4313 } else if (rfmamm->get_right()->is_match(this->left)) {
4314 return fma(this->left,
4315 fma(rfma->get_left(),
4316 rfmamm->get_left(),
4317 this->middle),
4318 rfma->get_right());
4319 } else if (rfmamm->get_left()->is_match(this->middle)) {
4320 return fma(this->middle,
4321 fma(rfma->get_left(),
4322 rfmamm->get_right(),
4323 this->left),
4324 rfma->get_right());
4325 } else if (rfmamm->get_left()->is_match(this->left)) {
4326 return fma(this->left,
4327 fma(rfma->get_left(),
4328 rfmamm->get_right(),
4329 this->middle),
4330 rfma->get_right());
4331 }
4332 } else if (rfmalm.get()) {
4333// fma(a, b, fma(e*b, c, d)) -> fma(b, fma(c, e, a), d)
4334// fma(b, a, fma(e*b, c, d)) -> fma(b, fma(c, e, a), d)
4335// fma(a, b, fma(b*e, c, d)) -> fma(b, fma(c, e, a), d)
4336// fma(b, a, fma(b*e, c, d)) -> fma(b, fma(c, e, a), d)
4337 if (rfmalm->get_right()->is_match(this->middle)) {
4338 return fma(this->middle,
4339 fma(rfma->get_middle(),
4340 rfmalm->get_left(),
4341 this->left),
4342 rfma->get_right());
4343 } else if (rfmalm->get_right()->is_match(this->left)) {
4344 return fma(this->left,
4345 fma(rfma->get_middle(),
4346 rfmalm->get_left(),
4347 this->middle),
4348 rfma->get_right());
4349 } else if (rfmalm->get_left()->is_match(this->middle)) {
4350 return fma(this->middle,
4351 fma(rfma->get_middle(),
4352 rfmalm->get_right(),
4353 this->left),
4354 rfma->get_right());
4355 } else if (rfmalm->get_left()->is_match(this->left)) {
4356 return fma(this->left,
4357 fma(rfma->get_middle(),
4358 rfmalm->get_right(),
4359 this->middle),
4360 rfma->get_right());
4361 }
4362 }
4363
4364 if (mm.get() && rfmamm.get()) {
4365// fma(a, f*b, fma(c, e*b, d)) -> fma(b, fma(a, f, c*e), d)
4366// fma(a, b*f, fma(c, e*b, d)) -> fma(b, fma(a, f, c*e), d)
4367// fma(a, f*b, fma(c, b*e, d)) -> fma(b, fma(a, f, c*e), d)
4368// fma(a, b*f, fma(c, b*e, d)) -> fma(b, fma(a, f, c*e), d)
4369 if (mm->get_right()->is_match(rfmamm->get_right())) {
4370 return fma(mm->get_right(),
4371 fma(this->left,
4372 mm->get_left(),
4373 rfma->get_left()*rfmamm->get_left()),
4374 rfma->get_right());
4375 } else if (mm->get_left()->is_match(rfmamm->get_right())) {
4376 return fma(mm->get_left(),
4377 fma(this->left,
4378 mm->get_right(),
4379 rfma->get_left()*rfmamm->get_left()),
4380 rfma->get_right());
4381 } else if (mm->get_right()->is_match(rfmamm->get_left())) {
4382 return fma(mm->get_right(),
4383 fma(this->left,
4384 mm->get_left(),
4385 rfma->get_left()*rfmamm->get_right()),
4386 rfma->get_right());
4387 } else if (mm->get_left()->is_match(rfmamm->get_left())) {
4388 return fma(mm->get_left(),
4389 fma(this->left,
4390 mm->get_right(),
4391 rfma->get_left()*rfmamm->get_right()),
4392 rfma->get_right());
4393 }
4394 } else if (lm.get() && rfmamm.get()) {
4395// fma(f*b, a, fma(c, e*b, d)) -> fma(b, fma(a, f, c*e), d)
4396// fma(b*f, a, fma(c, e*b, d)) -> fma(b, fma(a, f, c*e), d)
4397// fma(f*b, a, fma(c, b*e, d)) -> fma(b, fma(a, f, c*e), d)
4398// fma(b*f, a, fma(c, b*e, d)) -> fma(b, fma(a, f, c*e), d)
4399 if (lm->get_right()->is_match(rfmamm->get_right())) {
4400 return fma(lm->get_right(),
4401 fma(this->middle,
4402 lm->get_left(),
4403 rfma->get_left()*rfmamm->get_left()),
4404 rfma->get_right());
4405 } else if (lm->get_left()->is_match(rfmamm->get_right())) {
4406 return fma(lm->get_left(),
4407 fma(this->middle,
4408 lm->get_right(),
4409 rfma->get_left()*rfmamm->get_left()),
4410 rfma->get_right());
4411 } else if (lm->get_right()->is_match(rfmamm->get_left())) {
4412 return fma(lm->get_right(),
4413 fma(this->middle,
4414 lm->get_left(),
4415 rfma->get_left()*rfmamm->get_right()),
4416 rfma->get_right());
4417 } else if (lm->get_left()->is_match(rfmamm->get_left())) {
4418 return fma(lm->get_left(),
4419 fma(this->middle,
4420 lm->get_right(),
4421 rfma->get_left()*rfmamm->get_right()),
4422 rfma->get_right());
4423 }
4424 } else if (mm.get() && rfmalm.get()) {
4425// fma(a, f*b, fma(e*b, c, d)) -> fma(b, fma(a, f, c*e), d)
4426// fma(a, b*f, fma(e*b, c, d)) -> fma(b, fma(a, f, c*e), d)
4427// fma(a, f*b, fma(b*e, c, d)) -> fma(b, fma(a, f, c*e), d)
4428// fma(a, b*f, fma(b*e, c, d)) -> fma(b, fma(a, f, c*e), d)
4429 if (mm->get_right()->is_match(rfmalm->get_right())) {
4430 return fma(mm->get_right(),
4431 fma(this->left,
4432 mm->get_left(),
4433 rfma->get_middle()*rfmalm->get_left()),
4434 rfma->get_right());
4435 } else if (mm->get_left()->is_match(rfmalm->get_right())) {
4436 return fma(mm->get_left(),
4437 fma(this->left,
4438 mm->get_right(),
4439 rfma->get_middle()*rfmalm->get_left()),
4440 rfma->get_right());
4441 } else if (mm->get_right()->is_match(rfmalm->get_left())) {
4442 return fma(mm->get_right(),
4443 fma(this->left,
4444 mm->get_left(),
4445 rfma->get_middle()*rfmalm->get_right()),
4446 rfma->get_right());
4447 } else if (mm->get_left()->is_match(rfmalm->get_left())) {
4448 return fma(mm->get_left(),
4449 fma(this->left,
4450 mm->get_right(),
4451 rfma->get_middle()*rfmalm->get_right()),
4452 rfma->get_right());
4453 }
4454 } else if (lm.get() && rfmalm.get()) {
4455// fma(f*b, a, fma(e*b, c, d)) -> fma(b, fma(a, f, c*e), d)
4456// fma(b*f, a, fma(e*b, c, d)) -> fma(b, fma(a, f, c*e), d)
4457// fma(f*b, a, fma(b*e, c, d)) -> fma(b, fma(a, f, c*e), d)
4458// fma(b*f, a, fma(b*e, c, d)) -> fma(b, fma(a, f, c*e), d)
4459 if (lm->get_right()->is_match(rfmalm->get_right())) {
4460 return fma(lm->get_right(),
4461 fma(this->middle,
4462 lm->get_left(),
4463 rfma->get_middle()*rfmalm->get_left()),
4464 rfma->get_right());
4465 } else if (lm->get_left()->is_match(rfmalm->get_right())) {
4466 return fma(lm->get_left(),
4467 fma(this->middle,
4468 lm->get_right(),
4469 rfma->get_middle()*rfmalm->get_left()),
4470 rfma->get_right());
4471 } else if (lm->get_right()->is_match(rfmalm->get_left())) {
4472 return fma(lm->get_right(),
4473 fma(this->middle,
4474 lm->get_left(),
4475 rfma->get_middle()*rfmalm->get_right()),
4476 rfma->get_right());
4477 } else if (lm->get_left()->is_match(rfmalm->get_left())) {
4478 return fma(lm->get_left(),
4479 fma(this->middle,
4480 lm->get_right(),
4481 rfma->get_middle()*rfmalm->get_right()),
4482 rfma->get_right());
4483 }
4484 }
4485
4486 if (is_variable_combinable(this->middle, rfma->get_middle())) {
4487 if (is_greater_exponent(this->middle, rfma->get_middle())) {
4488// fma(a,x^b,fma(c,x^d,e)) -> fma(x^d,fma(x^(d-b),a,c),e) if b > d
4489 return fma(rfma->get_middle(),
4490 fma(this->middle/rfma->get_middle(),
4491 this->left,
4492 rfma->get_left()),
4493 rfma->get_right());
4494 } else {
4495// fma(a,x^b,fma(c,x^d,e)) -> fma(x^b,fma(x^(d-b),c,a),e) if d > b
4496 return fma(this->middle,
4497 fma(rfma->get_middle()/this->middle,
4498 rfma->get_left(),
4499 this->left),
4500 rfma->get_right());
4501 }
4502 } else if (is_variable_combinable(this->left, rfma->get_middle())) {
4503 if (is_greater_exponent(this->left, rfma->get_middle())) {
4504// fma(x^b,a,fma(c,x^d,e)) -> fma(x^d,fma(x^(d-b),a,c),e) if b > d
4505 return fma(rfma->get_middle(),
4506 fma(this->left/rfma->get_middle(),
4507 this->middle,
4508 rfma->get_left()),
4509 rfma->get_right());
4510 } else {
4511// fma(x^b,a,fma(c,x^d,e)) -> fma(x^b,fma(x^(d-b),c,a),e) if d > b
4512 return fma(this->left,
4513 fma(rfma->get_middle()/this->left,
4514 rfma->get_left(),
4515 this->middle),
4516 rfma->get_right());
4517 }
4518 } else if (is_variable_combinable(this->middle, rfma->get_left())) {
4519 if (is_greater_exponent(this->middle, rfma->get_left())) {
4520// fma(a,x^b,fma(x^d,c,e)) -> fma(x^d,fma(x^(d-b),a,c),e) if b > d
4521 return fma(rfma->get_left(),
4522 fma(this->middle/rfma->get_left(),
4523 this->left,
4524 rfma->get_middle()),
4525 rfma->get_right());
4526 } else {
4527// fma(a,x^b,fma(x^d,c,e)) -> fma(x^b,fma(x^(d-b),c,a),e) if d > b
4528 return fma(this->middle,
4529 fma(rfma->get_left()/this->middle,
4530 rfma->get_middle(),
4531 this->left),
4532 rfma->get_right());
4533 }
4534 } else if (is_variable_combinable(this->left, rfma->get_left())) {
4535 if (is_greater_exponent(this->left, rfma->get_left())) {
4536// fma(x^b,a,fma(x^d,c,e)) -> fma(x^d,fma(x^(d-b),a,c),e) if b > d
4537 return fma(rfma->get_left(),
4538 fma(this->left/rfma->get_left(),
4539 this->middle,
4540 rfma->get_middle()),
4541 rfma->get_right());
4542 } else {
4543// fma(x^b,a,fma(x^d,c,e)) -> fma(x^b,fma(x^(d-b),c,a),e) if d > b
4544 return fma(this->left,
4545 fma(rfma->get_left()/this->left,
4546 rfma->get_middle(),
4547 this->middle),
4548 rfma->get_right());
4549 }
4550 }
4551
4552// fma(a,b,fma(a,b,c)) -> fma(2*a,b,c)
4553// fma(a,b,fma(b,a,c)) -> fma(2*a,b,c)
4554 if (this->left->is_match(rfma->get_left()) &&
4555 this->middle->is_match(rfma->get_middle())) {
4556 return fma(2.0*this->left, this->middle, rfma->get_right());
4557 } else if (this->left->is_match(rfma->get_middle()) &&
4558 this->middle->is_match(rfma->get_left())) {
4559 return fma(2.0*this->left, this->middle, rfma->get_right());
4560 }
4561
4562// fma(a,b/c,fma(e,f/c,g)) -> (a*b + e*f)/c + g
4563// fma(a,b/c,fma(e/c,f,g)) -> (a*b + e*f)/c + g
4564// fma(a/c,b,fma(e,f/c,g)) -> (a*b + e*f)/c + g
4565// fma(a/c,b,fma(e/c,f,g)) -> (a*b + e*f)/c + g
4566 auto fmald = divide_cast(rfma->get_left());
4567 auto fmamd = divide_cast(rfma->get_middle());
4568 if (ld.get()) {
4569 if (fmald.get() && ld->get_right()->is_match(fmald->get_right())) {
4570 return (ld->get_left()*this->middle +
4571 fmald->get_left()*rfma->get_middle())/ld->get_right() +
4572 rfma->get_right();
4573 } else if (fmamd.get() && ld->get_right()->is_match(fmamd->get_right())) {
4574 return (ld->get_left()*this->middle +
4575 fmamd->get_left()*rfma->get_left())/ld->get_right() +
4576 rfma->get_right();
4577 }
4578 } else if (md.get()) {
4579 if (fmald.get() && md->get_right()->is_match(fmald->get_right())) {
4580 return (md->get_left()*this->left +
4581 fmald->get_left()*rfma->get_middle())/md->get_right() +
4582 rfma->get_right();
4583 } else if (fmamd.get() && md->get_right()->is_match(fmamd->get_right())) {
4584 return (md->get_left()*this->left +
4585 fmamd->get_left()*rfma->get_left())/md->get_right() +
4586 rfma->get_right();
4587 }
4588 }
4589 }
4590
4591// Check to see if it is worth moving nodes out of a fma nodes. These should be
4592// restricted to variable like nodes. Only do this reduction if the complexity
4593// reduces.
4594 if (this->left->is_all_variables()) {
4595 auto rdl = this->right/this->left;
4596 if (rdl->get_complexity() < this->left->get_complexity() +
4597 this->right->get_complexity()) {
4598 return (this->middle + rdl)*this->left;
4599 }
4600 } else if (this->middle->is_all_variables()) {
4601 auto rdm = this->right/this->middle;
4602 auto rdmc = constant_cast(rdm->get_power_exponent());
4603 if ((rdm->get_complexity() < this->middle->get_complexity() +
4604 this->right->get_complexity()) &&
4605 !(rdmc.get() && rdmc->evaluate().is_negative())) {
4606 return (this->left + rdm)*this->middle;
4607 }
4608 }
4609
4610// Change negative exponents to divide so that can be factored out.
4611// fma(a,b^-c,d) = a/b^c + d
4612// fma(b^-c,a,d) = a/b^c + d
4613 auto lp = pow_cast(this->left);
4614 if (lp.get()) {
4615 auto exponent = constant_cast(lp->get_right());
4616 if (exponent.get() && exponent->evaluate().is_negative()) {
4617 return this->middle/pow(lp->get_left(), -lp->get_right()) +
4618 this->right;
4619 }
4620 }
4621 auto mp = pow_cast(this->middle);
4622 if (mp.get()) {
4623 auto exponent = constant_cast(mp->get_right());
4624 if (exponent.get() && exponent->evaluate().is_negative()) {
4625 return this->left/pow(mp->get_left(), -mp->get_right()) +
4626 this->right;
4627 }
4628
4629// fma(2,a^2,a) -> a*fma(2,a,1)
4630// Note this case is handled earlier. fma(2,a,a^2) -> a*fma(2,1,a)
4632 this->right)) {
4633 auto temp = this->right/this->middle;
4634 auto temp_exponent = constant_cast(temp->get_power_exponent());
4635 if (temp_exponent.get() && temp_exponent->evaluate().is_negative()) {
4636 return this->right*fma(this->left,
4637 this->middle/this->right,
4638 1.0);
4639 }
4640 }
4641 }
4642
4643// a^b*c^b + d -> (a*c)^b + d
4644 if (lp.get() && mp.get()) {
4645 if (lp->get_right()->is_match(mp->get_right())) {
4646 return pow(lp->get_left()*mp->get_left(),
4647 lp->get_right()) +
4648 this->right;
4649 }
4650 }
4651
4652// fma(2,(ab)^2,a^2b) -> a^2*fma(2, b^2, b)
4653 if (rm.get() && mp.get()) {
4654 auto mplm = multiply_cast(mp->get_left());
4655 if (mplm.get()) {
4656 if (is_variable_combinable(mplm->get_left(),
4657 rm->get_left())) {
4658 auto temp = pow(mplm->get_left(),
4659 mp->get_right());
4660 return temp*fma(this->left,
4661 this->middle/temp,
4662 this->right/temp);
4663 } else if (is_variable_combinable(mplm->get_right(),
4664 rm->get_left())) {
4665 auto temp = pow(mplm->get_right(),
4666 mp->get_right());
4667 return temp*fma(this->left,
4668 this->middle/temp,
4669 this->right/temp);
4670 }
4671 }
4672 }
4673// fma(2,(a*b)^2,fma(a^2,b,c)) -> fma(a^2,fma(2,b^2,b),c)
4674 if (rfma.get() && mp.get()) {
4675 auto mplm = multiply_cast(mp->get_left());
4676 if (mplm.get()) {
4677 if (is_variable_combinable(mplm->get_left(),
4678 rfma->get_left())) {
4679 auto temp = pow(mplm->get_left(),
4680 mp->get_right());
4681 return fma(temp,
4682 fma(this->left,
4683 this->middle/temp,
4684 rfma->get_middle()),
4685 rfma->get_right());
4686 } else if (is_variable_combinable(mplm->get_right(),
4687 rfma->get_left())) {
4688 auto temp = pow(mplm->get_right(),
4689 mp->get_right());
4690 return fma(temp,
4691 fma(this->left,
4692 this->middle/temp,
4693 rfma->get_middle()),
4694 rfma->get_right());
4695 }
4696
4697// fma(2,(a*b)^2,fma(3,a^2*b,c)) -> a^2*fma(2,b^2,fma(3,b,c))
4698 auto rfmamm = multiply_cast(rfma->get_middle());
4699 if (rfmamm.get()) {
4700 if (is_variable_combinable(mplm->get_left(),
4701 rfmamm->get_left())) {
4702 auto temp = pow(mplm->get_left(),
4703 mp->get_right());
4704 return temp*fma(this->left,
4705 this->middle/temp,
4706 fma(rfma->get_left(),
4707 rfma->get_middle()/temp,
4708 rfma->get_right()));
4709 }
4710 }
4711 }
4712 }
4713
4714// fma(a,b/c,b/d) -> b*(a/c + 1/d)
4715// fma(a,c/b,d/b) -> (a*c + d)/b
4716 if (md.get() && rd.get()) {
4717 if (md->get_left()->is_match(rd->get_left())) {
4718 return md->get_left()*(this->left/md->get_right() +
4719 1.0/rd->get_right());
4720 } else if (md->get_right()->is_match(rd->get_right())) {
4721 return (this->left*md->get_left() +
4722 rd->get_left())/md->get_right();
4723 }
4724 }
4725// fma(b/c,a,b/d) -> b*(a/c + 1/d)
4726// fma(c/b,a,d/b) -> (a*c + d)/b
4727 if (ld.get() && rd.get()) {
4728 if (ld->get_left()->is_match(rd->get_left())) {
4729 return ld->get_left()*(this->middle/ld->get_right() +
4730 1.0/rd->get_right());
4731 } else if (ld->get_right()->is_match(rd->get_right())) {
4732 return (this->middle*ld->get_left() +
4733 rd->get_left())/ld->get_right();
4734 }
4735 }
4736
4737// fma(a/b,c,(d/b)*e) -> fma(a,c,d*e)/b
4738// fma(a/b,c,e*(d/b)) -> fma(a,c,d*e)/b
4739 if (rm.get() && ld.get()) {
4740 auto rmld = divide_cast(rm->get_left());
4741 if (rmld.get() && ld->get_right()->is_match(rmld->get_right())) {
4742 return fma(ld->get_left(), this->middle, rmld->get_left()*rm->get_right())/ld->get_right();
4743 }
4744 auto rmrd = divide_cast(rm->get_right());
4745 if (rmrd.get() && ld->get_right()->is_match(rmrd->get_right())) {
4746 return fma(ld->get_left(), this->middle, rmrd->get_left()*rm->get_left())/ld->get_right();
4747 }
4748 }
4749// fma(a,c/b,(d/b)*e) -> fma(a,c,d*e)/b
4750// fma(a,c/b,e*(d/b)) -> fma(a,c,d*e)/b
4751 if (rm.get() && md.get()) {
4752 auto rmld = divide_cast(rm->get_left());
4753 if (rmld.get() && md->get_right()->is_match(rmld->get_right())) {
4754 return fma(this->left, md->get_left(), rmld->get_left()*rm->get_right())/md->get_right();
4755 }
4756 auto rmrd = divide_cast(rm->get_right());
4757 if (rmrd.get() && md->get_right()->is_match(rmrd->get_right())) {
4758 return fma(this->left, md->get_left(), rmrd->get_left()*rm->get_left())/md->get_right();
4759 }
4760 }
4761
4762// fma(a/b*c,d,e/b) -> fma(a*c,d,e)/b
4763// fma(a*c/b,d,e/b) -> fma(a*c,d,e)/b
4764 if (rd.get() && lm.get()) {
4765 auto lmld = divide_cast(lm->get_left());
4766 if (lmld.get() && rd->get_right()->is_match(lmld->get_right())) {
4767 return fma(lmld->get_left()*lm->get_right(), this->middle, rd->get_left())/rd->get_right();
4768 }
4769 auto lmrd = divide_cast(lm->get_right());
4770 if (lmrd.get() && rd->get_right()->is_match(lmrd->get_right())) {
4771 return fma(lmld->get_left()*lm->get_left(), this->middle, rd->get_left())/rd->get_right();
4772 }
4773 }
4774// fma(a,c/b*d,e/b) -> fma(a,c*d,e)/b
4775// fma(a,c*d/b,e/b) -> fma(a,c*d,e)/b
4776 if (rd.get() && mm.get()) {
4777 auto mmld = divide_cast(mm->get_left());
4778 if (mmld.get() && rd->get_right()->is_match(mmld->get_right())) {
4779 return fma(this->left, mmld->get_left()*mm->get_right(), rd->get_left())/rd->get_right();
4780 }
4781 auto mmrd = divide_cast(mm->get_right());
4782 if (mmrd.get() && rd->get_right()->is_match(mmrd->get_right())) {
4783 return fma(this->left, mmrd->get_left()*mm->get_left(), rd->get_left())/rd->get_right();
4784 }
4785 }
4786
4787// fma(a, b/c, ((f/c)*e)*d) -> fma(a, b, f*e*d)/c
4788// fma(a/c, b, ((f/c)*e)*d) -> fma(a, b, f*e*d)/c
4789// fma(a, b/c, (e*(f/c))*d) -> fma(a, b, f*e*d)/c
4790// fma(a/c, b, (e*(f/c))*d) -> fma(a, b, f*e*d)/c
4791// fma(a, b/c, d*((f/c)*e)) -> fma(a, b, f*e*d)/c
4792// fma(a/c, b, d*((f/c)*e)) -> fma(a, b, f*e*d)/c
4793// fma(a, b/c, d*(e*(f/c))) -> fma(a, b, f*e*d)/c
4794// fma(a/c, b, d*(e*(f/c))) -> fma(a, b, f*e*d)/c
4795 if (md.get() && rm.get()) {
4796 auto rmlm = multiply_cast(rm->get_left());
4797 if (rmlm.get()) {
4798 auto rmlmld = divide_cast(rmlm->get_left());
4799 if (rmlmld.get() && rmlmld->get_right()->is_match(md->get_right())) {
4800 return fma(this->left, md->get_left(),
4801 rmlmld->get_left()*rmlm->get_right()*rm->get_right())/md->get_right();
4802 }
4803 auto rmlmrd = divide_cast(rmlm->get_right());
4804 if (rmlmrd.get() && rmlmrd->get_right()->is_match(md->get_right())) {
4805 return fma(this->left, md->get_left(),
4806 rmlmrd->get_left()*rmlm->get_left()*rm->get_right())/md->get_right();
4807 }
4808 }
4809 auto rmrm = multiply_cast(rm->get_right());
4810 if (rmrm.get()) {
4811 auto rmrmld = divide_cast(rmrm->get_left());
4812 if (rmrmld.get() && rmrmld->get_right()->is_match(md->get_right())) {
4813 return fma(this->left, md->get_left(),
4814 rmrmld->get_left()*rmrm->get_right()*rm->get_left())/md->get_right();
4815 }
4816 auto rmrmrd = divide_cast(rmrm->get_right());
4817 if (rmrmrd.get() && rmrmrd->get_right()->is_match(md->get_right())) {
4818 return fma(this->left, md->get_left(),
4819 rmrmrd->get_left()*rmrm->get_left()*rm->get_left())/md->get_right();
4820 }
4821 }
4822 } else if (ld.get() && rm.get()) {
4823 auto rmlm = multiply_cast(rm->get_left());
4824 if (rmlm.get()) {
4825 auto rmlmld = divide_cast(rmlm->get_left());
4826 if (rmlmld.get() && rmlmld->get_right()->is_match(ld->get_right())) {
4827 return fma(ld->get_left(), this->middle,
4828 rmlmld->get_left()*rmlm->get_right()*rm->get_right())/ld->get_right();
4829 }
4830 auto rmlmrd = divide_cast(rmlm->get_right());
4831 if (rmlmrd.get() && rmlmrd->get_right()->is_match(ld->get_right())) {
4832 return fma(ld->get_left(), this->middle,
4833 rmlmrd->get_left()*rmlm->get_right()*rm->get_right())/ld->get_right();
4834 }
4835 }
4836 auto rmrm = multiply_cast(rm->get_right());
4837 if (rmrm.get()) {
4838 auto rmrmld = divide_cast(rmrm->get_left());
4839 if (rmrmld.get() && rmrmld->get_right()->is_match(ld->get_right())) {
4840 return fma(ld->get_left(), this->middle,
4841 rmrmld->get_left()*rmrm->get_right()*rm->get_left())/ld->get_right();
4842 }
4843 auto rmrmrd = divide_cast(rmrm->get_right());
4844 if (rmrmrd.get() && rmrmrd->get_right()->is_match(ld->get_right())) {
4845 return fma(ld->get_left(), this->middle,
4846 rmrmrd->get_left()*rmrm->get_left()*rm->get_left())/ld->get_right();
4847 }
4848 }
4849 }
4850
4851// fma(exp(a), exp(b), c) -> exp(a + b) + c
4852 auto le = exp_cast(this->left);
4853 auto me = exp_cast(this->middle);
4854 if (le.get() && me.get()) {
4855 return exp(le->get_arg() + me->get_arg()) + this->right;
4856 }
4857
4858// fma(exp(a), exp(b)*c, d) -> fma(exp(a)*exp(b), c, d)
4859// fma(exp(a), c*exp(b), d) -> fma(exp(a)*exp(b), c, d)
4860 if (mm.get() && le.get()) {
4861 auto mmle = exp_cast(mm->get_left());
4862 if (mmle.get()) {
4863 return fma(this->left*mm->get_left(),
4864 mm->get_right(),
4865 this->right);
4866 }
4867 auto mmre = exp_cast(mm->get_right());
4868 if (mmre.get()) {
4869 return fma(this->left*mm->get_right(),
4870 mm->get_left(),
4871 this->right);
4872 }
4873 }
4874// fma(exp(a)*c, exp(b), d) -> fma(exp(a)*exp(b), c, d)
4875// fma(c*exp(a), exp(b), d) -> fma(exp(a)*exp(b), c, d)
4876 if (lm.get() && me.get()) {
4877 auto lmle = exp_cast(lm->get_left());
4878 if (lmle.get()) {
4879 return fma(lm->get_left()*this->middle,
4880 lm->get_right(),
4881 this->right);
4882 }
4883 auto lmre = exp_cast(lm->get_right());
4884 if (lmre.get()) {
4885 return fma(lm->get_right()*this->middle,
4886 lm->get_left(),
4887 this->right);
4888 }
4889 }
4890
4891// fma(exp(a)*c, exp(b)*d, e) -> fma(exp(a)*exp(b), c*d, e)
4892// fma(exp(a)*c, d*exp(b), e) -> fma(exp(a)*exp(b), c*d, e)
4893// fma(c*exp(a), exp(b)*d, e) -> fma(exp(a)*exp(b), c*d, e)
4894// fma(c*exp(a), d*exp(b), e) -> fma(exp(a)*exp(b), c*d, e)
4895 if (lm.get() && mm.get()) {
4896 auto lmle = exp_cast(lm->get_left());
4897 if (lmle.get()) {
4898 auto mmle = exp_cast(mm->get_left());
4899 if (mmle.get()) {
4900 return fma(lm->get_left()*mm->get_left(),
4901 lm->get_right()*mm->get_right(),
4902 this->right);
4903 }
4904 auto mmre = exp_cast(mm->get_right());
4905 if (mmre.get()) {
4906 return fma(lm->get_left()*mm->get_right(),
4907 lm->get_right()*mm->get_left(),
4908 this->right);
4909 }
4910 }
4911 auto lmre = exp_cast(lm->get_right());
4912 if (lmre.get()) {
4913 auto mmle = exp_cast(mm->get_left());
4914 if (mmle.get()) {
4915 return fma(lm->get_right()*mm->get_left(),
4916 lm->get_left()*mm->get_right(),
4917 this->right);
4918 }
4919 auto mmre = exp_cast(mm->get_right());
4920 if (mmre.get()) {
4921 return fma(lm->get_right()*mm->get_right(),
4922 lm->get_left()*mm->get_left(),
4923 this->right);
4924 }
4925 }
4926 }
4927
4928// fma(exp(a)*c, exp(b)/d, e) -> fma(exp(a)*exp(b), c/d, e)
4929// fma(exp(a)*c, d/exp(b), e) -> fma(exp(a)/exp(b), c*d, e)
4930// fma(c*exp(a), exp(b)/d, e) -> fma(exp(a)*exp(b), c/d, e)
4931// fma(c*exp(a), d/exp(b), e) -> fma(exp(a)/exp(b), c*d, e)
4932 if (lm.get() && md.get()) {
4933 auto lmle = exp_cast(lm->get_left());
4934 if (lmle.get()) {
4935 auto mdle = exp_cast(md->get_left());
4936 if (mdle.get()) {
4937 return fma(lm->get_left()*md->get_left(),
4938 lm->get_right()/md->get_right(),
4939 this->right);
4940 }
4941 auto mdre = exp_cast(md->get_right());
4942 if (mdre.get()) {
4943 return fma(lm->get_left()/md->get_right(),
4944 lm->get_right()*md->get_left(),
4945 this->right);
4946 }
4947 }
4948 auto lmre = exp_cast(lm->get_right());
4949 if (lmre.get()) {
4950 auto mdle = exp_cast(md->get_left());
4951 if (mdle.get()) {
4952 return fma(lm->get_right()*md->get_left(),
4953 lm->get_left()/md->get_right(),
4954 this->right);
4955 }
4956 auto mdre = exp_cast(md->get_right());
4957 if (mdre.get()) {
4958 return fma(lm->get_right()/md->get_right(),
4959 lm->get_left()*md->get_left(),
4960 this->right);
4961 }
4962 }
4963 }
4964
4965// fma(exp(a)/c, exp(b)*d, e) -> fma(exp(a)*exp(b), d/c, e)
4966// fma(exp(a)/c, d*exp(b), e) -> fma(exp(a)*exp(b), d/c, e)
4967// fma(c/exp(a), exp(b)*d, e) -> fma(exp(b)/exp(a), c*d, e)
4968// fma(c/exp(a), d*exp(b), e) -> fma(exp(b)/exp(a), c*d, e)
4969 if (ld.get() && mm.get()) {
4970 auto ldle = exp_cast(ld->get_left());
4971 if (ldle.get()) {
4972 auto mmle = exp_cast(mm->get_left());
4973 if (mmle.get()) {
4974 return fma(ld->get_left()*mm->get_left(),
4975 mm->get_right()/ld->get_right(),
4976 this->right);
4977 }
4978 auto mmre = exp_cast(mm->get_right());
4979 if (mmre.get()) {
4980 return fma(ld->get_left()*mm->get_right(),
4981 mm->get_left()/ld->get_right(),
4982 this->right);
4983 }
4984 }
4985 auto ldre = exp_cast(ld->get_right());
4986 if (ldre.get()) {
4987 auto mmle = exp_cast(mm->get_left());
4988 if (mmle.get()) {
4989 return fma(mm->get_left()/ld->get_right(),
4990 ld->get_left()*mm->get_right(),
4991 this->right);
4992 }
4993 auto mmre = exp_cast(mm->get_right());
4994 if (mmre.get()) {
4995 return fma(mm->get_right()/ld->get_right(),
4996 ld->get_left()*mm->get_left(),
4997 this->right);
4998 }
4999 }
5000 }
5001
5002// fma(exp(a)/c, exp(b)/d, e) -> (exp(a)*exp(b))/(c*d) + e
5003// fma(exp(a)/c, d/exp(b), e) -> fma(exp(a)/exp(b), d/c, e)
5004// fma(c/exp(a), exp(b)/d, e) -> fma(exp(b)/exp(a), c/d, e)
5005// fma(c/exp(a), d/exp(b), e) -> (c*d)/(exp(a)*exp(b)) + e
5006 if (ld.get() && md.get()) {
5007 auto ldle = exp_cast(ld->get_left());
5008 if (ldle.get()) {
5009 auto mdle = exp_cast(md->get_left());
5010 if (mdle.get()) {
5011 return ((ld->get_left()*md->get_left()) /
5012 (ld->get_right()*md->get_right())) +
5013 this->right;
5014 }
5015 auto mdre = exp_cast(md->get_right());
5016 if (mdre.get()) {
5017 return fma(ld->get_left()/md->get_right(),
5018 md->get_left()/ld->get_right(),
5019 this->right);
5020 }
5021 }
5022 auto ldre = exp_cast(ld->get_right());
5023 if (ldre.get()) {
5024 auto mdle = exp_cast(md->get_left());
5025 if (mdle.get()) {
5026 return fma(md->get_left()/ld->get_right(),
5027 ld->get_left()/md->get_right(),
5028 this->right);
5029 }
5030 auto mdre = exp_cast(md->get_right());
5031 if (mdre.get()) {
5032 return ((ld->get_left()*md->get_left()) /
5033 (ld->get_right()*md->get_right())) +
5034 this->right;
5035 }
5036 }
5037 }
5038
5039 return this->shared_from_this();
5040 }
5041
5042//------------------------------------------------------------------------------
5049//------------------------------------------------------------------------------
5052 if (this->is_match(x)) {
5053 return one<T, SAFE_MATH> ();
5054 }
5055
5056 const size_t hash = reinterpret_cast<size_t> (x.get());
5057 if (this->df_cache.find(hash) == this->df_cache.end()) {
5058 auto temp_right = fma(this->left,
5059 this->middle->df(x),
5060 this->right->df(x));
5061
5062 this->df_cache[hash] = fma(this->left->df(x),
5063 this->middle,
5064 temp_right);
5065 }
5066 return this->df_cache[hash];
5067 }
5068
5069//------------------------------------------------------------------------------
5077//------------------------------------------------------------------------------
5079 compile(std::ostringstream &stream,
5080 jit::register_map &registers,
5082 const jit::register_usage &usage) {
5083 if (registers.find(this) == registers.end()) {
5084 shared_leaf<T, SAFE_MATH> l = this->left->compile(stream,
5085 registers,
5086 indices,
5087 usage);
5088 shared_leaf<T, SAFE_MATH> m = this->middle->compile(stream,
5089 registers,
5090 indices,
5091 usage);
5092 shared_leaf<T, SAFE_MATH> r = this->right->compile(stream,
5093 registers,
5094 indices,
5095 usage);
5096
5097 registers[this] = jit::to_string('r', this);
5098 stream << " const ";
5099 jit::add_type<T> (stream);
5100 stream << " " << registers[this] << " = ";
5101 if constexpr (SAFE_MATH) {
5102 stream << "(" << registers[l.get()] << " == ";
5103 if constexpr (jit::complex_scalar<T>) {
5104 jit::add_type<T> (stream);
5105 stream << "(0, 0)";
5106 } else {
5107 stream << "0";
5108 }
5109 stream << " || " << registers[m.get()] << " == ";
5110 if constexpr (jit::complex_scalar<T>) {
5111 jit::add_type<T> (stream);
5112 stream << "(0, 0)";
5113 } else {
5114 stream << "0";
5115 }
5116 stream << ") ? " << registers[r.get()] << " : ";
5117 }
5118 if constexpr (jit::complex_scalar<T>) {
5119 stream << registers[l.get()] << "*"
5120 << registers[m.get()] << " + "
5121 << registers[r.get()];
5122 } else {
5123 stream << "fma("
5124 << registers[l.get()] << ", "
5125 << registers[m.get()] << ", "
5126 << registers[r.get()] << ")";
5127 }
5128 this->endline(stream, usage);
5129 }
5130
5131 return this->shared_from_this();
5132 }
5133
5134//------------------------------------------------------------------------------
5139//------------------------------------------------------------------------------
5141 if (this == x.get()) {
5142 return true;
5143 }
5144
5145 auto x_cast = fma_cast(x);
5146 if (x_cast.get()) {
5147 return this->left->is_match(x_cast->get_left()) &&
5148 this->middle->is_match(x_cast->get_middle()) &&
5149 this->right->is_match(x_cast->get_right());
5150 }
5151
5152 return false;
5153 }
5154
5155//------------------------------------------------------------------------------
5157//------------------------------------------------------------------------------
5158 virtual void to_latex() const {
5159 std::cout << "\\left(";
5160 if (add_cast(this->left).get() ||
5161 subtract_cast(this->left).get()) {
5162 std::cout << "\\left(";
5163 this->left->to_latex();
5164 std::cout << "\\right)";
5165 } else {
5166 this->left->to_latex();
5167 }
5168 std::cout << " ";
5169 if (add_cast(this->middle).get() ||
5170 subtract_cast(this->middle).get()) {
5171 std::cout << "\\left(";
5172 this->middle->to_latex();
5173 std::cout << "\\right)";
5174 } else {
5175 this->middle->to_latex();
5176 }
5177 std::cout << "+";
5178 this->right->to_latex();
5179 std::cout << "\\right)";
5180 }
5181
5182//------------------------------------------------------------------------------
5186//------------------------------------------------------------------------------
5188 if (this->has_pseudo()) {
5189 return fma(this->left->remove_pseudo(),
5190 this->middle->remove_pseudo(),
5191 this->right->remove_pseudo());
5192 }
5193 return this->shared_from_this();
5194 }
5195
5196//------------------------------------------------------------------------------
5202//------------------------------------------------------------------------------
5203 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
5204 jit::register_map &registers) {
5205 if (registers.find(this) == registers.end()) {
5206 const std::string name = jit::to_string('r', this);
5207 registers[this] = name;
5208 stream << " " << name
5209 << " [label = \"fma\", shape = oval, style = filled, fillcolor = blue, fontcolor = white];" << std::endl;
5210
5211 auto l = this->left->to_vizgraph(stream, registers);
5212 stream << " " << name << " -- " << registers[l.get()] << ";" << std::endl;
5213 auto m = this->middle->to_vizgraph(stream, registers);
5214 stream << " " << name << " -- " << registers[m.get()] << ";" << std::endl;
5215 auto r = this->right->to_vizgraph(stream, registers);
5216 stream << " " << name << " -- " << registers[r.get()] << ";" << std::endl;
5217 }
5218
5219 return this->shared_from_this();
5220 }
5221 };
5222
5223//------------------------------------------------------------------------------
5232//------------------------------------------------------------------------------
5233 template<jit::float_scalar T, bool SAFE_MATH=false>
5237 auto temp = std::make_shared<fma_node<T, SAFE_MATH>> (l, m, r)->reduce();
5238// Test for hash collisions.
5239 for (size_t i = temp->get_hash();
5241 if (leaf_node<T, SAFE_MATH>::caches.nodes.find(i) ==
5242 leaf_node<T, SAFE_MATH>::caches.nodes.end()) {
5244 return temp;
5245 } else if (temp->is_match(leaf_node<T, SAFE_MATH>::caches.nodes[i])) {
5246 return leaf_node<T, SAFE_MATH>::caches.nodes[i];
5247 }
5248 }
5249#if defined(__clang__) || defined(__GNUC__)
5251#else
5252 assert(false && "Should never reach.");
5253#endif
5254 }
5255
5256//------------------------------------------------------------------------------
5269//------------------------------------------------------------------------------
5270 template<jit::float_scalar T, jit::float_scalar L, bool SAFE_MATH=false>
5276
5277//------------------------------------------------------------------------------
5290//------------------------------------------------------------------------------
5291 template<jit::float_scalar T, jit::float_scalar M, bool SAFE_MATH=false>
5297
5298//------------------------------------------------------------------------------
5311//------------------------------------------------------------------------------
5312 template<jit::float_scalar T, jit::float_scalar R, bool SAFE_MATH=false>
5318
5319//------------------------------------------------------------------------------
5333//------------------------------------------------------------------------------
5334 template<jit::float_scalar T, jit::float_scalar L, jit::float_scalar M, bool SAFE_MATH=false>
5336 const M m,
5338 return fma<T, SAFE_MATH> (constant<T, SAFE_MATH> (static_cast<T> (l)),
5339 constant<T, SAFE_MATH> (static_cast<T> (m)), r);
5340 }
5341
5342//------------------------------------------------------------------------------
5356//------------------------------------------------------------------------------
5357 template<jit::float_scalar T, jit::float_scalar M, jit::float_scalar R, bool SAFE_MATH=false>
5359 const M m,
5360 const R r) {
5361 return fma<T, SAFE_MATH> (l, constant<T, SAFE_MATH> (static_cast<T> (m)),
5362 constant<T, SAFE_MATH> (static_cast<T> (r)));
5363 }
5364
5365//------------------------------------------------------------------------------
5379//------------------------------------------------------------------------------
5380 template<jit::float_scalar T, jit::float_scalar L, jit::float_scalar R, bool SAFE_MATH=false>
5383 const R r) {
5384 return fma<T, SAFE_MATH> (constant<T, SAFE_MATH> (static_cast<T> (l)), m,
5385 constant<T, SAFE_MATH> (static_cast<T> (r)));
5386 }
5387
5389 template<jit::float_scalar T, bool SAFE_MATH=false>
5390 using shared_fma = std::shared_ptr<fma_node<T, SAFE_MATH>>;
5391
5392//------------------------------------------------------------------------------
5400//------------------------------------------------------------------------------
5401 template<jit::float_scalar T, bool SAFE_MATH=false>
5403 return std::dynamic_pointer_cast<fma_node<T, SAFE_MATH>> (x);
5404 }
5405}
5406
5407#endif /* arithmetic_h */
Class representing a generic buffer.
Definition backend.hpp:29
void subtract_row(const buffer< T > &x)
Subtract row operation.
Definition backend.hpp:407
void multiply_row(const buffer< T > &x)
Multiply row operation.
Definition backend.hpp:479
void add_col(const buffer< T > &x)
Add col operation.
Definition backend.hpp:371
void divide_col(const buffer< T > &x)
Divide col operation.
Definition backend.hpp:587
void add_row(const buffer< T > &x)
Add row operation.
Definition backend.hpp:335
void subtract_col(const buffer< T > &x)
Subtract col operation.
Definition backend.hpp:443
void multiply_col(const buffer< T > &x)
Multiply col operation.
Definition backend.hpp:515
void divide_row(const buffer< T > &x)
Divide row operation.
Definition backend.hpp:551
bool is_zero() const
Is every element zero.
Definition backend.hpp:141
An addition node.
Definition arithmetic.hpp:132
virtual backend::buffer< T > evaluate()
Evaluate the results of addition.
Definition arithmetic.hpp:166
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition arithmetic.hpp:741
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition arithmetic.hpp:623
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Query if the nodes match.
Definition arithmetic.hpp:677
virtual shared_leaf< T, SAFE_MATH > remove_pseudo()
Remove pseudo variable nodes.
Definition arithmetic.hpp:726
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition arithmetic.hpp:645
add_node(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Construct an addition node.
Definition arithmetic.hpp:154
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduce an addition node.
Definition arithmetic.hpp:177
virtual void to_latex() const
Convert the node to latex.
Definition arithmetic.hpp:699
Class representing a branch node.
Definition node.hpp:1165
shared_leaf< T, SAFE_MATH > right
Right branch of the tree.
Definition node.hpp:1170
shared_leaf< T, SAFE_MATH > left
Left branch of the tree.
Definition node.hpp:1168
A division node.
Definition arithmetic.hpp:2769
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduce an division node.
Definition arithmetic.hpp:2822
virtual shared_leaf< T, SAFE_MATH > remove_pseudo()
Remove pseudo variable nodes.
Definition arithmetic.hpp:3586
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition arithmetic.hpp:3508
divide_node(shared_leaf< T, SAFE_MATH > n, shared_leaf< T, SAFE_MATH > d)
Construct an addition node.
Definition arithmetic.hpp:2791
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Query if the nodes match.
Definition arithmetic.hpp:3556
virtual backend::buffer< T > evaluate()
Evaluate the results of division.
Definition arithmetic.hpp:2803
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition arithmetic.hpp:3601
virtual void to_latex() const
Convert the node to latex.
Definition arithmetic.hpp:3573
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition arithmetic.hpp:3485
A fused multiply add node.
Definition arithmetic.hpp:3736
virtual shared_leaf< T, SAFE_MATH > remove_pseudo()
Remove pseudo variable nodes.
Definition arithmetic.hpp:5187
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition arithmetic.hpp:5051
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition arithmetic.hpp:5079
virtual backend::buffer< T > evaluate()
Evaluate the results of fused multiply add.
Definition arithmetic.hpp:3812
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition arithmetic.hpp:5203
fma_node(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > m, shared_leaf< T, SAFE_MATH > r)
Construct a fused multiply add node.
Definition arithmetic.hpp:3798
virtual void to_latex() const
Convert the node to latex.
Definition arithmetic.hpp:5158
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduce a fused multiply add node.
Definition arithmetic.hpp:3831
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Query if the nodes match.
Definition arithmetic.hpp:5140
Class representing a node leaf.
Definition node.hpp:364
virtual void endline(std::ostringstream &stream, const jit::register_usage &usage) const final
End a line in the kernel source.
Definition node.hpp:639
virtual backend::buffer< T > evaluate()=0
Evaluate method.
std::map< size_t, std::shared_ptr< leaf_node< T, SAFE_MATH > > > df_cache
Cache derivative terms.
Definition node.hpp:371
virtual bool has_pseudo() const
Query if the node contains pseudo variables.
Definition node.hpp:620
const size_t hash
Hash for node.
Definition node.hpp:367
A multiplication node.
Definition arithmetic.hpp:1720
virtual backend::buffer< T > evaluate()
Evaluate the results of multiplication.
Definition arithmetic.hpp:1859
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition arithmetic.hpp:2493
virtual shared_leaf< T, SAFE_MATH > remove_pseudo()
Remove pseudo variable nodes.
Definition arithmetic.hpp:2621
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduce an multiplication node.
Definition arithmetic.hpp:1884
multiply_node(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Construct a multiplication node.
Definition arithmetic.hpp:1848
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Query if the nodes match.
Definition arithmetic.hpp:2572
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition arithmetic.hpp:2516
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition arithmetic.hpp:2636
virtual void to_latex() const
Convert the node to latex.
Definition arithmetic.hpp:2594
A subtraction node.
Definition arithmetic.hpp:879
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition arithmetic.hpp:1475
subtract_node(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Construct a subtraction node.
Definition arithmetic.hpp:901
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition arithmetic.hpp:1453
virtual shared_leaf< T, SAFE_MATH > remove_pseudo()
Remove pseudo variable nodes.
Definition arithmetic.hpp:1551
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Query if the nodes match.
Definition arithmetic.hpp:1507
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduce an subtraction node.
Definition arithmetic.hpp:924
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition arithmetic.hpp:1566
virtual backend::buffer< T > evaluate()
Evaluate the results of subtraction.
Definition arithmetic.hpp:913
virtual void to_latex() const
Convert the node to latex.
Definition arithmetic.hpp:1524
Class representing a triple branch node.
Definition node.hpp:1289
shared_leaf< T, SAFE_MATH > middle
Middle branch of the tree.
Definition node.hpp:1292
Complex scalar concept.
Definition register.hpp:24
subroutine assert(test, message)
Assert check.
Definition f_binding_test.f90:38
buffer< T > fma(buffer< T > &a, buffer< T > &b, buffer< T > &c)
Fused multiply add operation.
Definition backend.hpp:950
Name space for graph nodes.
Definition arithmetic.hpp:13
shared_piecewise_2D< T, SAFE_MATH > piecewise_2D_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a piecewise 2D node.
Definition piecewise.hpp:1406
shared_leaf< T, SAFE_MATH > operator*(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Build multiply node from two leaves.
Definition arithmetic.hpp:2698
std::shared_ptr< add_node< T, SAFE_MATH > > shared_add
Convenience type alias for shared add nodes.
Definition arithmetic.hpp:851
shared_pow< T, SAFE_MATH > pow_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a power node.
Definition math.hpp:1416
shared_subtract< T, SAFE_MATH > subtract_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a subtract node.
Definition arithmetic.hpp:1706
constexpr shared_leaf< T, SAFE_MATH > zero()
Forward declare for zero.
Definition node.hpp:986
shared_leaf< T, SAFE_MATH > operator/(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Build divide node from two leaves.
Definition arithmetic.hpp:3663
shared_leaf< T, SAFE_MATH > multiply(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Build multiply node from two leaves.
Definition arithmetic.hpp:2664
shared_add< T, SAFE_MATH > add_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a add node.
Definition arithmetic.hpp:863
bool is_greater_exponent(shared_leaf< T, SAFE_MATH > a, shared_leaf< T, SAFE_MATH > b)
Check if the exponent is greater than the other.
Definition arithmetic.hpp:111
shared_leaf< T, SAFE_MATH > operator+(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Build add node from two leaves.
Definition arithmetic.hpp:806
std::shared_ptr< divide_node< T, SAFE_MATH > > shared_divide
Convenience type alias for shared divide nodes.
Definition arithmetic.hpp:3708
shared_leaf< T, SAFE_MATH > pow(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Build power node.
Definition math.hpp:1352
shared_sine< T, SAFE_MATH > sin_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a sine node.
Definition trigonometry.hpp:262
shared_divide< T, SAFE_MATH > divide_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a divide node.
Definition arithmetic.hpp:3720
shared_piecewise_1D< T, SAFE_MATH > piecewise_1D_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a piecewise 1D node.
Definition piecewise.hpp:629
shared_leaf< T, SAFE_MATH > operator-(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Build subtract node from two leaves.
Definition arithmetic.hpp:1630
shared_leaf< T, SAFE_MATH > divide(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Build divide node from two leaves.
Definition arithmetic.hpp:3629
shared_multiply< T, SAFE_MATH > multiply_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a multiply node.
Definition arithmetic.hpp:2755
shared_leaf< T, SAFE_MATH > exp(shared_leaf< T, SAFE_MATH > x)
Define exp convenience function.
Definition math.hpp:544
shared_exp< T, SAFE_MATH > exp_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a exp node.
Definition math.hpp:578
bool is_constant_promotable(shared_leaf< T, SAFE_MATH > a, shared_leaf< T, SAFE_MATH > b)
Check if the constants are promotable.
Definition arithmetic.hpp:53
shared_leaf< T, SAFE_MATH > fma(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > m, shared_leaf< T, SAFE_MATH > r)
Build fused multiply add node.
Definition arithmetic.hpp:5234
shared_constant< T, SAFE_MATH > constant_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a constant node.
Definition node.hpp:1034
bool is_variable_combinable(shared_leaf< T, SAFE_MATH > a, shared_leaf< T, SAFE_MATH > b)
Check if the variable is combinable.
Definition arithmetic.hpp:75
constexpr T i
Convenience type for imaginary constant.
Definition node.hpp:1018
shared_leaf< T, SAFE_MATH > subtract(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Build subtract node from two leaves.
Definition arithmetic.hpp:1595
std::shared_ptr< fma_node< T, SAFE_MATH > > shared_fma
Convenience type alias for shared add nodes.
Definition arithmetic.hpp:5390
std::shared_ptr< leaf_node< T, SAFE_MATH > > shared_leaf
Convenience type alias for shared leaf nodes.
Definition node.hpp:676
bool is_variable_promotable(shared_leaf< T, SAFE_MATH > a, shared_leaf< T, SAFE_MATH > b)
Check if the variable is variable is promotable.
Definition arithmetic.hpp:91
std::shared_ptr< multiply_node< T, SAFE_MATH > > shared_multiply
Convenience type alias for shared multiply nodes.
Definition arithmetic.hpp:2743
shared_leaf< T, SAFE_MATH > add(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Build add node from two leaves.
Definition arithmetic.hpp:772
shared_fma< T, SAFE_MATH > fma_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a fma node.
Definition arithmetic.hpp:5402
shared_cosine< T, SAFE_MATH > cos_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a cosine node.
Definition trigonometry.hpp:520
bool is_constant_combinable(shared_leaf< T, SAFE_MATH > a, shared_leaf< T, SAFE_MATH > b)
Check if nodes are constant combinable.
Definition arithmetic.hpp:25
std::shared_ptr< subtract_node< T, SAFE_MATH > > shared_subtract
Convenience type alias for shared subtract nodes.
Definition arithmetic.hpp:1694
std::string format_to_string(const T value)
Convert a value to a string while avoiding locale.
Definition register.hpp:212
std::map< void *, size_t > register_usage
Type alias for counting register usage.
Definition register.hpp:259
std::map< void *, std::string > register_map
Type alias for mapping node pointers to register names.
Definition register.hpp:257
std::string to_string(const char prefix, const NODE *pointer)
Convert a graph::leaf_node pointer to a string.
Definition register.hpp:246
Base nodes of graph computation framework.
void piecewise_1D()
Tests for 1D piecewise nodes.
Definition piecewise_test.cpp:80
void piecewise_2D()
Tests for 2D piecewise nodes.
Definition piecewise_test.cpp:306