Graph Framework
Loading...
Searching...
No Matches
arithmetic.hpp
Go to the documentation of this file.
1//------------------------------------------------------------------------------
6//------------------------------------------------------------------------------
7
8#ifndef arithmetic_h
9#define arithmetic_h
10
11#include "node.hpp"
12
13namespace graph {
14//------------------------------------------------------------------------------
23//------------------------------------------------------------------------------
24 template<jit::float_scalar T, bool SAFE_MATH=false>
27 if (a->is_constant() && b->is_constant()) {
28 auto a1 = piecewise_1D_cast(a);
29 auto a2 = piecewise_2D_cast(a);
30 auto b2 = piecewise_2D_cast(b);
31
32 return constant_cast(a).get() ||
33 constant_cast(b).get() ||
34 (a1.get() && a1->is_arg_match(b)) ||
35 (a2.get() && a2->is_arg_match(b)) ||
36 (a2.get() && (a2->is_row_match(b) || a2->is_col_match(b))) ||
37 (b2.get() && (b2->is_row_match(a) || b2->is_col_match(a)));
38 }
39 return false;
40 }
41
42//------------------------------------------------------------------------------
51//------------------------------------------------------------------------------
52 template<jit::float_scalar T, bool SAFE_MATH=false>
55 auto b1 = piecewise_1D_cast(b);
56 auto b2 = piecewise_2D_cast(b);
57
58 return a->is_constant() &&
59 (!b->is_constant() ||
60 (constant_cast(a).get() && (b1.get() || b2.get())) ||
61 (piecewise_1D_cast(a).get() && b2.get()));
62 }
63
64//------------------------------------------------------------------------------
73//------------------------------------------------------------------------------
74 template<jit::float_scalar T, bool SAFE_MATH=false>
77 return a->is_power_base_match(b);
78 }
79
80//------------------------------------------------------------------------------
89//------------------------------------------------------------------------------
90 template<jit::float_scalar T, bool SAFE_MATH=false>
93 return !b->is_constant() &&
94 (a->is_all_variables() &&
95 (!b->is_all_variables() ||
96 (b->is_all_variables() &&
97 a->get_complexity() < b->get_complexity())));
98 }
99
100//------------------------------------------------------------------------------
109//------------------------------------------------------------------------------
110 template<jit::float_scalar T, bool SAFE_MATH=false>
113 auto ae = constant_cast(a->get_power_exponent());
114 auto be = constant_cast(b->get_power_exponent());
115
116 return ae.get() && be.get() &&
117 std::abs(ae->evaluate().at(0)) > std::abs(be->evaluate().at(0));
118 }
119
120//******************************************************************************
121// Add node.
122//******************************************************************************
123//------------------------------------------------------------------------------
130//------------------------------------------------------------------------------
131 template<jit::float_scalar T, bool SAFE_MATH=false>
132 class add_node final : public branch_node<T, SAFE_MATH> {
133 private:
134//------------------------------------------------------------------------------
140//------------------------------------------------------------------------------
141 static std::string to_string(leaf_node<T, SAFE_MATH> *l,
143 return jit::format_to_string(reinterpret_cast<size_t> (l)) + "+" +
144 jit::format_to_string(reinterpret_cast<size_t> (r));
145 }
146
147 public:
148//------------------------------------------------------------------------------
153//------------------------------------------------------------------------------
158
159//------------------------------------------------------------------------------
165//------------------------------------------------------------------------------
167 backend::buffer<T> l_result = this->left->evaluate();
168 backend::buffer<T> r_result = this->right->evaluate();
169 return l_result + r_result;
170 }
171
172//------------------------------------------------------------------------------
176//------------------------------------------------------------------------------
178// Constant reductions.
179 auto l = constant_cast(this->left);
180 auto r = constant_cast(this->right);
181
182 if (l.get() && l->is(0)) {
183 return this->right;
184 } else if (r.get() && r->is(0)) {
185 return this->left;
186 } else if (l.get() && r.get()) {
187 return constant<T, SAFE_MATH> (this->evaluate());
188 } else if (r.get() && !l.get()) {
189 return this->right + this->left;
190 }
191
192 auto pl1 = piecewise_1D_cast(this->left);
193 auto pr1 = piecewise_1D_cast(this->right);
194
195 if (pl1.get() && (r.get() || pl1->is_arg_match(this->right))) {
196 return piecewise_1D(this->evaluate(), pl1->get_arg(),
197 pl1->get_scale(), pl1->get_offset());
198 } else if (pr1.get() && (l.get() || pr1->is_arg_match(this->left))) {
199 return piecewise_1D(this->evaluate(), pr1->get_arg(),
200 pr1->get_scale(), pr1->get_offset());
201 }
202
203 auto pl2 = piecewise_2D_cast(this->left);
204 auto pr2 = piecewise_2D_cast(this->right);
205
206 if (pl2.get() && (r.get() || pl2->is_arg_match(this->right))) {
207 return piecewise_2D(this->evaluate(),
208 pl2->get_num_columns(),
209 pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
210 pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
211 } else if (pr2.get() && (l.get() || pr2->is_arg_match(this->left))) {
212 return piecewise_2D(this->evaluate(),
213 pr2->get_num_columns(),
214 pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
215 pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
216 }
217
218// Combine 2D and 1D piecewise constants if a row or column matches.
219 if (pr2.get() && pr2->is_row_match(this->left)) {
220 backend::buffer<T> result = pl1->evaluate();
221 result.add_row(pr2->evaluate());
222 return piecewise_2D(result,
223 pr2->get_num_columns(),
224 pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
225 pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
226 } else if (pr2.get() && pr2->is_col_match(this->left)) {
227 backend::buffer<T> result = pl1->evaluate();
228 result.add_col(pr2->evaluate());
229 return piecewise_2D(result,
230 pr2->get_num_columns(),
231 pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
232 pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
233 } else if (pl2.get() && pl2->is_row_match(this->right)) {
234 backend::buffer<T> result = pl2->evaluate();
235 result.add_row(pr1->evaluate());
236 return piecewise_2D(result,
237 pl2->get_num_columns(),
238 pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
239 pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
240 } else if (pl2.get() && pl2->is_col_match(this->right)) {
241 backend::buffer<T> result = pl2->evaluate();
242 result.add_col(pr1->evaluate());
243 return piecewise_2D(result,
244 pl2->get_num_columns(),
245 pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
246 pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
247 }
248
249// Idenity reductions.
250 if (this->left->is_match(this->right)) {
251 return 2.0*this->left;
252 }
253
254// Common factor reduction. If the left and right are both muliply nodes check
255// for a common factor. So you can change a*b + a*c -> a*(b + c).
256 auto lm = multiply_cast(this->left);
257 auto rm = multiply_cast(this->right);
258
259// v1 + -c*v2 -> v1 - c*v2
260// -c*v1 + v2 -> v2 - c*v1
261 if (rm.get() &&
262 rm->get_left()->is_constant() &&
263 rm->get_left()->evaluate().is_negative()) {
264 return this->left - (-this->right);
265 } else if (rm.get() &&
266 rm->get_left()->is_constant() &&
267 rm->get_left()->evaluate().is_negative()) {
268 return this->right - (-this->left);
269 }
270
271// a*b + c -> fma(a,b,c)
272// a + b*c -> fma(b,c,a)
273 if (lm.get()) {
274 return fma(lm->get_left(), lm->get_right(), this->right);
275 } else if (rm.get()) {
276 return fma(rm->get_left(), rm->get_right(), this->left);
277 }
278
279// Common denominator reduction. If the left and right are both divide nodes
280// for a common denominator. So you can change a/b + c/b -> (a + c)/d.
281 auto ld = divide_cast(this->left);
282 auto rd = divide_cast(this->right);
283
284// c is a constant.
285// a + -c/b -> a - c/b
286// a + (-c*d)/b -> a - (c*d)/b
287// -c/a + b -> b - c/a
288// (-c*d)/a + b -> b - (c*d)/a
289 if (rd.get()) {
290 auto rdlm = multiply_cast(rd->get_left());
291 if ((rd->get_left()->is_constant() &&
292 rd->get_left()->evaluate().is_negative()) ||
293 (rdlm.get() &&
294 (rdlm->get_left()->is_constant() &&
295 rdlm->get_left()->evaluate().is_negative()))) {
296 return this->left - (-rd->get_left())/rd->get_right();
297 }
298 } else if (ld.get()) {
299 auto ldlm = multiply_cast(ld->get_left());
300 if ((ld->get_left()->is_constant() &&
301 ld->get_left()->evaluate().is_negative()) ||
302 (ldlm.get() &&
303 (ldlm->get_left()->is_constant() &&
304 ldlm->get_left()->evaluate().is_negative()))) {
305 return this->right - (-ld->get_left())/ld->get_right();
306 }
307 }
308
309 if (ld.get() && rd.get()) {
310 if (ld->get_right()->is_match(rd->get_right())) {
311 return (ld->get_left() + rd->get_left())/ld->get_right();
312 }
313
314 auto ldlm = multiply_cast(ld->get_left());
315 auto rdlm = multiply_cast(rd->get_left());
316// a/b - c*a/d -> (1/b - c/d)*a
317// a/b - a*c/d -> (1/b - c/d)*a
318// c*a/b - a/d -> (c/b - 1/d)*a
319// a*c/b - a/d -> (c/b - 1/d)*a
320 if (rdlm.get()) {
321 if (ld->get_left()->is_match(rdlm->get_left())) {
322 return (1.0/ld->get_right() +
323 rdlm->get_right()/rd->get_right())*rdlm->get_left();
324 } else if (ld->get_left()->is_match(rdlm->get_right())) {
325 return (1.0/ld->get_right() +
326 rdlm->get_left()/rd->get_right())*rdlm->get_right();
327 }
328 } else if (ldlm.get()) {
329 if (rd->get_left()->is_match(ldlm->get_left())) {
330 return (ldlm->get_right()/ld->get_right() +
331 1.0/rd->get_right())*ldlm->get_left();
332 } else if (rd->get_left()->is_match(ldlm->get_right())) {
333 return (ldlm->get_left()/ld->get_right() +
334 1.0/rd->get_right())*ldlm->get_right();
335 }
336 }
337
338// c1*a/b + c2*a/d = c3*(a/b + c4*a/d)
339// a*b/c + d*b/e -> (a/c + d/e)*b
340// Make sure we prevent combining constants when we just need to factor out a
341// common term.
342// c1*a/b + c2*a/d -> (c1/b + c2/d)*a
343 if (ldlm.get() && rdlm.get()) {
344 if (is_constant_combineable(ldlm->get_left(),
345 rdlm->get_left()) &&
346 !ldlm->get_right()->is_match(rdlm->get_right())) {
347 return (ldlm->get_right()/ld->get_right() +
348 rdlm->get_left()/ldlm->get_left() *
349 rdlm->get_right()/rd->get_right())*ldlm->get_left();
350 }
351
352 if (ldlm->get_right()->is_match(rdlm->get_right())) {
353 return (ldlm->get_left()/ld->get_right() +
354 rdlm->get_left()/rd->get_right())*ldlm->get_right();
355 } else if (ldlm->get_right()->is_match(rdlm->get_left())) {
356 return (ldlm->get_left()/ld->get_right() +
357 rdlm->get_right()/rd->get_right())*ldlm->get_right();
358 } else if (ldlm->get_left()->is_match(rdlm->get_right())) {
359 return (ldlm->get_right()/ld->get_right() +
360 rdlm->get_left()/rd->get_right())*ldlm->get_left();
361 } else if (ldlm->get_left()->is_match(rdlm->get_left())) {
362 return (ldlm->get_right()/ld->get_right() +
363 rdlm->get_right()/rd->get_right())*ldlm->get_left();
364 }
365 }
366
367// (a/(c*b) + d/(e*c)) -> (a/b + d/e)/c
368// (a/(b*c) + d/(e*c)) -> (a/b + d/e)/c
369// (a/(c*b) + d/(c*e)) -> (a/b + d/e)/c
370// (a/(b*c) + d/(c*e)) -> (a/b + d/e)/c
371 auto ldrm = multiply_cast(ld->get_right());
372 auto rdrm = multiply_cast(rd->get_right());
373 if (ldrm.get() && rdrm.get()) {
374 if (ldrm->get_right()->is_match(rdrm->get_right())) {
375 return (ld->get_left()/ldrm->get_left() +
376 rd->get_left()/rdrm->get_left())/ldrm->get_right();
377 } else if (ldrm->get_right()->is_match(rdrm->get_left())) {
378 return (ld->get_left()/ldrm->get_left() +
379 rd->get_left()/rdrm->get_right())/ldrm->get_right();
380 } else if (ldrm->get_left()->is_match(rdrm->get_right())) {
381 return (ld->get_left()/ldrm->get_right() +
382 rd->get_left()/rdrm->get_left())/ldrm->get_left();
383 } else if (ldrm->get_left()->is_match(rdrm->get_left())) {
384 return (ld->get_left()/ldrm->get_right() +
385 rd->get_left()/rdrm->get_right())/ldrm->get_left();
386 }
387 }
388
389// a/b + c/(b*d) -> (a*b + c)/(b*d)
390// a/b + c/(d*b) -> (a*b + c)/(b*d)
391// a/(b*d) + c/b -> (c*b + a)/(b*d)
392// a/(d*b) + c/b -> (c*b + a)/(b*d)
393 if (rdrm.get()) {
394 if (ld->get_right()->is_match(rdrm->get_left())) {
395 return fma(ld->get_left(),
396 rdrm->get_right(),
397 rd->get_left()) /
398 rd->get_right();
399 } else if (ld->get_right()->is_match(rdrm->get_right())) {
400 return fma(ld->get_left(),
401 rdrm->get_left(),
402 rd->get_left()) /
403 rd->get_right();
404 }
405 } else if (ldrm.get()) {
406 if (rd->get_right()->is_match(ldrm->get_left())) {
407 return fma(rd->get_left(),
408 ldrm->get_right(),
409 ld->get_left()) /
410 ld->get_right();
411 } else if (rd->get_right()->is_match(ldrm->get_right())) {
412 return fma(rd->get_left(),
413 ldrm->get_left(),
414 ld->get_left()) /
415 ld->get_right();
416 }
417 }
418 }
419
420// Chained addition reductions.
421// a + (a + b) = fma(2,a,b)
422// a + (b + a) = fma(2,a,b)
423// (a + b) + a = fma(2,a,b)
424// (b + a) + a = fma(2,a,b)
425 auto la = add_cast(this->left);
426 if (la.get()) {
427 if (this->right->is_match(la->get_left())) {
428 return fma(2.0, this->right, la->get_right());
429 } else if (this->right->is_match(la->get_right())) {
430 return fma(2.0, this->right, la->get_left());
431 }
432 }
433 auto ra = add_cast(this->right);
434 if (ra.get()) {
435 if (this->left->is_match(ra->get_left())) {
436 return fma(2.0, this->left, ra->get_right());
437 } else if (this->left->is_match(ra->get_right())) {
438 return fma(2.0, this->left, ra->get_left());
439 }
440 }
441
442// Move cases like
443// (c1 + c2/x) + c3/y -> c1 + (c2/x + c3/y)
444// (c1 - c2/x) + c3/y -> c1 + (c3/y - c2/x)
445// in case of common denominators.
446 if (rd.get()) {
447 if (la.get() && divide_cast(la->get_right()).get()) {
448 return la->get_left() + (la->get_right() + this->right);
449 }
450
451 auto ls = subtract_cast(this->left);
452 if (ls.get() && divide_cast(ls->get_right()).get()) {
453 return ls->get_left() + (this->right - ls->get_right());
454 }
455 }
456
457 auto lfma = fma_cast(this->left);
458 auto rfma = fma_cast(this->right);
459 if (lfma.get()) {
460// fma(c,d,e) + a -> fma(c,d,e + a)
461 return fma(lfma->get_left(),
462 lfma->get_middle(),
463 lfma->get_right() + this->right);
464 } else if (rfma.get()) {
465// a + fma(c,d,e) -> fma(c,d,a + e)
466 return fma(rfma->get_left(),
467 rfma->get_middle(),
468 this->left + rfma->get_right());
469 }
470
471// fma(b,a,d) + fma(c,a,e) -> fma(a,b + c, d + e)
472// fma(a,b,d) + fma(c,a,e) -> fma(a,b + c, d + e)
473// fma(b,a,d) + fma(a,c,e) -> fma(a,b + c, d + e)
474// fma(a,b,d) + fma(a,c,e) -> fma(a,b + c, d + e)
475 if (lfma.get() && rfma.get()) {
476 if (lfma->get_middle()->is_match(rfma->get_middle())) {
477 return fma(lfma->get_middle(),
478 lfma->get_left() + rfma->get_left(),
479 lfma->get_right() + rfma->get_right());
480 } else if (lfma->get_left()->is_match(rfma->get_middle())) {
481 return fma(lfma->get_left(),
482 lfma->get_middle() + rfma->get_left(),
483 lfma->get_right() + rfma->get_right());
484 } else if (lfma->get_middle()->is_match(rfma->get_left())) {
485 return fma(lfma->get_middle(),
486 lfma->get_left() + rfma->get_middle(),
487 lfma->get_right() + rfma->get_right());
488 } else if (lfma->get_left()->is_match(rfma->get_left())) {
489 return fma(lfma->get_left(),
490 lfma->get_middle() + rfma->get_middle(),
491 lfma->get_right() + rfma->get_right());
492 }
493 }
494
495 auto pl = pow_cast(this->left);
496 auto pr = pow_cast(this->right);
497
498// (a*b)^c + (a*d)^c -> a^c*(b^c + d^c)
499// (b*a)^c + (a*d)^c -> a^c*(b^c + d^c)
500// (a*b)^c + (d*a)^c -> a^c*(b^c + d^c)
501// (b*a)^c + (d*a)^c -> a^c*(b^c + d^c)
502 if (pl.get() && pr.get() &&
503 pl->get_right()->is_match(pr->get_right())) {
504 auto plm = multiply_cast(pl->get_left());
505 auto prm = multiply_cast(pr->get_left());
506 if (plm.get() && prm.get()) {
507 if (plm->get_left()->is_match(prm->get_left())) {
508 return pow(plm->get_left(), pl->get_right())*
509 (pow(plm->get_right(), pl->get_right()) +
510 pow(prm->get_right(), pl->get_right()));
511 } else if (plm->get_left()->is_match(prm->get_right())) {
512 return pow(plm->get_left(), pl->get_right())*
513 (pow(plm->get_right(), pl->get_right()) +
514 pow(prm->get_left(), pl->get_right()));
515 } else if (plm->get_right()->is_match(prm->get_left())) {
516 return pow(plm->get_right(), pl->get_right())*
517 (pow(plm->get_left(), pl->get_right()) +
518 pow(prm->get_right(), pl->get_right()));
519 } else if (plm->get_right()->is_match(prm->get_right())) {
520 return pow(plm->get_right(), pl->get_right())*
521 (pow(plm->get_left(), pl->get_right()) +
522 pow(prm->get_left(), pl->get_right()));
523 }
524 }
525
526// cos(x)^2 + sin(x)^2 -> 1
527// sin(x)^2 + cos(x)^2 -> 1
528 auto plrc = constant_cast(pl->get_right());
529 if (plrc.get() && plrc->is(static_cast<T> (2.0))) {
530 auto pls = sin_cast(pl->get_left());
531 auto prc = cos_cast(pr->get_left());
532 auto plc = cos_cast(pl->get_left());
533 auto prs = sin_cast(pr->get_left());
534 if ((pls.get() && prc.get() && pls->get_arg()->is_match(prc->get_arg())) ||
535 (plc.get() && prs.get() && plc->get_arg()->is_match(prs->get_arg()))) {
536 return one<T, SAFE_MATH> ();
537 }
538 }
539 }
540
541// (a/y)^e + b/y^e -> (a^2 + b)/(y^e)
542// b/y^e + (a/y)^e -> (b + a^2)/(y^e)
543// (a/y)^e + (b/y)^e -> (a^2 + b^2)/(y^e)
544 if (pl.get() && rd.get()) {
545 auto rdp = pow_cast(rd->get_right());
546 if (rdp.get() && pl->get_right()->is_match(rdp->get_right())) {
547 auto plld = divide_cast(pl->get_left());
548 if (plld.get() &&
549 rdp->get_left()->is_match(plld->get_right())) {
550 return (pow(plld->get_left(), pl->get_right()) +
551 rd->get_left()) /
552 pow(rdp->get_left(), pl->get_right());
553 }
554 }
555 } else if (pr.get() && ld.get()) {
556 auto ldp = pow_cast(ld->get_right());
557 if (ldp.get() && pr->get_right()->is_match(ldp->get_right())) {
558 auto prld = divide_cast(pr->get_left());
559 if (prld.get() &&
560 ldp->get_left()->is_match(prld->get_right())) {
561 return (pow(prld->get_left(), pr->get_right()) +
562 ld->get_left()) /
563 pow(ldp->get_left(), pr->get_right());
564 }
565 }
566 } else if (pl.get() && pr.get()) {
567 if (pl->get_right()->is_match(pr->get_right())) {
568 auto pld = divide_cast(pl->get_left());
569 auto prd = divide_cast(pr->get_left());
570 if (pld.get() && prd.get() &&
571 pld->get_right()->is_match(prd->get_right())) {
572 return (pow(pld->get_left(), pl->get_right()) +
573 pow(prd->get_left(), pl->get_right())) /
574 pow(pld->get_right(), pl->get_right());
575 }
576 }
577 }
578
579 return this->shared_from_this();
580 }
581
582//------------------------------------------------------------------------------
589//------------------------------------------------------------------------------
592 if (this->is_match(x)) {
593 return one<T, SAFE_MATH> ();
594 }
595
596 const size_t hash = reinterpret_cast<size_t> (x.get());
597 if (this->df_cache.find(hash) == this->df_cache.end()) {
598 this->df_cache[hash] = this->left->df(x) + this->right->df(x);
599 }
600 return this->df_cache[hash];
601 }
602
603//------------------------------------------------------------------------------
611//------------------------------------------------------------------------------
613 compile(std::ostringstream &stream,
614 jit::register_map &registers,
616 const jit::register_usage &usage) {
617 if (registers.find(this) == registers.end()) {
618 shared_leaf<T, SAFE_MATH> l = this->left->compile(stream,
619 registers,
620 indices,
621 usage);
622 shared_leaf<T, SAFE_MATH> r = this->right->compile(stream,
623 registers,
624 indices,
625 usage);
626
627 registers[this] = jit::to_string('r', this);
628 stream << " const ";
629 jit::add_type<T> (stream);
630 stream << " " << registers[this] << " = "
631 << registers[l.get()] << " + "
632 << registers[r.get()];
633 this->endline(stream, usage);
634 }
635
636 return this->shared_from_this();
637 }
638
639//------------------------------------------------------------------------------
644//------------------------------------------------------------------------------
646 if (this == x.get()) {
647 return true;
648 }
649
650 auto x_cast = add_cast(x);
651 if (x_cast.get()) {
652// Addition is commutative.
653 if ((this->left->is_match(x_cast->get_left()) &&
654 this->right->is_match(x_cast->get_right())) ||
655 (this->right->is_match(x_cast->get_left()) &&
656 this->left->is_match(x_cast->get_right()))) {
657 return true;
658 }
659 }
660
661 return false;
662 }
663
664//------------------------------------------------------------------------------
666//------------------------------------------------------------------------------
667 virtual void to_latex() const {
668 bool l_brackets = add_cast(this->left).get() ||
669 subtract_cast(this->left).get();
670 bool r_brackets = add_cast(this->right).get() ||
671 subtract_cast(this->right).get();
672 if (l_brackets) {
673 std::cout << "\\left(";
674 }
675 this->left->to_latex();
676 if (l_brackets) {
677 std::cout << "\\right)";
678 }
679 std::cout << "+";
680 if (r_brackets) {
681 std::cout << "\\left(";
682 }
683 this->right->to_latex();
684 if (r_brackets) {
685 std::cout << "\\right)";
686 }
687 }
688
689//------------------------------------------------------------------------------
693//------------------------------------------------------------------------------
695 if (this->has_pseudo()) {
696 return this->left->remove_pseudo() +
697 this->right->remove_pseudo();
698 }
699 return this->shared_from_this();
700 }
701
702//------------------------------------------------------------------------------
708//------------------------------------------------------------------------------
709 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
710 jit::register_map &registers) {
711 if (registers.find(this) == registers.end()) {
712 const std::string name = jit::to_string('r', this);
713 registers[this] = name;
714 stream << " " << name
715 << " [label = \"+\", shape = oval, style = filled, fillcolor = blue, fontcolor = white];" << std::endl;
716
717 auto l = this->left->to_vizgraph(stream, registers);
718 stream << " " << name << " -- " << registers[l.get()] << ";" << std::endl;
719 auto r = this->right->to_vizgraph(stream, registers);
720 stream << " " << name << " -- " << registers[r.get()] << ";" << std::endl;
721 }
722
723 return this->shared_from_this();
724 }
725 };
726
727//------------------------------------------------------------------------------
738//------------------------------------------------------------------------------
739 template<jit::float_scalar T, bool SAFE_MATH=false>
742 auto temp = std::make_shared<add_node<T, SAFE_MATH>> (l, r)->reduce();
743// Test for hash collisions.
744 for (size_t i = temp->get_hash();
746 if (leaf_node<T, SAFE_MATH>::caches.nodes.find(i) ==
749 return temp;
750 } else if (temp->is_match(leaf_node<T, SAFE_MATH>::caches.nodes[i])) {
752 }
753 }
754#if defined(__clang__) || defined(__GNUC__)
756#else
757 assert(false && "Should never reach.");
758#endif
759 }
760
761//------------------------------------------------------------------------------
772//------------------------------------------------------------------------------
773 template<jit::float_scalar T, bool SAFE_MATH=false>
778
779//------------------------------------------------------------------------------
791//------------------------------------------------------------------------------
792 template<jit::float_scalar T, jit::float_scalar L, bool SAFE_MATH=false>
797
798//------------------------------------------------------------------------------
810//------------------------------------------------------------------------------
811 template<jit::float_scalar T, jit::float_scalar R, bool SAFE_MATH=false>
816
818 template<jit::float_scalar T, bool SAFE_MATH=false>
819 using shared_add = std::shared_ptr<add_node<T, SAFE_MATH>>;
820
821//------------------------------------------------------------------------------
829//------------------------------------------------------------------------------
830 template<jit::float_scalar T, bool SAFE_MATH=false>
832 return std::dynamic_pointer_cast<add_node<T, SAFE_MATH>> (x);
833 }
834
835//******************************************************************************
836// Subtract node.
837//******************************************************************************
838//------------------------------------------------------------------------------
845//------------------------------------------------------------------------------
846 template<jit::float_scalar T, bool SAFE_MATH=false>
847 class subtract_node final : public branch_node<T, SAFE_MATH> {
848 private:
849//------------------------------------------------------------------------------
855//------------------------------------------------------------------------------
856 static std::string to_string(leaf_node<T, SAFE_MATH> *l,
858 return jit::format_to_string(reinterpret_cast<size_t> (l)) + "-" +
859 jit::format_to_string(reinterpret_cast<size_t> (r));
860 }
861
862 public:
863//------------------------------------------------------------------------------
868//------------------------------------------------------------------------------
873
874//------------------------------------------------------------------------------
880//------------------------------------------------------------------------------
882 backend::buffer<T> l_result = this->left->evaluate();
883 backend::buffer<T> r_result = this->right->evaluate();
884 return l_result - r_result;
885 }
886
887//------------------------------------------------------------------------------
891//------------------------------------------------------------------------------
893// Idenity reductions.
894 auto l = constant_cast(this->left);
895 if (this->left->is_match(this->right)) {
896 if (l.get() && l->is(0)) {
897 return this->left;
898 }
899
900 return zero<T, SAFE_MATH> ();
901 }
902
903// Constant reductions.
904 auto r = constant_cast(this->right);
905
906 if (l.get() && l->is(0)) {
907 return -this->right;
908 } else if (r.get() && r->is(0)) {
909 return this->left;
910 } else if (l.get() && r.get()) {
911 return constant<T, SAFE_MATH> (this->evaluate());
912 } else if (r.get() && r->evaluate().is_negative()) {
913 return this->left + -this->right;
914 }
915
916 auto pl1 = piecewise_1D_cast(this->left);
917 auto pr1 = piecewise_1D_cast(this->right);
918
919 if (pl1.get() && (r.get() || pl1->is_arg_match(this->right))) {
920 return piecewise_1D(this->evaluate(), pl1->get_arg(),
921 pl1->get_scale(), pl1->get_offset());
922 } else if (pr1.get() && (l.get() || pr1->is_arg_match(this->left))) {
923 return piecewise_1D(this->evaluate(), pr1->get_arg(),
924 pr1->get_scale(), pr1->get_offset());
925 }
926
927 auto pl2 = piecewise_2D_cast(this->left);
928 auto pr2 = piecewise_2D_cast(this->right);
929
930 if (pl2.get() && (r.get() || pl2->is_arg_match(this->right))) {
931 return piecewise_2D(this->evaluate(),
932 pl2->get_num_columns(),
933 pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
934 pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
935 } else if (pr2.get() && (l.get() || pr2->is_arg_match(this->left))) {
936 return piecewise_2D(this->evaluate(),
937 pr2->get_num_columns(),
938 pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
939 pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
940 }
941
942// Combine 2D and 1D piecewise constants if a row or column matches.
943 if (pr2.get() && pr2->is_row_match(this->left)) {
944 backend::buffer<T> result = pl1->evaluate();
945 result.subtract_row(pr2->evaluate());
946 return piecewise_2D(result,
947 pr2->get_num_columns(),
948 pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
949 pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
950 } else if (pr2.get() && pr2->is_col_match(this->left)) {
951 backend::buffer<T> result = pl1->evaluate();
952 result.subtract_col(pr2->evaluate());
953 return piecewise_2D(result,
954 pr2->get_num_columns(),
955 pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
956 pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
957 } else if (pl2.get() && pl2->is_row_match(this->right)) {
958 backend::buffer<T> result = pl2->evaluate();
959 result.subtract_row(pr1->evaluate());
960 return piecewise_2D(result,
961 pl2->get_num_columns(),
962 pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
963 pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
964 } else if (pl2.get() && pl2->is_col_match(this->right)) {
965 backend::buffer<T> result = pl2->evaluate();
966 result.subtract_col(pr1->evaluate());
967 return piecewise_2D(result,
968 pl2->get_num_columns(),
969 pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
970 pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
971 }
972// (c1 + a) - c2 -> c3 + a
973// c1 - (c2 + a) -> c3 + a
974 auto la = add_cast(this->left);
975 if (la.get()) {
976 if (is_constant_combineable(la->get_left(), this->right)) {
977 return (la->get_left() - this->right) + la->get_right();
978 }
979 }
980 auto ra = add_cast(this->right);
981 if (ra.get()) {
982 if (is_constant_combineable(this->left, ra->get_left())) {
983 return (this->left - ra->get_left()) + ra->get_right();
984 }
985 }
986
987// (c1 - a) - c2 -> c3 - a
988// (a - c3) - c2 -> a + c3
989 auto ls = subtract_cast(this->left);
990 if (ls.get()) {
991 if (is_constant_combineable(ls->get_left(), this->right)) {
992 return (ls->get_left() - this->right) - ls->get_right();
993 } else if (is_constant_combineable(ls->get_right(),
994 this->right)) {
995 return -(ls->get_right() + this->right) - ls->get_left();
996 }
997 }
998// c1 - (c2 - a) -> c3 - a
999// c1 - (a - c2) -> c3 - a
1000 auto rs = subtract_cast(this->right);
1001 if (rs.get()) {
1002 if (is_constant_combineable(this->left, rs->get_left())) {
1003 return (this->left - rs->get_left()) - rs->get_right();
1004 } else if (is_constant_combineable(this->left, rs->get_right())) {
1005 return (this->left + rs->get_right()) - rs->get_left();
1006 }
1007 }
1008
1009// Common factor reduction. If the left and right are both muliply nodes check
1010// for a common factor. So you can change a*b - a*c -> a*(b - c).
1011 auto lm = multiply_cast(this->left);
1012 auto rm = multiply_cast(this->right);
1013
1014// c1*(c2 + a) - c3 -> fma(c1,a,c4)
1015 if (lm.get()) {
1016 auto lmra = add_cast(lm->get_right());
1017 if (lmra.get()) {
1018 if (is_constant_combineable(lm->get_left(),
1019 lmra->get_left()) &&
1020 is_constant_combineable(lm->get_left(),
1021 this->right)) {
1022 return fma(lm->get_left(),
1023 lmra->get_right(),
1024 lm->get_left()*lmra->get_left() - this->right);
1025 }
1026 }
1027// c1*(c2 - a) - c3 -> c4 - c1*a
1028 auto lmrs = subtract_cast(lm->get_right());
1029 if (lmrs.get()) {
1030 if (is_constant_combineable(lm->get_left(),
1031 lmrs->get_left()) &&
1032 is_constant_combineable(lm->get_left(),
1033 this->right)) {
1034 return lm->get_left()*lmrs->get_left() - this->right -
1035 lm->get_left()*lmrs->get_right();
1036 }
1037 }
1038 }
1039
1040// Assume constants are on the left.
1041// v1 - -c*v2 -> v1 + c*v2
1042 if (rm.get() &&
1043 rm->get_left()->is_constant() &&
1044 rm->get_left()->evaluate().is_negative()) {
1045 return this->left + (-this->right);
1046 }
1047
1048 if (lm.get()) {
1049// Assume constants are on the left.
1050// -a - b -> -(a + b)
1051 auto lmc = constant_cast(lm->get_left());
1052 if (lmc.get() && lmc->is(-1)) {
1053 return lm->get_left()*(lm->get_right() + this->right);
1054 }
1055
1056// a*v - v = (a - 1)*v
1057// v*a - v = (a - 1)*v
1058 if (this->right->is_match(lm->get_right())) {
1059 return (lm->get_left() - 1.0)*this->right;
1060 } else if (this->right->is_match(lm->get_left())) {
1061 return (lm->get_right() - 1.0)*this->right;
1062 }
1063 }
1064// v - a*v = (1 - a)*v
1065// v - v*a = (1 - a)*v
1066 if (rm.get()) {
1067 if (this->left->is_match(rm->get_right())) {
1068 return (1.0 - rm->get_left())*this->left;
1069 } else if (this->left->is_match(rm->get_left())) {
1070 return (1.0 - rm->get_right())*this->left;
1071 }
1072 }
1073
1074 if (lm.get() && rm.get()) {
1075 if (lm->get_left()->is_match(rm->get_left())) {
1076// a*b - a*c -> a*(b - c)
1077 return lm->get_left()*(lm->get_right() - rm->get_right());
1078 } else if (lm->get_left()->is_match(rm->get_right())) {
1079// a*b - c*a -> a*(b - c)
1080 return lm->get_left()*(lm->get_right() - rm->get_left());
1081 } else if (lm->get_right()->is_match(rm->get_left())) {
1082// b*a - a*c -> a*(b - c)
1083 return lm->get_right()*(lm->get_left() - rm->get_right());
1084 } else if (lm->get_right()->is_match(rm->get_right())) {
1085// b*a - c*a -> a*(b - c)
1086 return lm->get_right()*(lm->get_left() - rm->get_left());
1087 }
1088
1089// Change cases like c1*a - c2*b -> c1*(a - c2/c1*b)
1090// Note need to make sure c1 doesn't contain any zeros.
1091 if (lm->get_left()->is_constant() &&
1092 rm->get_left()->is_constant() &&
1093 !lm->get_left()->has_constant_zero()) {
1094 return lm->get_left()*(lm->get_right() -
1095 (rm->get_left()/lm->get_left())*rm->get_right());
1096 }
1097
1098// Handle case
1099 auto rmrm = multiply_cast(rm->get_right());
1100 if (rmrm.get()) {
1101// a*b - c*(d*b) -> (a - c*d)*b
1102 if (lm->get_right()->is_match(rmrm->get_right())) {
1103 return (lm->get_left() - rm->get_left()*rmrm->get_left())*lm->get_right();
1104 }
1105// a*b - c*(b*d) -> (a - c*d)*b
1106 if (lm->get_right()->is_match(rmrm->get_left())) {
1107 return (lm->get_left() - rm->get_left()*rmrm->get_right())*lm->get_right();
1108 }
1109// b*a - c*(d*b) -> (a - c*d)*b
1110 if (lm->get_left()->is_match(rmrm->get_right())) {
1111 return (lm->get_right() - rm->get_left()*rmrm->get_left())*lm->get_left();
1112 }
1113// b*a - c*(b*d) -> (a - c*d)*b
1114 if (lm->get_left()->is_match(rmrm->get_left())) {
1115 return (lm->get_right() - rm->get_left()*rmrm->get_right())*lm->get_left();
1116 }
1117 }
1118 auto lmrm = multiply_cast(lm->get_right());
1119 if (lmrm.get()) {
1120// c*(d*b) - a*b -> (c*d - a)*b
1121 if (rm->get_right()->is_match(lmrm->get_right())) {
1122 return (lm->get_left()*lmrm->get_left() - rm->get_left())*rm->get_right();
1123 }
1124// c*(b*d) - a*b -> (c*d - a)*b
1125 if (rm->get_right()->is_match(lmrm->get_left())) {
1126 return (lm->get_left()*lmrm->get_right() - rm->get_left())*rm->get_right();
1127 }
1128// c*(d*b) - b*a -> (c*d - a)*b
1129 if (rm->get_left()->is_match(lmrm->get_right())) {
1130 return (lm->get_left()*lmrm->get_left() - rm->get_right())*rm->get_left();
1131 }
1132// c*(b*d) - b*a -> (c*d - a)*b
1133 if (rm->get_left()->is_match(lmrm->get_left())) {
1134 return (lm->get_left()*lmrm->get_right() - rm->get_right())*rm->get_left();
1135 }
1136 }
1137
1138// a/b*c - d/b*e -> (a*b - d*e)/b
1139// a/b*c - d*e/b -> (a*b - d*e)/b
1140// a*c/b - d/b*e -> (a*b - d*e)/b
1141// a*c/b - d*e/b -> (a*b - d*e)/b
1142 auto lmld = divide_cast(lm->get_left());
1143 auto rmld = divide_cast(rm->get_left());
1144 auto lmrd = divide_cast(lm->get_right());
1145 auto rmrd = divide_cast(rm->get_right());
1146 if (lmld.get() && rmld.get() &&
1147 lmld->get_right()->is_match(rmld->get_right())) {
1148 return (lmld->get_left()*lm->get_right() -
1149 rmld->get_left()*rm->get_right())/lmld->get_right();
1150 } else if (lmld.get() && rmrd.get() &&
1151 lmld->get_right()->is_match(rmrd->get_right())) {
1152 return (lmld->get_left()*lm->get_right() -
1153 rmrd->get_left()*rm->get_left())/lmld->get_right();
1154 } else if (lmrd.get() && rmld.get() &&
1155 lmrd->get_right()->is_match(rmld->get_right())) {
1156 return (lmrd->get_left()*lm->get_left() -
1157 rmld->get_left()*rm->get_right())/lmrd->get_right();
1158 } else if (lmrd.get() && rmrd.get() &&
1159 lmrd->get_right()->is_match(rmrd->get_right())) {
1160 return (lmrd->get_left()*lm->get_left() -
1161 rmrd->get_left()*rm->get_left())/lmrd->get_right();
1162 }
1163 }
1164
1165// Chained subtraction reductions.
1166 if (ls.get()) {
1167 auto lrm = multiply_cast(ls->get_right());
1168 if (lrm.get() && rm.get()) {
1169 if (lrm->get_left()->is_match(rm->get_left())) {
1170// (a - c*b) - c*d -> a - (b + d)*c
1171 return ls->get_left() -
1172 (lrm->get_right() +
1173 rm->get_right())*rm->get_left();
1174 } else if (lrm->get_left()->is_match(rm->get_right())) {
1175// (a - c*b) - d*c -> a - (b + d)*c
1176 return ls->get_left() -
1177 (lrm->get_right() +
1178 rm->get_left())*rm->get_right();
1179 } else if (lrm->get_right()->is_match(rm->get_left())) {
1180// (a - c*b) - c*d -> a - (b + d)*c
1181 return ls->get_left() -
1182 (lrm->get_left() +
1183 rm->get_right())*rm->get_left();
1184 } else if (lrm->get_right()->is_match(rm->get_right())) {
1185// (a - c*b) - d*c -> a - (b + d)*c
1186 return ls->get_left() -
1187 (lrm->get_left() +
1188 rm->get_left())*rm->get_right();
1189 }
1190 }
1191 }
1192
1193// Common denominator reduction. If the left and right are both divide nodes
1194// for a common denominator. So you can change a/b - c/b -> (a - c)/d.
1195 auto ld = divide_cast(this->left);
1196 auto rd = divide_cast(this->right);
1197
1198// c is a constant.
1199// a - -c/b -> a + c/b
1200// a - (-c*d)/b -> a + (c*d)/b
1201// -c/a - b -> -(b + c/a)
1202// (-c*d)/a - b -> -(b + (c*d)/a)
1203 if (rd.get()) {
1204 auto rdlm = multiply_cast(rd->get_left());
1205 if ((rd->get_left()->is_constant() &&
1206 rd->get_left()->evaluate().is_negative()) ||
1207 (rdlm.get() &&
1208 (rdlm->get_left()->is_constant() &&
1209 rdlm->get_left()->evaluate().is_negative()))) {
1210 return this->left + -this->right;
1211 }
1212 } else if (ld.get()) {
1213 auto ldlm = multiply_cast(ld->get_left());
1214 if ((ld->get_left()->is_constant() &&
1215 ld->get_left()->evaluate().is_negative()) ||
1216 (ldlm.get() &&
1217 (ldlm->get_left()->is_constant() &&
1218 ldlm->get_left()->evaluate().is_negative()))) {
1219 return -(-this->left + this->right);
1220 }
1221 }
1222
1223 if (ld.get() && rd.get()) {
1224 if (ld->get_right()->is_match(rd->get_right())) {
1225 return (ld->get_left() - rd->get_left())/ld->get_right();
1226 }
1227
1228 auto ldlm = multiply_cast(ld->get_left());
1229 auto rdlm = multiply_cast(rd->get_left());
1230// a/b - c*a/d -> (1/b - c/d)*a
1231// a/b - a*c/d -> (1/b - c/d)*a
1232// c*a/b - a/d -> (c/b - 1/d)*a
1233// a*c/b - a/d -> (c/b - 1/d)*a
1234 if (rdlm.get()) {
1235 if (ld->get_left()->is_match(rdlm->get_left())) {
1236 return (1.0/ld->get_right() -
1237 rdlm->get_right()/rd->get_right())*rdlm->get_left();
1238 } else if (ld->get_left()->is_match(rdlm->get_right())) {
1239 return (1.0/ld->get_right() -
1240 rdlm->get_left()/rd->get_right())*rdlm->get_right();
1241 }
1242 } else if (ldlm.get()) {
1243 if (rd->get_left()->is_match(ldlm->get_left())) {
1244 return (ldlm->get_right()/ld->get_right() -
1245 1.0/rd->get_right())*ldlm->get_left();
1246 } else if (rd->get_left()->is_match(ldlm->get_right())) {
1247 return (ldlm->get_left()/ld->get_right() -
1248 1.0/rd->get_right())*ldlm->get_right();
1249 }
1250 }
1251
1252// c1*a/b - c2*e/d = c3*(a/b - c4*e/d)
1253// a*b/c - d*b/e -> (a/c - d/e)*b
1254// Make sure we prevent combining constants when we just need to factor out a
1255// common term.
1256// c1*a/b - c2*a/d -> (c1/b - c2/d)*a
1257 if (ldlm.get() && rdlm.get()) {
1258 if (is_constant_combineable(ldlm->get_left(),
1259 rdlm->get_left()) &&
1260 !ldlm->get_right()->is_match(rdlm->get_right())) {
1261 return (ldlm->get_right()/ld->get_right() -
1262 rdlm->get_left()/ldlm->get_left() *
1263 rdlm->get_right()/rd->get_right())*ldlm->get_left();
1264 }
1265
1266 if (ldlm->get_right()->is_match(rdlm->get_right())) {
1267 return (ldlm->get_left()/ld->get_right() -
1268 rdlm->get_left()/rd->get_right())*ldlm->get_right();
1269 } else if (ldlm->get_right()->is_match(rdlm->get_left())) {
1270 return (ldlm->get_left()/ld->get_right() -
1271 rdlm->get_right()/rd->get_right())*ldlm->get_right();
1272 } else if (ldlm->get_left()->is_match(rdlm->get_right())) {
1273 return (ldlm->get_right()/ld->get_right() -
1274 rdlm->get_left()/rd->get_right())*ldlm->get_left();
1275 } else if (ldlm->get_left()->is_match(rdlm->get_left())) {
1276 return (ldlm->get_right()/ld->get_right() -
1277 rdlm->get_right()/rd->get_right())*ldlm->get_left();
1278 }
1279 }
1280
1281// (a/(c*b) - d/(e*c)) -> (a/b - d/e)/c
1282// (a/(b*c) - d/(e*c)) -> (a/b - d/e)/c
1283// (a/(c*b) - d/(c*e)) -> (a/b - d/e)/c
1284// (a/(b*c) - d/(c*e)) -> (a/b - d/e)/c
1285 auto ldrm = multiply_cast(ld->get_right());
1286 auto rdrm = multiply_cast(rd->get_right());
1287 if (ldrm.get() && rdrm.get()) {
1288 if (ldrm->get_right()->is_match(rdrm->get_right())) {
1289 return (ld->get_left()/ldrm->get_left() -
1290 rd->get_left()/rdrm->get_left())/ldrm->get_right();
1291 } else if (ldrm->get_right()->is_match(rdrm->get_left())) {
1292 return (ld->get_left()/ldrm->get_left() -
1293 rd->get_left()/rdrm->get_right())/ldrm->get_right();
1294 } else if (ldrm->get_left()->is_match(rdrm->get_right())) {
1295 return (ld->get_left()/ldrm->get_right() -
1296 rd->get_left()/rdrm->get_left())/ldrm->get_left();
1297 } else if (ldrm->get_left()->is_match(rdrm->get_left())) {
1298 return (ld->get_left()/ldrm->get_right() -
1299 rd->get_left()/rdrm->get_right())/ldrm->get_left();
1300 }
1301 }
1302
1303// a/b - c/(b*d) -> (a*d - c)/(b*d)
1304// a/b - c/(d*b) -> (a*d - c)/(b*d)
1305// a/(b*d) - c/b -> (a - c*d)/(b*d)
1306// a/(d*b) - c/b -> (a - c*d)/(b*d)
1307 if (rdrm.get()) {
1308 if (ld->get_right()->is_match(rdrm->get_left())) {
1309 return (ld->get_left()*rdrm->get_right() - rd->get_left()) /
1310 rd->get_right();
1311 } else if (ld->get_right()->is_match(rdrm->get_right())) {
1312 return (ld->get_left()*rdrm->get_left() - rd->get_left()) /
1313 rd->get_right();
1314 }
1315 } else if (ldrm.get()) {
1316 if (rd->get_right()->is_match(ldrm->get_left())) {
1317 return (ld->get_left() - rd->get_left()*ldrm->get_right()) /
1318 ld->get_right();
1319 } else if (rd->get_right()->is_match(ldrm->get_right())) {
1320 return (ld->get_left() - rd->get_left()*ldrm->get_left()) /
1321 ld->get_right();
1322 }
1323 }
1324 }
1325
1326// Move cases like
1327// (c1 + c2/x) - c3/y -> c1 + (c2/x - c3/y)
1328// (c1 - c2/x) - c3/y -> c1 - (c2/x + c3/y)
1329// in case of common denominators.
1330 if (rd.get()) {
1331 auto la = add_cast(this->left);
1332 if (la.get() && divide_cast(la->get_right()).get()) {
1333 return la->get_left() + (la->get_right() - this->right);
1334 } else if (ls.get() && divide_cast(ls->get_right()).get()) {
1335 return ls->get_left() - (this->right + ls->get_right());
1336 }
1337 }
1338
1339// Handle cases like:
1340// (a/y)^e - b/y^e -> (a^2 - b)/(y^e)
1341// b/y^e - (a/y)^e -> (b - a^2)/(y^e)
1342// (a/y)^e - (b/y)^e -> (a^2 - b^2)/(y^e)
1343 auto pl = pow_cast(this->left);
1344 auto pr = pow_cast(this->right);
1345 if (pl.get() && rd.get()) {
1346 auto rdp = pow_cast(rd->get_right());
1347 if (rdp.get() && pl->get_right()->is_match(rdp->get_right())) {
1348 auto plld = divide_cast(pl->get_left());
1349 if (plld.get() &&
1350 rdp->get_left()->is_match(plld->get_right())) {
1351 return (pow(plld->get_left(), pl->get_right()) -
1352 rd->get_left()) /
1353 pow(rdp->get_left(), pl->get_right());
1354 }
1355 }
1356 } else if (pr.get() && ld.get()) {
1357 auto ldp = pow_cast(ld->get_right());
1358 if (ldp.get() && pr->get_right()->is_match(ldp->get_right())) {
1359 auto prld = divide_cast(pr->get_left());
1360 if (prld.get() &&
1361 ldp->get_left()->is_match(prld->get_right())) {
1362 return (pow(prld->get_left(), pr->get_right()) -
1363 ld->get_left()) /
1364 pow(ldp->get_left(), pr->get_right());
1365 }
1366 }
1367 } else if (pl.get() && pr.get()) {
1368 if (pl->get_right()->is_match(pr->get_right())) {
1369 auto pld = divide_cast(pl->get_left());
1370 auto prd = divide_cast(pr->get_left());
1371 if (pld.get() && prd.get() &&
1372 pld->get_right()->is_match(prd->get_right())) {
1373 return (pow(pld->get_left(), pl->get_right()) -
1374 pow(prd->get_left(), pl->get_right())) /
1375 pow(pld->get_right(), pl->get_right());
1376 }
1377 }
1378 }
1379
1380 auto lfma = fma_cast(this->left);
1381 auto rfma = fma_cast(this->right);
1382
1383 if (lfma.get() && rfma.get()) {
1384 if (lfma->get_middle()->is_match(rfma->get_middle())) {
1385 return fma(lfma->get_left() - rfma->get_left(),
1386 lfma->get_middle(),
1387 lfma->get_right() - rfma->get_right());
1388 }
1389 }
1390
1391// fma(c,d,e) - a -> fma(c,d,e - a)
1392 if (lfma.get() && !this->right->is_all_variables()) {
1393 return fma(lfma->get_left(),
1394 lfma->get_middle(),
1395 lfma->get_right() - this->right);
1396 }
1397
1398// Reduce cases chained subtract multiply divide.
1399 if (ls.get()) {
1400// (a - b*c) - d*e -> a - (b*c + d*e)
1401// (a - b/c) - d/e -> a - (b/c + d/e)
1402 auto lsrd = divide_cast(ls->get_right());
1403 if ((multiply_cast(ls->get_right()).get() && (rm.get() || rd.get())) ||
1404 (divide_cast(ls->get_right()).get() && (rm.get() || rd.get()))) {
1405 return ls->get_left() - (ls->get_right() + this->right);
1406 }
1407 }
1408
1409 return this->shared_from_this();
1410 }
1411
1412//------------------------------------------------------------------------------
1419//------------------------------------------------------------------------------
1422 if (this->is_match(x)) {
1423 return one<T, SAFE_MATH> ();
1424 }
1425
1426 const size_t hash = reinterpret_cast<size_t> (x.get());
1427 if (this->df_cache.find(hash) == this->df_cache.end()) {
1428 this->df_cache[hash] = this->left->df(x) - this->right->df(x);
1429 }
1430 return this->df_cache[hash];
1431 }
1432
1433//------------------------------------------------------------------------------
1441//------------------------------------------------------------------------------
1443 compile(std::ostringstream &stream,
1444 jit::register_map &registers,
1446 const jit::register_usage &usage) {
1447 if (registers.find(this) == registers.end()) {
1448 shared_leaf<T, SAFE_MATH> l = this->left->compile(stream,
1449 registers,
1450 indices,
1451 usage);
1452 shared_leaf<T, SAFE_MATH> r = this->right->compile(stream,
1453 registers,
1454 indices,
1455 usage);
1456
1457 registers[this] = jit::to_string('r', this);
1458 stream << " const ";
1459 jit::add_type<T> (stream);
1460 stream << " " << registers[this] << " = "
1461 << registers[l.get()] << " - "
1462 << registers[r.get()];
1463 this->endline(stream, usage);
1464 }
1465
1466 return this->shared_from_this();
1467 }
1468
1469//------------------------------------------------------------------------------
1474//------------------------------------------------------------------------------
1476 if (this == x.get()) {
1477 return true;
1478 }
1479
1480 auto x_cast = subtract_cast(x);
1481 if (x_cast.get()) {
1482 return this->left->is_match(x_cast->get_left()) &&
1483 this->right->is_match(x_cast->get_right());
1484 }
1485
1486 return false;
1487 }
1488
1489//------------------------------------------------------------------------------
1491//------------------------------------------------------------------------------
1492 virtual void to_latex() const {
1493 bool l_brackets = add_cast(this->left).get() ||
1494 subtract_cast(this->left).get();
1495 bool r_brackets = add_cast(this->right).get() ||
1496 subtract_cast(this->right).get();
1497 if (l_brackets) {
1498 std::cout << "\\left(";
1499 }
1500 this->left->to_latex();
1501 if (l_brackets) {
1502 std::cout << "\\right)";
1503 }
1504 std::cout << "-";
1505 if (r_brackets) {
1506 std::cout << "\\left(";
1507 }
1508 this->right->to_latex();
1509 if (r_brackets) {
1510 std::cout << "\\right)";
1511 }
1512 }
1513
1514//------------------------------------------------------------------------------
1518//------------------------------------------------------------------------------
1520 if (this->has_pseudo()) {
1521 return this->left->remove_pseudo() -
1522 this->right->remove_pseudo();
1523 }
1524 return this->shared_from_this();
1525 }
1526
1527//------------------------------------------------------------------------------
1533//------------------------------------------------------------------------------
1534 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
1535 jit::register_map &registers) {
1536 if (registers.find(this) == registers.end()) {
1537 const std::string name = jit::to_string('r', this);
1538 registers[this] = name;
1539 stream << " " << name
1540 << " [label = \"-\", shape = oval, style = filled, fillcolor = blue, fontcolor = white];" << std::endl;
1541
1542 auto l = this->left->to_vizgraph(stream, registers);
1543 stream << " " << name << " -- " << registers[l.get()] << ";" << std::endl;
1544 auto r = this->right->to_vizgraph(stream, registers);
1545 stream << " " << name << " -- " << registers[r.get()] << ";" << std::endl;
1546 }
1547
1548 return this->shared_from_this();
1549 }
1550 };
1551
1552//------------------------------------------------------------------------------
1561//------------------------------------------------------------------------------
1562 template<jit::float_scalar T, bool SAFE_MATH=false>
1565 auto temp = std::make_shared<subtract_node<T, SAFE_MATH>> (l, r)->reduce();
1566// Test for hash collisions.
1567 for (size_t i = temp->get_hash();
1569 if (leaf_node<T, SAFE_MATH>::caches.nodes.find(i) ==
1570 leaf_node<T, SAFE_MATH>::caches.nodes.end()) {
1572 return temp;
1573 } else if (temp->is_match(leaf_node<T, SAFE_MATH>::caches.nodes[i])) {
1574 return leaf_node<T, SAFE_MATH>::caches.nodes[i];
1575 }
1576 }
1577#if defined(__clang__) || defined(__GNUC__)
1579#else
1580 assert(false && "Should never reach.");
1581#endif
1582 }
1583
1584//------------------------------------------------------------------------------
1596//------------------------------------------------------------------------------
1597 template<jit::float_scalar T, bool SAFE_MATH=false>
1602
1603//------------------------------------------------------------------------------
1616//------------------------------------------------------------------------------
1617 template<jit::float_scalar T, jit::float_scalar L, bool SAFE_MATH=false>
1622
1623//------------------------------------------------------------------------------
1636//------------------------------------------------------------------------------
1637 template<jit::float_scalar T, jit::float_scalar R, bool SAFE_MATH=false>
1642
1643//------------------------------------------------------------------------------
1654//------------------------------------------------------------------------------
1655 template<jit::float_scalar T, bool SAFE_MATH=false>
1659
1661 template<jit::float_scalar T, bool SAFE_MATH=false>
1662 using shared_subtract = std::shared_ptr<subtract_node<T, SAFE_MATH>>;
1663
1664//------------------------------------------------------------------------------
1672//------------------------------------------------------------------------------
1673 template<jit::float_scalar T, bool SAFE_MATH=false>
1675 return std::dynamic_pointer_cast<subtract_node<T, SAFE_MATH>> (x);
1676 }
1677
1678//******************************************************************************
1679// Multiply node.
1680//******************************************************************************
1681//------------------------------------------------------------------------------
1686//------------------------------------------------------------------------------
1687 template<jit::float_scalar T, bool SAFE_MATH=false>
1688 class multiply_node final : public branch_node<T, SAFE_MATH> {
1689 private:
1690//------------------------------------------------------------------------------
1697//------------------------------------------------------------------------------
1699 reduce_nested_fma_times_constant(shared_leaf<T, SAFE_MATH> trial) {
1700 auto temp = fma_cast(trial);
1701 if (temp.get()) {
1702 if (is_constant_combineable(this->left, temp->get_left()) &&
1703 is_constant_combineable(this->left, temp->get_right())) {
1704 return fma(this->left*temp->get_left(),
1705 temp->get_middle(),
1706 this->left*temp->get_right());
1707 } else {
1708 auto temp2 = reduce_nested_fma_times_constant(temp->get_left());
1709 if (temp2.get()) {
1710 return fma(temp2,
1711 temp->get_middle(),
1712 this->left*temp->get_right());
1713 }
1714 }
1715 }
1716 return null_leaf<T, SAFE_MATH> ();
1717 }
1718
1719//------------------------------------------------------------------------------
1727//------------------------------------------------------------------------------
1729 expand_nested_fma_times_add(shared_leaf<T, SAFE_MATH> trial,
1731 auto temp = fma_cast(trial);
1732 if (temp.get()) {
1733 if (add->get_right()->is_match(temp->get_middle()) &&
1734 is_constant_combineable(add->get_left(), temp->get_right())) {
1735 auto temp2 = expand_nested_fma_times_add2(temp->get_left(),
1736 temp, add);
1737 if (temp2.get()) {
1738 return fma(temp2,
1739 add->get_right(),
1740 temp->get_right()*add->get_left());
1741 } else if (is_constant_combineable(add->get_left(), temp->get_left())) {
1742 return fma(fma(temp->get_left(),
1743 add->get_right(),
1744 add->get_left()*temp->get_left() + temp->get_right()),
1745 add->get_right(),
1746 temp->get_right()*add->get_left());
1747 }
1748 }
1749 }
1750 return null_leaf<T, SAFE_MATH> ();
1751 }
1752
1753//------------------------------------------------------------------------------
1762//------------------------------------------------------------------------------
1764 expand_nested_fma_times_add2(shared_leaf<T, SAFE_MATH> trial,
1767 auto temp = fma_cast(trial);
1768 auto temp2 = fma_cast(last);
1769 assert(temp2.get() && "Assumed a fma node.");
1770 if (temp.get()) {
1771 if (add->get_right()->is_match(temp->get_middle()) &&
1772 is_constant_combineable(add->get_left(), temp->get_left()) &&
1773 is_constant_combineable(add->get_left(), temp->get_right())) {
1774
1775 return fma(fma(temp->get_left(),
1776 add->get_right(),
1777 add->get_left()*temp->get_left() +
1778 temp->get_right()),
1779 add->get_right(),
1780 add->get_left()*temp->get_right() +
1781 temp2->get_right());
1782 } else {
1783 auto temp3 = expand_nested_fma_times_add2(temp->get_left(),
1784 temp, add);
1785 if (temp3.get()) {
1786 return fma(temp3,
1787 add->get_right(),
1788 add->get_left()*temp->get_right() +
1789 temp2->get_right());
1790 }
1791 }
1792 }
1793 return null_leaf<T, SAFE_MATH> ();
1794 }
1795
1796//------------------------------------------------------------------------------
1802//------------------------------------------------------------------------------
1803 static std::string to_string(leaf_node<T, SAFE_MATH> *l,
1805 return jit::format_to_string(reinterpret_cast<size_t> (l)) + "*" +
1806 jit::format_to_string(reinterpret_cast<size_t> (r));
1807 }
1808
1809 public:
1810//------------------------------------------------------------------------------
1815//------------------------------------------------------------------------------
1819
1820//------------------------------------------------------------------------------
1826//------------------------------------------------------------------------------
1828 backend::buffer<T> l_result = this->left->evaluate();
1829
1830// If the left are right are same don't evaluate the right.
1831// NOTE: Do not use is_match here. Remove once power is implimented.
1832 if (this->left.get() == this->right.get()) {
1833 return l_result*l_result;
1834 }
1835
1836// If all the elements on the left are zero, return the leftside without
1837// revaluating the rightside. Stop this loop early once the first non zero
1838// element is encountered.
1839 if (l_result.is_zero()) {
1840 return l_result;
1841 }
1842
1843 backend::buffer<T> r_result = this->right->evaluate();
1844 return l_result*r_result;
1845 }
1846
1847//------------------------------------------------------------------------------
1851//------------------------------------------------------------------------------
1853 auto l = constant_cast(this->left);
1854 auto r = constant_cast(this->right);
1855
1856 if (l.get() && l->is(1)) {
1857 return this->right;
1858 } else if (l.get() && l->is(0)) {
1859 return this->left;
1860 } else if (r.get() && r->is(1)) {
1861 return this->left;
1862 } else if (r.get() && r->is(0)) {
1863 return this->right;
1864 } else if (l.get() && r.get()) {
1865 return constant<T, SAFE_MATH> (this->evaluate());
1866 }
1867
1868 auto pl1 = piecewise_1D_cast(this->left);
1869 auto pr1 = piecewise_1D_cast(this->right);
1870
1871 if (pl1.get() && (r.get() || pl1->is_arg_match(this->right))) {
1872 return piecewise_1D(this->evaluate(), pl1->get_arg(),
1873 pl1->get_scale(), pl1->get_offset());
1874 } else if (pr1.get() && (l.get() || pr1->is_arg_match(this->left))) {
1875 return piecewise_1D(this->evaluate(), pr1->get_arg(),
1876 pr1->get_scale(), pr1->get_offset());
1877 }
1878
1879 auto pl2 = piecewise_2D_cast(this->left);
1880 auto pr2 = piecewise_2D_cast(this->right);
1881
1882 if (pl2.get() && (r.get() || pl2->is_arg_match(this->right))) {
1883 return piecewise_2D(this->evaluate(),
1884 pl2->get_num_columns(),
1885 pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
1886 pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
1887 } else if (pr2.get() && (l.get() || pr2->is_arg_match(this->left))) {
1888 return piecewise_2D(this->evaluate(),
1889 pr2->get_num_columns(),
1890 pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
1891 pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
1892 }
1893
1894// Combine 2D and 1D piecewise constants if a row or column matches.
1895 if (pr2.get() && pr2->is_row_match(this->left)) {
1896 backend::buffer<T> result = pl1->evaluate();
1897 result.multiply_row(pr2->evaluate());
1898 return piecewise_2D(result,
1899 pr2->get_num_columns(),
1900 pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
1901 pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
1902 } else if (pr2.get() && pr2->is_col_match(this->left)) {
1903 backend::buffer<T> result = pl1->evaluate();
1904 result.multiply_col(pr2->evaluate());
1905 return piecewise_2D(result,
1906 pr2->get_num_columns(),
1907 pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
1908 pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
1909 } else if (pl2.get() && pl2->is_row_match(this->right)) {
1910 backend::buffer<T> result = pl2->evaluate();
1911 result.multiply_row(pr1->evaluate());
1912 return piecewise_2D(result,
1913 pl2->get_num_columns(),
1914 pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
1915 pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
1916 } else if (pl2.get() && pl2->is_col_match(this->right)) {
1917 backend::buffer<T> result = pl2->evaluate();
1918 result.multiply_col(pr1->evaluate());
1919 return piecewise_2D(result,
1920 pl2->get_num_columns(),
1921 pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
1922 pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
1923 }
1924
1925// Move constants to the left.
1926 if (is_constant_promotable(this->right, this->left)) {
1927 return this->right*this->left;
1928 }
1929
1930// Disable if the right is power like to avoid infinite loop.
1931 if (is_variable_promotable(this->left, this->right)) {
1932 return this->right*this->left;
1933 }
1934
1935// Move trig to the right.
1936 auto cl = cos_cast(this->left);
1937 auto sl = sin_cast(this->left);
1938 if ((cl.get() && !this->right->is_power_like() &&
1939 !this->right->is_all_variables() &&
1940 !sin_cast(this->right).get()) ||
1941 (sl.get() && !this->right->is_power_like() &&
1942 !this->right->is_all_variables()) ||
1943 (sl.get() && cos_cast(this->right).get())) {
1944 return this->right*this->left;
1945 }
1946
1947// Reduce x*x to x^2
1948 if (this->left->is_match(this->right)) {
1949 return pow(this->left, 2.0);
1950 }
1951
1952// Gather common terms.
1953 auto lm = multiply_cast(this->left);
1954 if (lm.get()) {
1955// Promote constants before variables.
1956// (c*v1)*v2 -> c*(v1*v2)
1957 if (is_constant_promotable(lm->get_left(),
1958 lm->get_right())) {
1959 return lm->get_left()*(lm->get_right()*this->right);
1960 }
1961
1962// (a^c*b)*a^d -> a^(c+d)*b
1963// (b*a^c)*a^d -> a^(c+d)*b
1964 if (is_variable_combineable(this->right, lm->get_left())) {
1965 return (this->right*lm->get_left())*lm->get_right();
1966 } else if (is_variable_combineable(this->right, lm->get_right())) {
1967 return (this->right*lm->get_right())*lm->get_left();
1968 }
1969
1970// Assume variables, sqrt of variables, and powers of variables are on the
1971// right.
1972// (a*v)*b -> (a*b)*v
1973 if (is_variable_promotable(lm->get_right(), this->right)) {
1974 return (lm->get_left()*this->right)*lm->get_right();
1975 }
1976
1977// (a*(b*c)^e)*c^f -> a*b^e*c^(e+f)
1978 auto lmrp = pow_cast(lm->get_right());
1979 if (lmrp.get()) {
1980 auto lmrplm = multiply_cast(lmrp->get_left());
1981 if (lmrplm.get() &&
1982 is_variable_combineable(lmrplm->get_right(),
1983 this->right)) {
1984 return (lm->get_left()*pow(lmrplm->get_left(),
1985 lmrp->get_right()))*pow(this->right->get_power_base(),
1986 lmrp->get_right() +
1987 this->right->get_power_exponent());
1988 }
1989 }
1990 }
1991
1992 auto rm = multiply_cast(this->right);
1993 if (rm.get()) {
1994// Assume constants are on the left.
1995// c1*(c2*v) -> c3*v
1996 if (is_constant_combineable(this->left,
1997 rm->get_left())) {
1998 auto temp = this->left*rm->get_left();
1999 if (temp->is_normal()) {
2000 return temp*rm->get_right();
2001 }
2002 }
2003
2004// a*(a*b) -> a^2*b
2005// a*(b*a) -> a^2*b
2006 if (is_variable_combineable(this->left, rm->get_left())) {
2007 return (this->left*rm->get_left())*rm->get_right();
2008 } else if (is_variable_combineable(this->left, rm->get_right())) {
2009 return (this->left*rm->get_right())*rm->get_left();
2010 }
2011
2012// Assume variables are on the left.
2013// a*(b*v) -> (a*b)*v
2014 if (is_variable_promotable(rm->get_right(), this->left)) {
2015 return (this->left*rm->get_left())*rm->get_right();
2016 }
2017
2018// c1*(fma(c2,x,c3)*y)-> fma(c4,x,c5)*y
2019// c1*(fma(fma(c2,x,c3),x,c4)*y)-> fma(fma(c5,x,c6),x,c7)*y
2020// c1*(fma(fma(fma(c2,x,c3),x,c4),x,c5)*y)-> fma(fma(fma(c6,x,c7),x,c8),x,c9)*y
2021// etc...
2022 auto temp = this->reduce_nested_fma_times_constant(rm->get_left());
2023 if (temp.get()) {
2024 return temp*rm->get_right();
2025 }
2026 }
2027
2028// v1*(c*v2) -> c*(v1*v2)
2029 if (rm.get() &&
2030 is_constant_promotable(rm->get_left(), this->left)) {
2031 return rm->get_left()*(this->left*rm->get_right());
2032 }
2033
2034// Assume trig on the right.
2035// a*(b*sin) -> (a*b)*sin
2036// a*(b*cos) -> (a*b)*cos
2037// (a*sin)*b -> (a*b)*sin
2038// (a*cos)*b -> (a*b)*cos
2039 if (lm.get() &&
2040 (sin_cast(lm->get_right()).get() ||
2041 cos_cast(lm->get_right()).get()) &&
2042 !sin_cast(this->right).get() &&
2043 !this->right->is_power_like()) {
2044 return (lm->get_left()*this->right)*lm->get_right();
2045 } else if (rm.get() &&
2046 (sin_cast(rm->get_right()).get() ||
2047 cos_cast(rm->get_right()).get()) &&
2048 !this->left->is_constant()) {
2049 return (this->left*rm->get_left())*rm->get_right();
2050 }
2051
2052// Factor out common constants c*b*c*d -> c*c*b*d. c*c will get reduced to c on
2053// the second pass.
2054 if (lm.get() && rm.get()) {
2055 if (is_constant_combineable(lm->get_left(),
2056 rm->get_left())) {
2057 auto temp = lm->get_left()*rm->get_left();
2058 if (temp->is_normal()) {
2059 return temp*(lm->get_right()*rm->get_right());
2060 }
2061 } else if (is_constant_combineable(lm->get_left(),
2062 rm->get_right())) {
2063 auto temp = lm->get_left()*rm->get_right();
2064 if (temp->is_normal()) {
2065 return temp*(lm->get_right()*rm->get_left());
2066 }
2067 } else if (is_constant_combineable(lm->get_right(),
2068 rm->get_left())) {
2069 auto temp = lm->get_right()*rm->get_left();
2070 if (temp->is_normal()) {
2071 return temp*(lm->get_left()*rm->get_right());
2072 }
2073 } else if (is_constant_combineable(lm->get_right(),
2074 rm->get_right())) {
2075 auto temp = lm->get_right()*rm->get_right();
2076 if (temp->is_normal()) {
2077 return temp*(lm->get_left()*rm->get_left());
2078 }
2079 }
2080
2081// Gather common terms. This will help reduce sqrt(a)*sqrt(a).
2082 if (lm->get_left()->is_match(rm->get_left())) {
2083 return (lm->get_left()*rm->get_left()) *
2084 (lm->get_right()*rm->get_right());
2085 } else if (lm->get_right()->is_match(rm->get_left())) {
2086 return (lm->get_right()*rm->get_left()) *
2087 (lm->get_left()*rm->get_right());
2088 } else if (lm->get_left()->is_match(rm->get_right())) {
2089 return (lm->get_left()*rm->get_right()) *
2090 (lm->get_right()*rm->get_left());
2091 } else if (lm->get_right()->is_match(rm->get_right())) {
2092 return (lm->get_right()*rm->get_right()) *
2093 (lm->get_left()*rm->get_left());
2094 }
2095 }
2096
2097// Common factor reduction. (a/b)*(c/a) = c/b.
2098 auto ld = divide_cast(this->left);
2099 auto rd = divide_cast(this->right);
2100
2101// a*(b/c) -> (a*b)/c
2102// (a/c)*b -> (a*b)/c
2103 if (rd.get()) {
2104 return (this->left*rd->get_left())/rd->get_right();
2105 } else if (ld.get()) {
2106 return (ld->get_left()*this->right)/ld->get_right();
2107 }
2108
2109// (a/b)*(c/a) -> c/b
2110// (b/a)*(a/c) -> c/b
2111 if (ld.get() && rd.get()) {
2112 if (ld->get_left()->is_match(rd->get_right())) {
2113 return rd->get_left()/ld->get_right();
2114 } else if (ld->get_right()->is_match(rd->get_left())) {
2115 return ld->get_left()/rd->get_right();
2116 }
2117
2118// Convert (a/b)*(c/d) -> (a*c)/(b*d). This should help reduce cases like.
2119// (a/b)*(a/b) + (c/b)*(c/b).
2120 return (ld->get_left()*rd->get_left()) /
2121 (ld->get_right()*rd->get_right());
2122 }
2123
2124// Power reductions.
2125 if (is_variable_combineable(this->left, this->right)) {
2126 return pow(this->left->get_power_base(),
2127 this->left->get_power_exponent() +
2128 this->right->get_power_exponent());
2129 }
2130
2131// a*b^-c -> a/b^c
2132 auto rp = pow_cast(this->right);
2133 if (rp.get()) {
2134 auto exponent = constant_cast(rp->get_right());
2135 if (exponent.get() && exponent->evaluate().is_negative()) {
2136 return this->left/pow(rp->get_left(), -rp->get_right());
2137 }
2138 }
2139// b^-c*a -> a/b^c
2140 auto lp = pow_cast(this->left);
2141 if (lp.get()) {
2142 auto exponent = constant_cast(lp->get_right());
2143 if (exponent.get() && exponent->evaluate().is_negative()) {
2144 return this->right/pow(lp->get_left(), -lp->get_right());
2145 }
2146 }
2147// a^b*c^b -> (a*c)^b
2148 if (lp.get() && rp.get()) {
2149 if (lp->get_right()->is_match(rp->get_right())) {
2150 return pow(lp->get_left()*rp->get_left(), lp->get_right());
2151 }
2152 }
2153// (a*b^c)*d^c -> a*(b*d)^c
2154// (a^c*b)*d^c -> b*(a*d)^c
2155// a^c*(b*d^c) -> b*(a*d)^c
2156// a^c*(b^c*d) -> d*(a*b)^c
2157 if (lm.get() && rp.get()) {
2158 auto lmlp = pow_cast(lm->get_left());
2159 auto lmrp = pow_cast(lm->get_right());
2160 if (lmrp.get()) {
2161 if (lmrp->get_right()->is_match(rp->get_right())) {
2162 return lm->get_left()*pow(lmrp->get_left()*rp->get_left(),
2163 rp->get_right());
2164 }
2165 } else if (lmlp.get()) {
2166 if (lmlp->get_right()->is_match(rp->get_right())) {
2167 return lm->get_right()*pow(lmlp->get_left()*rp->get_left(),
2168 rp->get_right());
2169 }
2170 }
2171 } else if (rm.get() && lp.get()) {
2172 auto rmlp = pow_cast(rm->get_left());
2173 auto rmrp = pow_cast(rm->get_right());
2174 if (rmrp.get()) {
2175 if (rmrp->get_right()->is_match(lp->get_right())) {
2176 return rm->get_left()*pow(lp->get_left()*rmrp->get_left(),
2177 lp->get_right());
2178 }
2179 } else if (rmlp.get()) {
2180 if (rmlp->get_right()->is_match(lp->get_right())) {
2181 return rm->get_right()*pow(lp->get_left()*rmlp->get_left(),
2182 lp->get_right());
2183 }
2184 }
2185 }
2186
2187// (b*a)^c*a^d -> b^c*a^(c + d)
2188// (a*b)^c*a^d -> b^c*a^(c + d)
2189// a^d*(b*a)^c -> b^c*a^(c + d)
2190// a^d*(a*b)^c -> b^c*a^(c + d)
2191 if (lp.get() && rp.get()) {
2192 auto lplm = multiply_cast(lp->get_left());
2193 auto rplm = multiply_cast(rp->get_left());
2194 if (lplm.get()) {
2195 if (is_variable_combineable(lplm->get_right(),
2196 this->right)) {
2197 return pow(lplm->get_left()->get_power_base(),
2198 this->left->get_power_exponent())*
2199 pow(this->right->get_power_base(),
2200 this->left->get_power_exponent() +
2201 this->right->get_power_exponent());
2202 } else if (is_variable_combineable(lplm->get_left(),
2203 this->right)) {
2204 return pow(lplm->get_right()->get_power_base(),
2205 this->left->get_power_exponent())*
2206 pow(this->right->get_power_base(),
2207 this->left->get_power_exponent() +
2208 this->right->get_power_exponent());
2209 }
2210 }
2211
2212 if (rplm.get()) {
2213 if (is_variable_combineable(rplm->get_right(),
2214 this->left)) {
2215 return pow(rplm->get_left()->get_power_base(),
2216 this->right->get_power_exponent())*
2217 pow(this->left->get_power_base(),
2218 this->left->get_power_exponent() +
2219 this->right->get_power_exponent());
2220 } else if (is_variable_combineable(rplm->get_left(),
2221 this->left)) {
2222 return pow(rplm->get_right()->get_power_base(),
2223 this->right->get_power_exponent())*
2224 pow(this->left->get_power_base(),
2225 this->left->get_power_exponent() +
2226 this->right->get_power_exponent());
2227 }
2228 }
2229 }
2230
2231 auto lpd = divide_cast(this->left->get_power_base());
2232 if (lpd.get()) {
2233// (a/b)^c*b^d -> a^c*b^(c-d)
2234 if (is_variable_combineable(lpd->get_right(),
2235 this->right)) {
2236 return pow(lpd->get_left(), this->left->get_power_exponent()) *
2237 pow(this->right->get_power_base(),
2238 this->right->get_power_exponent() -
2239 this->left->get_power_exponent()*lpd->get_right()->get_power_exponent());
2240 }
2241// (b/a)^c*b^d -> b^(c+d)/a^c
2242 if (is_variable_combineable(lpd->get_left(), this->right)) {
2243 return pow(this->right->get_power_base(),
2244 this->right->get_power_exponent() +
2245 this->left->get_power_exponent()*lpd->get_left()->get_power_exponent()) /
2246 pow(lpd->get_right(), this->left->get_power_exponent());
2247 }
2248 }
2249 auto rpd = divide_cast(this->right->get_power_base());
2250 if (rpd.get()) {
2251// b^d*(a/b)^c -> a^c*b^(c-d)
2252 if (is_variable_combineable(rpd->get_right(),
2253 this->left)) {
2254 return pow(rpd->get_left(), this->right->get_power_exponent()) *
2255 pow(this->left->get_power_base(),
2256 this->left->get_power_exponent() -
2257 this->right->get_power_exponent()*rpd->get_right()->get_power_exponent());
2258 }
2259// b^d*(b/a)^c -> b^(c+d)/a^c
2260 if (is_variable_combineable(rpd->get_left(),
2261 this->left)) {
2262 return pow(this->right->get_power_base(),
2263 this->right->get_power_exponent() +
2264 this->right->get_power_exponent()*rpd->get_left()->get_power_exponent()) /
2265 pow(rpd->get_right(), this->right->get_power_exponent());
2266 }
2267 }
2268
2269// exp(a)*exp(b) -> exp(a + b)
2270 auto le = exp_cast(this->left);
2271 auto re = exp_cast(this->right);
2272 if (le.get() && re.get()) {
2273 return exp(le->get_arg() + re->get_arg());
2274 }
2275
2276// exp(a)*(exp(b)*c) -> c*(exp(a)*exp(b))
2277// exp(a)*(c*exp(b)) -> c*(exp(a)*exp(b))
2278 if (le.get() && rm.get()) {
2279 auto rmle = exp_cast(rm->get_left());
2280 if (rmle.get()) {
2281 return rm->get_right()*(this->left*rm->get_left());
2282 }
2283 auto rmre = exp_cast(rm->get_right());
2284 if (rmre.get()) {
2285 return rm->get_left()*(this->left*rm->get_right());
2286 }
2287 }
2288// (exp(a)*c)*exp(b) -> c*(exp(a)*exp(b))
2289// (c*exp(a))*exp(b) -> c*(exp(a)*exp(b))
2290 if (re.get() && lm.get()) {
2291 auto lmle = exp_cast(lm->get_left());
2292 if (lmle.get()) {
2293 return lm->get_right()*(this->right*lm->get_left());
2294 }
2295 auto lmre = exp_cast(lm->get_right());
2296 if (lmre.get()) {
2297 return lm->get_left()*(this->right*lm->get_right());
2298 }
2299 }
2300// (exp(a)*c)*(exp(b)*d) -> (c*d)*(exp(a)*exp(b))
2301// (exp(a)*c)*(d*exp(b)) -> (c*d)*(exp(a)*exp(b))
2302// (c*exp(a))*(exp(b)*d) -> (c*d)*(exp(a)*exp(b))
2303// (c*exp(a))*(d*exp(b)) -> (c*d)*(exp(a)*exp(b))
2304 if (lm.get() && rm.get()) {
2305 auto lmle = exp_cast(lm->get_left());
2306 if (lmle.get()) {
2307 auto rmle = exp_cast(rm->get_left());
2308 if (rmle.get()) {
2309 return (lm->get_right()*rm->get_right()) *
2310 (lm->get_left()*rm->get_left());
2311 }
2312 auto rmre = exp_cast(rm->get_right());
2313 if (rmre.get()) {
2314 return (lm->get_right()*rm->get_left()) *
2315 (lm->get_left()*rm->get_right());
2316 }
2317 }
2318 auto lmre = exp_cast(lm->get_right());
2319 if (lmre.get()) {
2320 auto rmle = exp_cast(rm->get_left());
2321 if (rmle.get()) {
2322 return (lm->get_left()*rm->get_right()) *
2323 (lm->get_right()*rm->get_left());
2324 }
2325 auto rmre = exp_cast(rm->get_right());
2326 if (rmre.get()) {
2327 return (lm->get_left()*rm->get_left()) *
2328 (lm->get_right()*rm->get_right());
2329 }
2330 }
2331 }
2332
2333 if (ld.get() && re.get()) {
2334// (c/exp(a))*exp(b) -> c*(exp(b)/exp(a))
2335 auto ldre = exp_cast(ld->get_right());
2336 if (ldre.get()) {
2337 return ld->get_left()*(this->right/ld->get_right());
2338 }
2339// (exp(a)/c)*exp(b) -> (exp(a)*exp(b))/c
2340 auto ldle = exp_cast(ld->get_left());
2341 if (ldle.get()) {
2342 return (ld->get_left()*this->right)/ld->get_right();
2343 }
2344 }
2345 if (rd.get() && le.get()) {
2346// exp(a)*(c/exp(a)) -> c*(exp(a)/exp(b))
2347 auto rdre = exp_cast(rd->get_right());
2348 if (rdre.get()) {
2349 return rd->get_left()*(this->left/rd->get_right());
2350 }
2351// exp(a)*(exp(b)/c) -> (exp(a)*exp(b))/c
2352 auto rdle = exp_cast(rd->get_left());
2353 if (rdle.get()) {
2354 return (this->left*rd->get_left())/rd->get_right();
2355 }
2356 }
2357
2358 if (ld.get() && rm.get()) {
2359 auto rmle = exp_cast(rm->get_left());
2360 if (rmle.get()) {
2361// (c/exp(a))*(exp(b)*d) -> (c*d)*(exp(b)/exp(a))
2362 auto ldre = exp_cast(ld->get_right());
2363 if (ldre.get()) {
2364 return (ld->get_left()*rm->get_right()) *
2365 (rm->get_left()/ld->get_right());
2366 }
2367// (exp(a)/c)*(exp(b)*d) -> (d/c)*(exp(a)*exp(b))
2368 auto ldle = exp_cast(ld->get_left());
2369 if (ldle.get()) {
2370 return (rm->get_right()/ld->get_right()) *
2371 (ld->get_left()*rm->get_left());
2372 }
2373 }
2374 auto rmre = exp_cast(rm->get_right());
2375 if (rmre.get()) {
2376// (c/exp(a))*(d*exp(b)) -> (c*d)*(exp(b)/exp(a))
2377 auto ldre = exp_cast(ld->get_right());
2378 if (ldre.get()) {
2379 return (ld->get_left()*rm->get_left()) *
2380 (rm->get_right()/ld->get_right());
2381 }
2382// (exp(a)/c)*(d*exp(b)) -> (d/c)*(exp(a)*exp(b))
2383 auto ldle = exp_cast(ld->get_left());
2384 if (ldle.get()) {
2385 return (rm->get_left()/ld->get_right()) *
2386 (ld->get_left()*rm->get_right());
2387 }
2388 }
2389 } else if (rd.get() && lm.get()) {
2390 auto lmre = exp_cast(lm->get_right());
2391 if (lmre.get()) {
2392// (c*exp(a))*(exp(b)/d) -> (c/d)*(exp(a)*exp(b))
2393 auto rdre = exp_cast(rd->get_left());
2394 if (rdre.get()) {
2395 return (lm->get_left()/rd->get_right()) *
2396 (lm->get_right()*rd->get_left());
2397 }
2398// (c*exp(a))*(d/exp(b)) -> (c*d)*(exp(a)/exp(b))
2399 auto rdle = exp_cast(rd->get_right());
2400 if (rdle.get()) {
2401 return (lm->get_left()*rd->get_left()) *
2402 (lm->get_right()/rd->get_right());
2403 }
2404 }
2405 auto lmle = exp_cast(lm->get_left());
2406 if (lmle.get()) {
2407// (exp(a)*c)*(d/exp(b)) -> (c*d)*(exp(a)/exp(b))
2408 auto rdle = exp_cast(rd->get_right());
2409 if (rdle.get()) {
2410 return (lm->get_right()*rd->get_left()) *
2411 (lm->get_left()/rd->get_right());
2412 }
2413// (exp(a)*c)*(exp(b)/d) -> (c/d)*(exp(a)*exp(b))
2414 auto rdre = exp_cast(rd->get_left());
2415 if (rdre.get()) {
2416 return (lm->get_right()/rd->get_right()) *
2417 (lm->get_left()*rd->get_left());
2418 }
2419 }
2420 }
2421
2422// c1*fma(c2,x,c3) -> fma(c4,x,c5)
2423// c1*fma(fma(c2,x,c3),x,c4) -> fma(fma(c5,x,c6),x,c7)
2424// c1*fma(fma(fma(c2,x,c3),x,c4),x,c5) -> fma(fma(fma(c6,x,c7),x,c8),x,c9)
2425// etc...
2426 auto fma_reduce = this->reduce_nested_fma_times_constant(this->right);
2427 if (fma_reduce.get()) {
2428 return fma_reduce;
2429 }
2430
2431// fma(c1,x,c2)*(c3 + x) -> fma(fma(c1,x,c4),x,c5)
2432// fma(fma(c1,x,c2),x,c3)*(c4 + x) -> fma(fma(fma(c1,x,c5),x,c6),x,c7)
2433// etc...
2434 auto ra = add_cast(this->right);
2435 if (ra.get()) {
2436 auto fma_expand = this->expand_nested_fma_times_add(this->left,
2437 ra);
2438 if (fma_expand.get()) {
2439 return fma_expand;
2440 }
2441 }
2442
2443// Cases like
2444// (c/exp(a))*(exp(b)/d) -> (c/d)*(exp(b)/exp(a))
2445// (c/exp(a))*(d/exp(b)) -> (c*e)/(exp(b)*exp(a))
2446// (exp(a)/c)*(d/exp(b)) -> (d/c)*(exp(a)/exp(b))
2447// (exp(a)/c)*(exp(b)/d) -> (exp(a)*exp(b))/(c*d)
2448// Are taken care of by (a/b)*(c/d) -> (a*c)/(b*d) conversion above.
2449
2450 return this->shared_from_this();
2451 }
2452
2453//------------------------------------------------------------------------------
2460//------------------------------------------------------------------------------
2462 if (this->is_match(x)) {
2463 return one<T, SAFE_MATH> ();
2464 }
2465
2466 const size_t hash = reinterpret_cast<size_t> (x.get());
2467 if (this->df_cache.find(hash) == this->df_cache.end()) {
2468 this->df_cache[hash] = this->left->df(x)*this->right
2469 + this->left*this->right->df(x);
2470 }
2471 return this->df_cache[hash];
2472 }
2473
2474//------------------------------------------------------------------------------
2482//------------------------------------------------------------------------------
2484 compile(std::ostringstream &stream,
2485 jit::register_map &registers,
2487 const jit::register_usage &usage) {
2488 if (registers.find(this) == registers.end()) {
2489 shared_leaf<T, SAFE_MATH> l = this->left->compile(stream,
2490 registers,
2491 indices,
2492 usage);
2493 shared_leaf<T, SAFE_MATH> r = this->right->compile(stream,
2494 registers,
2495 indices,
2496 usage);
2497
2498 registers[this] = jit::to_string('r', this);
2499 stream << " const ";
2500 jit::add_type<T> (stream);
2501 stream << " " << registers[this] << " = ";
2502 if constexpr (SAFE_MATH) {
2503 stream << "(" << registers[l.get()] << " == ";
2504 if constexpr (jit::complex_scalar<T>) {
2505 jit::add_type<T> (stream);
2506 stream << "(0, 0)";
2507 } else {
2508 stream << "0";
2509 }
2510 stream << " || " << registers[r.get()] << " == ";
2511 if constexpr (jit::complex_scalar<T>) {
2512 jit::add_type<T> (stream);
2513 stream << "(0, 0)";
2514 } else {
2515 stream << "0";
2516 }
2517 stream << ") ? ";
2518 if constexpr (jit::complex_scalar<T>) {
2519 jit::add_type<T> (stream);
2520 stream << "(0, 0)";
2521 } else {
2522 stream << "0";
2523 }
2524 stream << " : ";
2525 }
2526 stream << registers[l.get()] << "*"
2527 << registers[r.get()];
2528 this->endline(stream, usage);
2529 }
2530
2531 return this->shared_from_this();
2532 }
2533
2534//------------------------------------------------------------------------------
2539//------------------------------------------------------------------------------
2541 if (this == x.get()) {
2542 return true;
2543 }
2544
2545 auto x_cast = multiply_cast(x);
2546 if (x_cast.get()) {
2547// Multiplication is commutative.
2548 if ((this->left->is_match(x_cast->get_left()) &&
2549 this->right->is_match(x_cast->get_right())) ||
2550 (this->right->is_match(x_cast->get_left()) &&
2551 this->left->is_match(x_cast->get_right()))) {
2552 return true;
2553 }
2554 }
2555
2556 return false;
2557 }
2558
2559//------------------------------------------------------------------------------
2561//------------------------------------------------------------------------------
2562 virtual void to_latex() const {
2563 if (constant_cast(this->left).get() ||
2564 add_cast(this->left).get() ||
2565 subtract_cast(this->left).get()) {
2566 std::cout << "\\left(";
2567 this->left->to_latex();
2568 std::cout << "\\right)";
2569 } else {
2570 this->left->to_latex();
2571 }
2572 std::cout << " ";
2573 if (constant_cast(this->right).get() ||
2574 add_cast(this->right).get() ||
2575 subtract_cast(this->right).get()) {
2576 std::cout << "\\left(";
2577 this->right->to_latex();
2578 std::cout << "\\right)";
2579 } else {
2580 this->right->to_latex();
2581 }
2582 }
2583
2584//------------------------------------------------------------------------------
2588//------------------------------------------------------------------------------
2590 if (this->has_pseudo()) {
2591 return this->left->remove_pseudo() *
2592 this->right->remove_pseudo();
2593 }
2594 return this->shared_from_this();
2595 }
2596
2597//------------------------------------------------------------------------------
2603//------------------------------------------------------------------------------
2604 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
2605 jit::register_map &registers) {
2606 if (registers.find(this) == registers.end()) {
2607 const std::string name = jit::to_string('r', this);
2608 registers[this] = name;
2609 stream << " " << name
2610 << " [label = \"⨉\", shape = oval, style = filled, fillcolor = blue, fontcolor = white];" << std::endl;
2611
2612 auto l = this->left->to_vizgraph(stream, registers);
2613 stream << " " << name << " -- " << registers[l.get()] << ";" << std::endl;
2614 auto r = this->right->to_vizgraph(stream, registers);
2615 stream << " " << name << " -- " << registers[r.get()] << ";" << std::endl;
2616 }
2617
2618 return this->shared_from_this();
2619 }
2620 };
2621
2622//------------------------------------------------------------------------------
2630//------------------------------------------------------------------------------
2631 template<jit::float_scalar T, bool SAFE_MATH=false>
2634 auto temp = std::make_shared<multiply_node<T, SAFE_MATH>> (l, r)->reduce();
2635// Test for hash collisions.
2636 for (size_t i = temp->get_hash();
2638 if (leaf_node<T, SAFE_MATH>::caches.nodes.find(i) ==
2639 leaf_node<T, SAFE_MATH>::caches.nodes.end()) {
2641 return temp;
2642 } else if (temp->is_match(leaf_node<T, SAFE_MATH>::caches.nodes[i])) {
2643 return leaf_node<T, SAFE_MATH>::caches.nodes[i];
2644 }
2645 }
2646#if defined(__clang__) || defined(__GNUC__)
2648#else
2649 assert(false && "Should never reach.");
2650#endif
2651 }
2652
2653//------------------------------------------------------------------------------
2664//------------------------------------------------------------------------------
2665 template<jit::float_scalar T, bool SAFE_MATH=false>
2670
2671//------------------------------------------------------------------------------
2683//------------------------------------------------------------------------------
2684 template<jit::float_scalar T, jit::float_scalar L, bool SAFE_MATH=false>
2689
2690//------------------------------------------------------------------------------
2702//------------------------------------------------------------------------------
2703 template<jit::float_scalar T, jit::float_scalar R, bool SAFE_MATH=false>
2708
2710 template<jit::float_scalar T, bool SAFE_MATH=false>
2711 using shared_multiply = std::shared_ptr<multiply_node<T, SAFE_MATH>>;
2712
2713//------------------------------------------------------------------------------
2721//------------------------------------------------------------------------------
2722 template<jit::float_scalar T, bool SAFE_MATH=false>
2724 return std::dynamic_pointer_cast<multiply_node<T, SAFE_MATH>> (x);
2725 }
2726
2727//******************************************************************************
2728// Divide node.
2729//******************************************************************************
2730//------------------------------------------------------------------------------
2735//------------------------------------------------------------------------------
2736 template<jit::float_scalar T, bool SAFE_MATH=false>
2737 class divide_node final : public branch_node<T, SAFE_MATH> {
2738 private:
2739//------------------------------------------------------------------------------
2745//------------------------------------------------------------------------------
2746 static std::string to_string(leaf_node<T, SAFE_MATH> *l,
2748 return jit::format_to_string(reinterpret_cast<size_t> (l)) + "/" +
2749 jit::format_to_string(reinterpret_cast<size_t> (r));
2750 }
2751
2752 public:
2753//------------------------------------------------------------------------------
2758//------------------------------------------------------------------------------
2763
2764//------------------------------------------------------------------------------
2770//------------------------------------------------------------------------------
2772 backend::buffer<T> l_result = this->left->evaluate();
2773
2774// If all the elements on the left are zero, return the leftside without
2775// revaluating the rightside. Stop this loop early once the first non zero
2776// element is encountered.
2777 if (l_result.is_zero()) {
2778 return l_result;
2779 }
2780
2781 backend::buffer<T> r_result = this->right->evaluate();
2782 return l_result/r_result;
2783 }
2784
2785//------------------------------------------------------------------------------
2789//------------------------------------------------------------------------------
2791// Constant Reductions.
2792 auto l = constant_cast(this->left);
2793 auto r = constant_cast(this->right);
2794
2795 if ((l.get() && l->is(0)) ||
2796 (r.get() && r->is(1))) {
2797 return this->left;
2798 } else if (l.get() && r.get()) {
2799 return constant<T, SAFE_MATH> (this->evaluate());
2800 }
2801
2802 auto pl1 = piecewise_1D_cast(this->left);
2803 auto pr1 = piecewise_1D_cast(this->right);
2804
2805 if (pl1.get() && (r.get() || pl1->is_arg_match(this->right))) {
2806 return piecewise_1D(this->evaluate(), pl1->get_arg(),
2807 pl1->get_scale(), pl1->get_offset());
2808 } else if (pr1.get() && (l.get() || pr1->is_arg_match(this->left))) {
2809 return piecewise_1D(this->evaluate(), pr1->get_arg(),
2810 pr1->get_scale(), pr1->get_offset());
2811 }
2812
2813 auto pl2 = piecewise_2D_cast(this->left);
2814 auto pr2 = piecewise_2D_cast(this->right);
2815
2816 if (pl2.get() && (r.get() || pl2->is_arg_match(this->right))) {
2817 return piecewise_2D(this->evaluate(),
2818 pl2->get_num_columns(),
2819 pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
2820 pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
2821 } else if (pr2.get() && (l.get() || pr2->is_arg_match(this->left))) {
2822 return piecewise_2D(this->evaluate(),
2823 pr2->get_num_columns(),
2824 pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
2825 pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
2826 }
2827
2828// Combine 2D and 1D piecewise constants if a row or column matches.
2829 if (pr2.get() && pr2->is_row_match(this->left)) {
2830 backend::buffer<T> result = pl1->evaluate();
2831 result.divide_row(pr2->evaluate());
2832 return piecewise_2D(result,
2833 pr2->get_num_columns(),
2834 pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
2835 pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
2836 } else if (pr2.get() && pr2->is_col_match(this->left)) {
2837 backend::buffer<T> result = pl1->evaluate();
2838 result.divide_col(pr2->evaluate());
2839 return piecewise_2D(result,
2840 pr2->get_num_columns(),
2841 pr2->get_left(), pr2->get_x_scale(), pr2->get_x_offset(),
2842 pr2->get_right(), pr2->get_y_scale(), pr2->get_y_offset());
2843 } else if (pl2.get() && pl2->is_row_match(this->right)) {
2844 backend::buffer<T> result = pl2->evaluate();
2845 result.divide_row(pr1->evaluate());
2846 return piecewise_2D(result,
2847 pl2->get_num_columns(),
2848 pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
2849 pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
2850 } else if (pl2.get() && pl2->is_col_match(this->right)) {
2851 backend::buffer<T> result = pl2->evaluate();
2852 result.divide_col(pr1->evaluate());
2853 return piecewise_2D(result,
2854 pl2->get_num_columns(),
2855 pl2->get_left(), pl2->get_x_scale(), pl2->get_x_offset(),
2856 pl2->get_right(), pl2->get_y_scale(), pl2->get_y_offset());
2857 }
2858
2859 if (this->left->is_match(this->right)) {
2860 return one<T, SAFE_MATH> ();
2861 }
2862
2863// Reduce cases of a/c1 -> c2*a
2864 if (this->right->is_constant()) {
2865 return (1.0/this->right)*this->left;
2866 }
2867
2868// a/(b/c + d) -> a*c/(c*d + b)
2869// a/(d + b/c) -> a*c/(c*d + b)
2870 auto ra = add_cast(this->right);
2871 if (ra.get()) {
2872 auto rald = divide_cast(ra->get_left());
2873 auto rard = divide_cast(ra->get_right());
2874 if (rald.get()) {
2875 return this->left*rald->get_right() /
2876 fma(rald->get_right(),
2877 ra->get_right(),
2878 rald->get_left());
2879 } else if (rard.get()) {
2880 return this->left*rard->get_right() /
2881 fma(rard->get_right(),
2882 ra->get_left(),
2883 rard->get_left());
2884 }
2885 }
2886
2887// a/(b/c - d) -> a*c/(b - c*d)
2888// a/(d - b/c) -> a*c/(c*d - b)
2889 auto rs = subtract_cast(this->right);
2890 if (rs.get()) {
2891 auto rsld = divide_cast(rs->get_left());
2892 auto rsrd = divide_cast(rs->get_right());
2893 if (rsld.get()) {
2894 return this->left*rsld->get_right() /
2895 (rsld->get_left() -
2896 rsld->get_right()*rs->get_right());
2897 } else if (rsrd.get()) {
2898 return this->left*rsrd->get_right() /
2899 (rsrd->get_right()*rs->get_left() -
2900 rsrd->get_left());
2901 }
2902 }
2903
2904// fma(a,d,c*d)/d -> a + c
2905// fma(a,d,d*c)/d -> a + c
2906// fma(d,a,c*d)/d -> a + c
2907// fma(d,a,d*c)/d -> a + c
2908 auto lfma = fma_cast(this->left);
2909 if (lfma.get()) {
2910 auto fmarm = multiply_cast(lfma->get_right());
2911 if (fmarm.get()) {
2912 if (lfma->get_middle()->is_match(this->right) &&
2913 fmarm->get_right()->is_match(this->right)) {
2914 return lfma->get_left() + fmarm->get_left();
2915 } else if (lfma->get_middle()->is_match(this->right) &&
2916 fmarm->get_left()->is_match(this->right)) {
2917 return lfma->get_left() + fmarm->get_right();
2918 } else if (lfma->get_left()->is_match(this->right) &&
2919 fmarm->get_right()->is_match(this->right)) {
2920 return lfma->get_middle() + fmarm->get_left();
2921 } else if (lfma->get_left()->is_match(this->right) &&
2922 fmarm->get_left()->is_match(this->right)) {
2923 return lfma->get_middle() + fmarm->get_right();
2924 }
2925 }
2926 }
2927
2928// Common factor reduction. (a*b)/(a*c) = b/c.
2929 auto lm = multiply_cast(this->left);
2930 auto rm = multiply_cast(this->right);
2931
2932 if (lm.get() && rm.get()) {
2933 if (is_variable_combineable(lm->get_left(),
2934 rm->get_left()) ||
2935 is_variable_combineable(lm->get_right(),
2936 rm->get_right())) {
2937 return (lm->get_left()/rm->get_left()) *
2938 (lm->get_right()/rm->get_right());
2939 } else if (is_variable_combineable(lm->get_left(),
2940 rm->get_right()) ||
2941 is_variable_combineable(lm->get_right(),
2942 rm->get_left())) {
2943 return (lm->get_left()/rm->get_right()) *
2944 (lm->get_right()/rm->get_left());
2945 }
2946 }
2947
2948// Move constants to the numerator.
2949// a/(c1*b) -> (c2*a)/b
2950// a/(b*c1) -> (c2*a)/b
2951 if (rm.get()) {
2952 if (rm->get_left()->is_constant() &&
2953 rm->get_left()->is_normal()) {
2954 return ((1.0/rm->get_left())*this->left)/rm->get_right();
2955 } else if (rm->get_right()->is_constant() &&
2956 rm->get_right()->is_normal()) {
2957 return ((1.0/rm->get_right())*this->left)/rm->get_left();
2958 }
2959
2960// a/((b/c + d)*e) -> a*c/((c*d + b)*e)
2961// a/((d + b/c)*e) -> a*c/((c*d + b)*e)
2962// a/(e*(b/c + d)) -> a*c/((c*d + b)*e)
2963// a/(e*(d + b/c)) -> a*c/((c*d + b)*e)
2964 auto rmla = add_cast(rm->get_left());
2965 auto rmra = add_cast(rm->get_right());
2966 if (rmla.get()) {
2967 auto rmlald = divide_cast(rmla->get_left());
2968 auto rmlard = divide_cast(rmla->get_right());
2969 if (rmlald.get()) {
2970 return this->left*rmlald->get_right() /
2971 (fma(rmlald->get_right(),
2972 rmla->get_right(),
2973 rmlald->get_left())*rm->get_right());
2974 } else if (rmlard.get()) {
2975 return this->left*rmlard->get_right() /
2976 (fma(rmlard->get_right(),
2977 rmla->get_left(),
2978 rmlard->get_left())*rm->get_right());
2979 }
2980 }
2981 if (rmra.get()) {
2982 auto rmrald = divide_cast(rmra->get_left());
2983 auto rmrard = divide_cast(rmra->get_right());
2984 if (rmrald.get()) {
2985 return this->left*rmrald->get_right() /
2986 (fma(rmrald->get_right(),
2987 rmra->get_right(),
2988 rmrald->get_left())*rm->get_left());
2989 } else if (rmrard.get()) {
2990 return this->left*rmrard->get_right() /
2991 (fma(rmrard->get_right(),
2992 rmra->get_left(),
2993 rmrard->get_left())*rm->get_left());
2994 }
2995 }
2996
2997// a/((b/c - d)*e) -> a*c/((b - c*d)*e)
2998// a/(e*(b/c - d)) -> a*c/((b - c*d)*e)
2999// a/((d - b/c)*e) -> a*c/((c*d - b)*e)
3000// a/(e*(d - b/c)) -> a*c/((c*d - b)*e)
3001 auto rmls = subtract_cast(rm->get_left());
3002 auto rmrs = subtract_cast(rm->get_right());
3003 if (rmls.get()) {
3004 auto rmlsld = divide_cast(rmls->get_left());
3005 auto rmlsrd = divide_cast(rmls->get_right());
3006 if (rmlsld.get()) {
3007 return this->left*rmlsld->get_right() /
3008 ((rmlsld->get_left() -
3009 rmlsld->get_right()*rmls->get_right())*rm->get_right());
3010 } else if (rmlsrd.get()) {
3011 return this->left*rmlsrd->get_right() /
3012 ((rmlsrd->get_right()*rmls->get_left() -
3013 rmlsrd->get_left())*rm->get_right());
3014 }
3015 }
3016 if (rmrs.get()) {
3017 auto rmrsld = divide_cast(rmrs->get_left());
3018 auto rmrsrd = divide_cast(rmrs->get_right());
3019 if (rmrsld.get()) {
3020 return this->left*rmrsld->get_right() /
3021 ((rmrsld->get_left() -
3022 rmrsld->get_right()*rmrs->get_right())*rm->get_left());
3023 } else if (rmrsrd.get()) {
3024 return this->left*rmrsrd->get_right() /
3025 ((rmrsrd->get_right()*rmrs->get_left() -
3026 rmrsrd->get_left())*rm->get_left());
3027 }
3028 }
3029 }
3030
3031 if (lm.get() && rm.get()) {
3032// (a*b)/(a*c) -> b/c
3033// (b*a)/(a*c) -> b/c
3034// (a*b)/(c*a) -> b/c
3035// (b*a)/(c*a) -> b/c
3036 if (lm->get_left()->is_match(rm->get_left())) {
3037 return lm->get_right()/rm->get_right();
3038 } else if (lm->get_left()->is_match(rm->get_right())) {
3039 return lm->get_right()/rm->get_left();
3040 } else if (lm->get_right()->is_match(rm->get_left())) {
3041 return lm->get_left()/rm->get_right();
3042 } else if (lm->get_right()->is_match(rm->get_right())) {
3043 return lm->get_left()/rm->get_left();
3044 }
3045 }
3046
3047 if (lm.get()) {
3048// (v1*v2)/v1 -> v2
3049// (v2*v1)/v1 -> v2
3050 if (lm->get_left()->is_match(this->right)) {
3051 return lm->get_right();
3052 } else if (lm->get_right()->is_match(this->right)) {
3053 return lm->get_left();
3054 }
3055
3056// (v1^a*v2)/v1^b -> v2*(v1^a/v1^b)
3057// (v2*v1^a)/v1^b -> v2*(v1^a/v1^b)
3058 if (is_variable_combineable(lm->get_left(),
3059 this->right)) {
3060 return lm->get_right()*(lm->get_left()/this->right);
3061 } else if (is_variable_combineable(lm->get_right(),
3062 this->right)) {
3063 return lm->get_left()*(lm->get_right()/this->right);
3064 }
3065 }
3066
3067// (a/b)/c -> a/(b*c)
3068// a/(b/c) -> a*c/b
3069 auto ld = divide_cast(this->left);
3070 auto rd = divide_cast(this->right);
3071 if (ld.get()) {
3072 return ld->get_left()/(ld->get_right()*this->right);
3073 }
3074 if (rd.get()) {
3075 return this->left*rd->get_right()/rd->get_left();
3076 }
3077
3078// Power reductions.
3079 if (is_variable_combineable(this->left,
3080 this->right)) {
3081 return pow(this->left->get_power_base(),
3082 this->left->get_power_exponent() -
3083 this->right->get_power_exponent());
3084 }
3085
3086// a/b^-c -> a*b^c
3087 auto rp = pow_cast(this->right);
3088 if (rp.get()) {
3089 auto exponent = constant_cast(rp->get_right());
3090 if (exponent.get() && exponent->evaluate().is_negative()) {
3091 return this->left*pow(rp->get_left(), -rp->get_right());
3092 }
3093 }
3094
3095// (a*b)^c/(a^d) = a^(c - d)*b^c
3096// (b*a)^c/(a^d) = a^(c - d)*b^c
3097 auto lp = pow_cast(this->left);
3098 if (lp.get()) {
3099 auto lpm = multiply_cast(this->left->get_power_base());
3100 if (lpm.get()) {
3101 if (lpm->get_left()->is_match(this->right->get_power_base())) {
3102 return pow(this->right->get_power_base(),
3103 this->left->get_power_exponent() -
3104 this->right->get_power_exponent()) *
3105 pow(lpm->get_right(),
3106 this->left->get_power_exponent());
3107 } else if (lpm->get_right()->is_match(this->right->get_power_base())) {
3108 return pow(this->right->get_power_base(),
3109 this->left->get_power_exponent() -
3110 this->right->get_power_exponent()) *
3111 pow(lpm->get_left(),
3112 this->left->get_power_exponent());
3113 }
3114 }
3115 }
3116
3117// a^b/c^b -> (a/c)^b
3118 if (lp.get() && rp.get()) {
3119 if (lp->get_right()->is_match(rp->get_right())) {
3120 return pow(lp->get_left()/rp->get_left(), lp->get_right());
3121 }
3122 }
3123
3124// (a*b)^c/((a^d)*e) = a^(c - d)*b^c/e
3125// (b*a)^c/((a^d)*e) = a^(c - d)*b^c/e
3126// (a*b)^c/(e*(a^d)) = a^(c - d)*b^c/e
3127// (b*a)^c/(e*(a^d)) = a^(c - d)*b^c/e
3128 if (lp.get() && rm.get()) {
3129 auto lpm = multiply_cast(this->left->get_power_base());
3130 if (lpm.get()) {
3131 if (lpm->get_left()->is_match(rm->get_left()->get_power_base())) {
3132 return (pow(rm->get_left()->get_power_base(),
3133 this->left->get_power_exponent() -
3134 rm->get_left()->get_power_exponent()) *
3135 pow(lpm->get_right(),
3136 this->left->get_power_exponent())) /
3137 rm->get_right();
3138 } else if (lpm->get_right()->is_match(rm->get_left()->get_power_base())) {
3139 return (pow(rm->get_left()->get_power_base(),
3140 this->left->get_power_exponent() -
3141 rm->get_left()->get_power_exponent()) *
3142 pow(lpm->get_left(),
3143 this->left->get_power_exponent())) /
3144 rm->get_right();
3145 } else if (lpm->get_left()->is_match(rm->get_right()->get_power_base())) {
3146 return (pow(rm->get_right()->get_power_base(),
3147 this->left->get_power_exponent() -
3148 rm->get_right()->get_power_exponent()) *
3149 pow(lpm->get_right(),
3150 this->left->get_power_exponent())) /
3151 rm->get_left();
3152 } else if (lpm->get_right()->is_match(rm->get_right()->get_power_base())) {
3153 return (pow(rm->get_right()->get_power_base(),
3154 this->left->get_power_exponent() -
3155 rm->get_right()->get_power_exponent()) *
3156 pow(lpm->get_left(),
3157 this->left->get_power_exponent())) /
3158 rm->get_left();
3159 }
3160 }
3161 }
3162
3163 if (lm.get()) {
3164// a*(b*c)/c -> a*b
3165// a*(c*b)/c -> a*b
3166// (a*c)*b/c -> a*b
3167// (c*a)*b/c -> a*b
3168 auto lmrm = multiply_cast(lm->get_right());
3169 auto lmlm = multiply_cast(lm->get_left());
3170 if (lmrm.get()) {
3171 if (is_variable_combineable(lmrm->get_right(),
3172 this->right)) {
3173 return lm->get_left()*lmrm->get_left() *
3174 (lmrm->get_right()/this->right);
3175 } else if (is_variable_combineable(lmrm->get_left(),
3176 this->right)) {
3177 return lm->get_left()*lmrm->get_right() *
3178 (lmrm->get_left()/this->right);
3179 }
3180 } else if (lmlm.get()) {
3181 if (is_variable_combineable(lmlm->get_right(),
3182 this->right)) {
3183 return lm->get_right()*lmlm->get_left() *
3184 (lmlm->get_right()/this->right);
3185 } else if (is_variable_combineable(lmlm->get_left(),
3186 this->right)) {
3187 return lm->get_right()*lmlm->get_right() *
3188 (lmlm->get_left()/this->right);
3189 }
3190 }
3191
3192// (f*(a*b)^c)/(a^d) = f*a^(c - d)*b^c
3193// (f*(b*a)^c)/(a^d) = f*a^(c - d)*b^c
3194// (((a*b)^c)*f)/(a^d) = f*a^(c - d)*b^c
3195// (((b*a)^c)*f)/(a^d) = f*a^(c - d)*b^c
3196 auto lmlp = pow_cast(lm->get_left());
3197 auto lmrp = pow_cast(lm->get_right());
3198 if (lmlp.get()) {
3199 auto lmlpm = multiply_cast(lmlp->get_power_base());
3200 if (lmlpm.get()) {
3201 if (lmlpm->get_left()->is_match(this->right->get_power_base())) {
3202 return lm->get_right() *
3203 pow(this->right->get_power_base(),
3204 lmlp->get_power_exponent() -
3205 this->right->get_power_exponent()) *
3206 pow(lmlpm->get_right(),
3207 lmlp->get_power_exponent());
3208 } else if (lmlpm->get_right()->is_match(this->right->get_power_base())) {
3209 return lm->get_right() *
3210 pow(this->right->get_power_base(),
3211 lmlp->get_power_exponent() -
3212 this->right->get_power_exponent()) *
3213 pow(lmlpm->get_left(),
3214 lmlp->get_power_exponent());
3215 }
3216 }
3217 } else if (lmrp.get()) {
3218 auto lmrpm = multiply_cast(lmrp->get_power_base());
3219 if (lmrpm.get()) {
3220 if (lmrpm->get_left()->is_match(this->right->get_power_base())) {
3221 return lm->get_left() *
3222 pow(this->right->get_power_base(),
3223 lmrp->get_power_exponent() -
3224 this->right->get_power_exponent()) *
3225 pow(lmrpm->get_right(),
3226 lmrp->get_power_exponent());
3227 } else if (lmrpm->get_right()->is_match(this->right->get_power_base())) {
3228 return lm->get_left() *
3229 pow(this->right->get_power_base(),
3230 lmrp->get_power_exponent() -
3231 this->right->get_power_exponent()) *
3232 pow(lmrpm->get_left(),
3233 lmrp->get_power_exponent());
3234 }
3235 }
3236 }
3237 }
3238
3239// f*(a*b)^c/((a^d)*e) = a^(c - d)*b^c/e
3240// f*(b*a)^c/((a^d)*e) = a^(c - d)*b^c/e
3241// f*(a*b)^c/(e*(a^d)) = a^(c - d)*b^c/e
3242// f*(b*a)^c/(e*(a^d)) = a^(c - d)*b^c/e
3243// (a*b)^c*f/((a^d)*e) = a^(c - d)*b^c/e
3244// (b*a)^c*f/((a^d)*e) = a^(c - d)*b^c/e
3245// (a*b)^c*f/(e*(a^d)) = a^(c - d)*b^c/e
3246// (b*a)^c*f/(e*(a^d)) = a^(c - d)*b^c/e
3247 if (lm.get() && rm.get()) {
3248 auto lmlp = pow_cast(lm->get_left());
3249 auto lmrp = pow_cast(lm->get_right());
3250 if (lmlp.get()) {
3251 auto lmlpm = multiply_cast(lmlp->get_power_base());
3252 if (lmlpm.get()) {
3253 if (lmlpm->get_left()->is_match(rm->get_left()->get_power_base())) {
3254 return lm->get_right() *
3255 (pow(rm->get_left()->get_power_base(),
3256 lmlp->get_power_exponent() -
3257 rm->get_left()->get_power_exponent())) *
3258 pow(lmlpm->get_right(),
3259 lmlp->get_power_exponent()) /
3260 rm->get_right();
3261 } else if (lmlpm->get_right()->is_match(rm->get_left()->get_power_base())) {
3262 return lm->get_right() *
3263 (pow(rm->get_left()->get_power_base(),
3264 lmlp->get_power_exponent() -
3265 rm->get_left()->get_power_exponent())) *
3266 pow(lmlpm->get_left(),
3267 lmlp->get_power_exponent()) /
3268 rm->get_right();
3269 } else if (lmlpm->get_left()->is_match(rm->get_right()->get_power_base())) {
3270 return lm->get_right() *
3271 (pow(rm->get_left()->get_power_base(),
3272 lmlp->get_power_exponent() -
3273 rm->get_right()->get_power_exponent())) *
3274 pow(lmlpm->get_right(),
3275 lmlp->get_power_exponent()) /
3276 rm->get_left();
3277 } else if (lmlpm->get_right()->is_match(rm->get_right()->get_power_base())) {
3278 return lm->get_right() *
3279 (pow(rm->get_left()->get_power_base(),
3280 lmlp->get_power_exponent() -
3281 rm->get_right()->get_power_exponent())) *
3282 pow(lmlpm->get_left(),
3283 lmlp->get_power_exponent()) /
3284 rm->get_left();
3285 }
3286 }
3287 } else if (lmrp.get()) {
3288 auto lmrpm = multiply_cast(lmrp->get_power_base());
3289 if (lmrpm.get()) {
3290 if (lmrpm->get_left()->is_match(rm->get_left()->get_power_base())) {
3291 return lm->get_left() *
3292 (pow(rm->get_left()->get_power_base(),
3293 lmrp->get_power_exponent() -
3294 rm->get_left()->get_power_exponent())) *
3295 pow(lmrpm->get_right(),
3296 lmrp->get_power_exponent()) /
3297 rm->get_right();
3298 } else if (lmrpm->get_right()->is_match(rm->get_left()->get_power_base())) {
3299 return lm->get_left() *
3300 (pow(rm->get_left()->get_power_base(),
3301 lmrp->get_power_exponent() -
3302 rm->get_left()->get_power_exponent())) *
3303 pow(lmrpm->get_left(),
3304 lmrp->get_power_exponent()) /
3305 rm->get_right();
3306 } else if (lmrpm->get_left()->is_match(rm->get_right()->get_power_base())) {
3307 return lm->get_left() *
3308 (pow(rm->get_left()->get_power_base(),
3309 lmrp->get_power_exponent() -
3310 rm->get_right()->get_power_exponent())) *
3311 pow(lmrpm->get_right(),
3312 lmrp->get_power_exponent()) /
3313 rm->get_left();
3314 } else if (lmrpm->get_right()->is_match(rm->get_right()->get_power_base())) {
3315 return lm->get_left() *
3316 (pow(rm->get_left()->get_power_base(),
3317 lmrp->get_power_exponent() -
3318 rm->get_right()->get_power_exponent())) *
3319 pow(lmrpm->get_left(),
3320 lmrp->get_power_exponent()) /
3321 rm->get_left();
3322 }
3323 }
3324 }
3325 }
3326
3327// exp(a)/exp(b) -> exp(a - b)
3328 auto lexp = exp_cast(this->left);
3329 auto rexp = exp_cast(this->right);
3330 if (lexp.get() && rexp.get()) {
3331 return exp(lexp->get_arg() - rexp->get_arg());
3332 }
3333
3334// (c*exp(a))/exp(b) -> c*(exp(a)/exp(b))
3335// (exp(a)*c)/exp(b) -> c*(exp(a)/exp(b))
3336 if (rexp.get() && lm.get()) {
3337 auto lmre = exp_cast(lm->get_right());
3338 if (lmre.get()) {
3339 return lm->get_left()*(lm->get_right()/this->right);
3340 }
3341 auto lmle = exp_cast(lm->get_left());
3342 if (lmle.get()) {
3343 return lm->get_right()*(lm->get_left()/this->right);
3344 }
3345 }
3346// ((c*exp(a))*d)/exp(b)
3347// ((exp(a)*c)*d)/exp(b)
3348// (c*(exp(a)*d))/exp(b)
3349// (c*(d*exp(a)))/exp(b)
3350 if (rexp.get() && lm.get()) {
3351 auto lmlm = multiply_cast(lm->get_left());
3352 auto lmrm = multiply_cast(lm->get_right());
3353
3354 if (lmlm.get()) {
3355 if (exp_cast(lmlm->get_right()).get()) {
3356 return lmlm->get_left()*lm->get_right() *
3357 (lmlm->get_right()/this->right);
3358 } else if (exp_cast(lmlm->get_left()).get()) {
3359 return lmlm->get_right()*lm->get_right() *
3360 (lmlm->get_left()/this->right);
3361 }
3362 } else if (lmrm.get()) {
3363 if (exp_cast(lmrm->get_right()).get()) {
3364 return lmrm->get_left()*lm->get_left() *
3365 (lmrm->get_right()/this->right);
3366 } else if (exp_cast(lmrm->get_left()).get()) {
3367 return lmrm->get_right()*lm->get_left() *
3368 (lmrm->get_left()/this->right);
3369 }
3370 }
3371 }
3372
3373// exp(a)/(c*exp(b)) -> (exp(a)/exp(b))/c
3374// exp(a)/(exp(b)*c) -> (exp(a)/exp(b))/c
3375 if (lexp.get() && rm.get()) {
3376 auto rmre = exp_cast(rm->get_right());
3377 if (rmre.get()) {
3378 return (this->left/rm->get_right())/rm->get_left();
3379 }
3380 auto rmle = exp_cast(rm->get_left());
3381 if (rmle.get()) {
3382 return (this->left/rm->get_left())/rm->get_right();
3383 }
3384 }
3385
3386// (c*exp(a))/(d*exp(b)) -> (c/d)*(exp(a)/exp(b))
3387// (c*exp(a))/(exp(b)*d) -> (c/d)*(exp(a)/exp(b))
3388// (exp(a)*c)/(d*exp(b)) -> (c/d)*(exp(a)/exp(b))
3389// (exp(a)*c)/(exp(b)*d) -> (c/d)*(exp(a)/exp(b))
3390 if (lm.get() && rm.get()) {
3391 auto lmre = exp_cast(lm->get_right());
3392 if (lmre.get()) {
3393 auto rmre = exp_cast(rm->get_right());
3394 if (rmre.get()) {
3395 return (lm->get_left()/rm->get_left()) *
3396 (lm->get_right()/rm->get_right());
3397 }
3398 auto rmle = exp_cast(rm->get_left());
3399 if (rmle.get()) {
3400 return (lm->get_left()/rm->get_right()) *
3401 (lm->get_right()/rm->get_left());
3402 }
3403 }
3404 auto lmle = exp_cast(lm->get_left());
3405 if (lmle.get()) {
3406 auto rmre = exp_cast(rm->get_right());
3407 if (rmre.get()) {
3408 return (lm->get_right()/rm->get_left()) *
3409 (lm->get_left()/rm->get_right());
3410 }
3411 auto rmle = exp_cast(rm->get_left());
3412 if (rmle.get()) {
3413 return (lm->get_right()/rm->get_right()) *
3414 (lm->get_left()/rm->get_left());
3415 }
3416 }
3417 }
3418
3419// exp(a)/(c/exp(b)) -> (exp(a)*exp(b))/c
3420// exp(a)/(exp(b)/c) -> c*(exp(a)/exp(b))
3421 if (rd.get() && lexp.get()) {
3422 auto rdre = exp_cast(rd->get_right());
3423 if (rdre.get()) {
3424 return (this->left*rd->get_right())/rd->get_left();
3425 }
3426 auto rdle = exp_cast(rd->get_left());
3427 if (rdle.get()) {
3428 return rd->get_right()*(this->left/rd->get_left());
3429 }
3430 }
3431
3432// (c/exp(a))/exp(b) -> c/(exp(a)*exp(b))
3433// (exp(a)/c)/exp(b) -> exp(a)/(c*exp(b))
3434// (c/exp(a))/(d/exp(b)) -> (c*exp(b))/(d*exp(a))
3435// (c/exp(a))/(exp(b)/d) -> (c*d)/(exp(b)*exp(a))
3436// (exp(a)/c)/(d/exp(b)) -> (exp(a)*exp(b))/(d*c)
3437// (exp(a)/c)/(exp(b)/d) -> (exp(a)*d)/(exp(b)*c)
3438// Note cases like this are already transformed by the (a/b)/c -> a/(b*c)
3439// above.
3440
3441 return this->shared_from_this();
3442 }
3443
3444//------------------------------------------------------------------------------
3451//------------------------------------------------------------------------------
3454 if (this->is_match(x)) {
3455 return one<T, SAFE_MATH> ();
3456 }
3457
3458 const size_t hash = reinterpret_cast<size_t> (x.get());
3459 if (this->df_cache.find(hash) == this->df_cache.end()) {
3460 this->df_cache[hash] = this->left->df(x)/this->right
3461 - this->left*this->right->df(x)/(this->right*this->right);
3462 }
3463 return this->df_cache[hash];
3464 }
3465
3466//------------------------------------------------------------------------------
3474//------------------------------------------------------------------------------
3476 compile(std::ostringstream &stream,
3477 jit::register_map &registers,
3479 const jit::register_usage &usage) {
3480 if (registers.find(this) == registers.end()) {
3481 shared_leaf<T, SAFE_MATH> l = this->left->compile(stream,
3482 registers,
3483 indices,
3484 usage);
3485 shared_leaf<T, SAFE_MATH> r = this->right->compile(stream,
3486 registers,
3487 indices,
3488 usage);
3489
3490 registers[this] = jit::to_string('r', this);
3491 stream << " const ";
3492 jit::add_type<T> (stream);
3493 stream << " " << registers[this] << " = ";
3494 if constexpr (SAFE_MATH) {
3495 stream << registers[l.get()] << " == ";
3496 if constexpr (jit::complex_scalar<T>) {
3497 jit::add_type<T> (stream);
3498 stream << "(0, 0)";
3499 } else {
3500 stream << "0";
3501 }
3502 stream << " ? ";
3503 if constexpr (jit::complex_scalar<T>) {
3504 jit::add_type<T> (stream);
3505 stream << "(0, 0)";
3506 } else {
3507 stream << "0";
3508 }
3509 stream << " : ";
3510 }
3511 stream << registers[l.get()] << "/"
3512 << registers[r.get()];
3513 this->endline(stream, usage);
3514 }
3515 return this->shared_from_this();
3516 }
3517
3518//------------------------------------------------------------------------------
3523//------------------------------------------------------------------------------
3525 if (this == x.get()) {
3526 return true;
3527 }
3528
3529 auto x_cast = divide_cast(x);
3530 if (x_cast.get()) {
3531 return this->left->is_match(x_cast->get_left()) &&
3532 this->right->is_match(x_cast->get_right());
3533 }
3534
3535 return false;
3536 }
3537
3538//------------------------------------------------------------------------------
3540//------------------------------------------------------------------------------
3541 virtual void to_latex() const {
3542 std::cout << "\\frac{";
3543 this->left->to_latex();
3544 std::cout << "}{";
3545 this->right->to_latex();
3546 std::cout << "}";
3547 }
3548
3549//------------------------------------------------------------------------------
3553//------------------------------------------------------------------------------
3555 if (this->has_pseudo()) {
3556 return this->left->remove_pseudo() /
3557 this->right->remove_pseudo();
3558 }
3559 return this->shared_from_this();
3560 }
3561
3562//------------------------------------------------------------------------------
3568//------------------------------------------------------------------------------
3569 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
3570 jit::register_map &registers) {
3571 if (registers.find(this) == registers.end()) {
3572 const std::string name = jit::to_string('r', this);
3573 registers[this] = name;
3574 stream << " " << name
3575 << " [label = \"\\\\\", shape = oval, style = filled, fillcolor = blue, fontcolor = white];" << std::endl;
3576
3577 auto l = this->left->to_vizgraph(stream, registers);
3578 stream << " " << name << " -- " << registers[l.get()] << ";" << std::endl;
3579 auto r = this->right->to_vizgraph(stream, registers);
3580 stream << " " << name << " -- " << registers[r.get()] << ";" << std::endl;
3581 }
3582
3583 return this->shared_from_this();
3584 }
3585 };
3586
3587//------------------------------------------------------------------------------
3595//------------------------------------------------------------------------------
3596 template<jit::float_scalar T, bool SAFE_MATH=false>
3599 auto temp = std::make_shared<divide_node<T, SAFE_MATH>> (l, r)->reduce();
3600// Test for hash collisions.
3601 for (size_t i = temp->get_hash();
3603 if (leaf_node<T, SAFE_MATH>::caches.nodes.find(i) ==
3604 leaf_node<T, SAFE_MATH>::caches.nodes.end()) {
3606 return temp;
3607 } else if (temp->is_match(leaf_node<T, SAFE_MATH>::caches.nodes[i])) {
3608 return leaf_node<T, SAFE_MATH>::caches.nodes[i];
3609 }
3610 }
3611#if defined(__clang__) || defined(__GNUC__)
3613#else
3614 assert(false && "Should never reach.");
3615#endif
3616 }
3617
3618//------------------------------------------------------------------------------
3629//------------------------------------------------------------------------------
3630 template<jit::float_scalar T, bool SAFE_MATH=false>
3635
3636//------------------------------------------------------------------------------
3648//------------------------------------------------------------------------------
3649 template<jit::float_scalar T, jit::float_scalar L, bool SAFE_MATH=false>
3654
3655//------------------------------------------------------------------------------
3667//------------------------------------------------------------------------------
3668 template<jit::float_scalar T, jit::float_scalar R, bool SAFE_MATH=false>
3673
3675 template<jit::float_scalar T, bool SAFE_MATH=false>
3676 using shared_divide = std::shared_ptr<divide_node<T, SAFE_MATH>>;
3677
3678//------------------------------------------------------------------------------
3686//------------------------------------------------------------------------------
3687 template<jit::float_scalar T, bool SAFE_MATH=false>
3689 return std::dynamic_pointer_cast<divide_node<T, SAFE_MATH>> (x);
3690 }
3691
3692//******************************************************************************
3693// fused multiply add node.
3694//******************************************************************************
3695//------------------------------------------------------------------------------
3702//------------------------------------------------------------------------------
3703 template<jit::float_scalar T, bool SAFE_MATH=false>
3704 class fma_node final : public triple_node<T, SAFE_MATH> {
3705 private:
3706//------------------------------------------------------------------------------
3713//------------------------------------------------------------------------------
3715 reduce_nested_fma(shared_subtract<T, SAFE_MATH> sub) {
3716 auto temp = fma_cast(this->left);
3717 if (temp.get()) {
3718 if (is_constant_combineable(sub->get_right(), temp->get_left()) &&
3719 is_constant_combineable(sub->get_right(), temp->get_right()) &&
3720 is_constant_combineable(this->right, temp->get_right()) &&
3721 temp->get_middle()->is_match(sub->get_left())) {
3722 return fma(fma(temp->get_left(),
3723 sub->get_left(),
3724 temp->get_right() - temp->get_left()*sub->get_right()),
3725 sub->get_left(),
3726 this->right - temp->get_right()*sub->get_right());
3727 } else {
3728 if (temp->get_middle()->is_match(sub->get_left()) &&
3729 is_constant_combineable(sub->get_right(), this->right)) {
3730 auto temp2 = temp->reduce_nested_fma(sub);
3731 if (temp2.get()) {
3732 return fma(temp2,
3733 sub->get_left(),
3734 this->right - temp->get_right()*sub->get_right());
3735 }
3736 }
3737 }
3738 }
3739 return this->shared_from_this();
3740 }
3741
3742//------------------------------------------------------------------------------
3749//------------------------------------------------------------------------------
3750 static std::string to_string(leaf_node<T, SAFE_MATH> *l,
3753 return "fma" + jit::format_to_string(reinterpret_cast<size_t> (l))
3754 + jit::format_to_string(reinterpret_cast<size_t> (m))
3755 + jit::format_to_string(reinterpret_cast<size_t> (r));
3756 }
3757
3758 public:
3759//------------------------------------------------------------------------------
3765//------------------------------------------------------------------------------
3772
3773//------------------------------------------------------------------------------
3779//------------------------------------------------------------------------------
3781 backend::buffer<T> l_result = this->left->evaluate();
3782 backend::buffer<T> r_result = this->right->evaluate();
3783
3784// If all the elements on the left are zero, return the leftside without
3785// revaluating the rightside.
3786 if (l_result.is_zero()) {
3787 return r_result;
3788 }
3789
3790 backend::buffer<T> m_result = this->middle->evaluate();
3791 return backend::fma(l_result, m_result, r_result);
3792 }
3793
3794//------------------------------------------------------------------------------
3798//------------------------------------------------------------------------------
3800 auto l = constant_cast(this->left);
3801 auto m = constant_cast(this->middle);
3802 auto r = constant_cast(this->right);
3803
3804 if ((l.get() && l->is(0)) ||
3805 (m.get() && m->is(0))) {
3806 return this->right;
3807 } else if (r.get() && r->is(0)) {
3808 return this->left*this->middle;
3809 } else if (l.get() && m.get() && r.get()) {
3810 return constant<T, SAFE_MATH> (this->evaluate());
3811 } else if (l.get() && m.get()) {
3812 return this->left*this->middle + this->right;
3813 } else if (l.get() && l->is(-1)) {
3814 return this->right - this->middle;
3815 } else if (m.get() && m->is(-1)) {
3816 return this->right - this->left;
3817 } else if (l.get() && l->is(1)) {
3818 return this->middle + this->right;
3819 } else if (m.get() && m->is(1)) {
3820 return this->left + this->right;
3821 }
3822
3823// Check if the left and middle are combinable. This will be constant merged in
3824// multiply reduction.
3825 if (is_constant_combineable(this->left, this->middle) ||
3826 is_variable_combineable(this->left, this->middle)) {
3827 return (this->left*this->middle) + this->right;
3828 }
3829
3830// fma(c2,c1,a) -> fma(c1,c2,a)
3831 if (is_constant_promotable(this->middle,
3832 this->left)) {
3833 return fma(this->middle, this->left, this->right);
3834 }
3835
3836// fma(a,b,a) -> a*(1 + b)
3837// fma(b,a,a) -> a*(1 + b)
3838 if (this->left->is_match(this->right)) {
3839 return this->left*(1.0 + this->middle);
3840 } else if (this->middle->is_match(this->right)) {
3841 return this->middle*(1.0 + this->left);
3842 }
3843
3844// fma(c1,c2 + a,c3) -> fma(c4,a,c5)
3845 auto ma = add_cast(this->middle);
3846 if (ma.get()) {
3847 if (is_constant_combineable(this->left, ma->get_left()) &&
3848 is_constant_combineable(this->left, this->right)) {
3849 return fma(this->left,
3850 ma->get_right(),
3851 fma(this->left, ma->get_left(), this->right));
3852 }
3853 }
3854
3855// fma(c1,c2 - a,c3) -> fma(-c1,a,c1*c2 + c3)
3856// fma(c1,a - c2,c3) -> fma(c1,a,c3 - c1*c2)
3857 auto ms = subtract_cast(this->middle);
3858 if (ms.get()) {
3859 if (is_constant_combineable(this->left, ms->get_left()) &&
3860 is_constant_combineable(this->left, this->right)) {
3861 return fma(-this->left, ms->get_right(),
3862 this->left*ms->get_left() + this->right);
3863 } else if (is_constant_combineable(this->left, ms->get_right()) &&
3864 is_constant_combineable(this->left, this->right)) {
3865 return fma(this->left, ms->get_left(),
3866 this->right - this->left*ms->get_right());
3867 }
3868
3869 auto temp = this->reduce_nested_fma(ms);
3870 if (temp.get() != this) {
3871 return temp;
3872 }
3873 }
3874
3875// Common factor reduction. If the left and right are both multiply nodes check
3876// for a common factor. So you can change a*b + (a*c) -> a*(b + c).
3877 auto lm = multiply_cast(this->left);
3878 auto mm = multiply_cast(this->middle);
3879 auto rm = multiply_cast(this->right);
3880 if (rm.get()) {
3881 if (rm->get_left()->is_match(this->left)) {
3882 return this->left*(this->middle + rm->get_right());
3883 } else if (rm->get_left()->is_match(this->middle)) {
3884 return this->middle*(this->left + rm->get_right());
3885 } else if (rm->get_right()->is_match(this->left)) {
3886 return this->left*(this->middle + rm->get_left());
3887 } else if (rm->get_right()->is_match(this->middle)) {
3888 return this->middle*(this->left + rm->get_left());
3889 }
3890
3891// Chnage case of
3892// fma(a,b,-c1*b) -> a*b - c1*b
3893 auto rmlc = constant_cast(rm->get_left());
3894 if (rmlc.get() && rmlc->evaluate().is_negative()) {
3895 return this->left*this->middle -
3896 (-1.0*rm->get_left())*rm->get_right();
3897 }
3898
3899// Change cases like
3900// fma(c1,a,c2*b) -> c1*fma(c3,b,a)
3901// fma(a,c1,c2*b) -> c1*fma(c3,b,a)
3902// fma(c1,a,b*c2) -> c1*fma(c3,b,a)
3903// fma(a,c1,b*c2) -> c1*fma(c3,b,a)
3904 if (is_constant_combineable(this->left,
3905 rm->get_left()) &&
3906 !this->left->has_constant_zero()) {
3907 auto temp = rm->get_left()/this->left;
3908 if (temp->is_normal()) {
3909 return this->left*fma(temp,
3910 rm->get_right(),
3911 this->middle);
3912 }
3913 }
3915 rm->get_left()) &&
3916 !this->middle->has_constant_zero()) {
3917 auto temp = rm->get_left()/this->middle;
3918 if (temp->is_normal()) {
3919 return this->middle*fma(temp,
3920 rm->get_right(),
3921 this->left);
3922 }
3923 }
3924 if (is_constant_combineable(this->left,
3925 rm->get_right()) &&
3926 !this->left->has_constant_zero()) {
3927 auto temp = rm->get_right()/this->left;
3928 if (temp->is_normal()) {
3929 return this->left*fma(temp,
3930 rm->get_left(),
3931 this->middle);
3932 }
3933 }
3935 rm->get_right()) &&
3936 !this->middle->has_constant_zero()) {
3937 auto temp = rm->get_right()/this->middle;
3938 if (temp->is_normal()) {
3939 return this->middle*fma(temp,
3940 rm->get_left(),
3941 this->left);
3942 }
3943 }
3944
3945// fma(a,b*c,b*d) -> b*fma(a,c,d)
3946// fma(a,c*b,b*d) -> b*fma(a,c,d)
3947// fma(a,b*c,d*b) -> b*fma(a,c,d)
3948// fma(a,c*b,d*b) -> b*fma(a,c,d)
3949 if (mm.get()) {
3950 if (mm->get_left()->is_match(rm->get_left())) {
3951 return mm->get_left()*fma(this->left,
3952 mm->get_right(),
3953 rm->get_right());
3954 } else if (mm->get_left()->is_match(rm->get_right())) {
3955 return mm->get_left()*fma(this->left,
3956 mm->get_right(),
3957 rm->get_left());
3958 } else if (mm->get_right()->is_match(rm->get_left())) {
3959 return mm->get_right()*fma(this->left,
3960 mm->get_left(),
3961 rm->get_right());
3962 } else if (mm->get_right()->is_match(rm->get_right())) {
3963 return mm->get_right()*fma(this->left,
3964 mm->get_left(),
3965 rm->get_left());
3966 }
3967 }
3968
3969// Convert fma(a*b,c,d*e) -> fma(d,e,a*b*c)
3970// Convert fma(a,b*c,d*e) -> fma(d,e,a*b*c)
3971 if ((lm.get() || mm.get()) &&
3972 (this->left->get_complexity() + this->middle->get_complexity() >
3973 this->right->get_complexity())) {
3974 return fma(rm->get_left(), rm->get_right(),
3975 this->left*this->middle);
3976 }
3977 }
3978
3979// Handle cases like.
3980// fma(c1*a,b,c2*d) -> c1*(a*b + c2/c1*d)
3981// fma(a*c1,b,c2*d) -> c1*(a*b + c2/c1*d)
3982// fma(c1*a,b,d*c2*d) -> c1*(a*b + c2/c1*d)
3983// fma(a*c1,b,d*c2*d) -> c1*(a*b + c2/c1*d)
3984 if (lm.get() && rm.get()) {
3985 if (is_constant_combineable(rm->get_left(),
3986 lm->get_left()) &&
3987 !lm->get_left()->has_constant_zero()) {
3988 auto temp = rm->get_left()/lm->get_left();
3989 if (temp->is_normal()){
3990 return lm->get_left()*fma(lm->get_right(),
3991 this->middle,
3992 temp*rm->get_right());
3993 }
3994 }
3995 if (is_constant_combineable(rm->get_left(),
3996 lm->get_right()) &&
3997 !lm->get_right()->has_constant_zero()) {
3998 auto temp = rm->get_left()/lm->get_right();
3999 if (temp->is_normal()){
4000 return lm->get_right()*fma(lm->get_left(),
4001 this->middle,
4002 temp*rm->get_right());
4003 }
4004 }
4005 if (is_constant_combineable(rm->get_right(),
4006 lm->get_left()) &&
4007 !lm->get_left()->has_constant_zero()) {
4008 auto temp = rm->get_right()/lm->get_left();
4009 if (temp->is_normal()) {
4010 return lm->get_left()*fma(lm->get_right(),
4011 this->middle,
4012 temp*rm->get_left());
4013 }
4014 }
4015 if (is_constant_combineable(rm->get_right(),
4016 lm->get_right()) &&
4017 !lm->get_right()->has_constant_zero()) {
4018 auto temp = rm->get_right()/lm->get_right();
4019 if (temp->is_normal()) {
4020 return lm->get_right()*fma(lm->get_left(),
4021 this->middle,
4022 temp*rm->get_left());
4023 }
4024 }
4025 }
4026
4027// Move constant multiplies to the left.
4028 if (lm.get()) {
4029// fma(c1*a,b,c) -> fma(c1,a*b,c)
4030 if (is_constant_promotable(lm->get_left(),
4031 lm->get_right())) {
4032 return fma(lm->get_left(),
4033 lm->get_right()*this->middle,
4034 this->right);
4035 }
4036 } else if (mm.get()) {
4037// fma(c1,c2*a,b) -> fma(c3,a,b)
4038// fma(c1,a*c2,b) -> fma(c3,a,b)
4039// fma(a,c1*b,c) -> fma(c1,a*b,c)
4040 if (is_constant_combineable(this->left,
4041 mm->get_left())) {
4042 auto temp = this->left*mm->get_left();
4043 if (temp->is_normal()) {
4044 return fma(temp,
4045 mm->get_right(),
4046 this->right);
4047 }
4048 }
4049 if (is_constant_combineable(this->left,
4050 mm->get_right())) {
4051 auto temp = this->left*mm->get_right();
4052 if (temp->is_normal()) {
4053 return fma(temp,
4054 mm->get_left(),
4055 this->right);
4056 }
4057 }
4058 if (is_constant_promotable(mm->get_left(),
4059 this->left)) {
4060 return fma(mm->get_left(),
4061 this->left*mm->get_right(),
4062 this->right);
4063 }
4064 }
4065
4066// fma(a,b*c,b) -> b*fma(a,c,1)
4067 if (mm.get()) {
4068 if (mm->get_left()->is_match(this->right)) {
4069 return mm->get_left()*fma(this->left,
4070 mm->get_right(),
4071 1.0);
4072 } else if (mm->get_right()->is_match(this->right)) {
4073 return mm->get_right()*fma(this->left,
4074 mm->get_left(),
4075 1.0);
4076 }
4077 }
4078
4079// fma(c1,a,c2/b) -> c1*(a + c3/b)
4080// fma(a,c1,c2/b) -> c1*(a + c3/b)
4081 auto rd = divide_cast(this->right);
4082 if (rd.get()) {
4083 if (is_constant_combineable(this->left,
4084 rd->get_left()) &&
4085 !this->left->has_constant_zero()) {
4086 auto temp = rd->get_left()/this->left;
4087 if (temp->is_normal()) {
4088 return this->left*(this->middle +
4089 temp/rd->get_right());
4090 }
4091 }
4093 rd->get_left()) &&
4094 !this->middle->has_constant_zero()) {
4095 auto temp = rd->get_left()/this->middle;
4096 if (temp->is_normal()) {
4097 return this->middle*(this->left +
4098 temp/rd->get_right());
4099 }
4100 }
4101 }
4102
4103// Reduce fma(a/b,b,c) -> a + c
4104// Reduce fma(a,b/a,c) -> b + c
4105 auto ld = divide_cast(this->left);
4106 if (ld.get() && ld->get_right()->is_match(this->middle)) {
4107 return ld->get_left() + this->right;
4108 }
4109 auto md = divide_cast(this->middle);
4110 if (md.get() && md->get_right()->is_match(this->left)) {
4111 return md->get_left() + this->right;
4112 }
4113
4114// Common denominator reductions.
4115 if (ld.get() && rd.get()) {
4116// fma(b/c,a,b,d) -> b(a/c + 1/d)
4117 if (ld->get_left()->is_match(rd->get_left())) {
4118 return ld->get_left()*(this->middle/ld->get_right() +
4119 1.0/rd->get_right());
4120 }
4121
4122// fma(a/(b*c),d,e/c) -> fma(a,d,e*b)/(b*c)
4123// fma(a/(c*b),d,e/c) -> fma(a,d,e*b)/(c*b)
4124// fma(a/c,d,e/(c*b)) -> fma(a*b,d,e)/(b*c)
4125// fma(a/c,d,e/(b*c)) -> fma(a*b,d,e)/(c*b)
4126 auto ldrm = multiply_cast(ld->get_right());
4127 auto rdrm = multiply_cast(rd->get_right());
4128
4129 if (ldrm.get()) {
4130 if (ldrm->get_right()->is_match(rd->get_right())) {
4131 return fma(ld->get_left(), this->middle,
4132 rd->get_left()*ldrm->get_left()) /
4133 ld->get_right();
4134 } else if (ldrm->get_left()->is_match(rd->get_right())) {
4135 return fma(ld->get_left(), this->middle,
4136 rd->get_left()*ldrm->get_right()) /
4137 ld->get_right();
4138 }
4139 } else if (rdrm.get()) {
4140 if (rdrm->get_right()->is_match(ld->get_right())) {
4141 return fma(ld->get_left()*rdrm->get_left(),
4142 this->middle, rd->get_left()) /
4143 rd->get_right();
4144 } else if (rdrm->get_left()->is_match(ld->get_right())) {
4145 return fma(ld->get_left()*rdrm->get_right(),
4146 this->middle, rd->get_left()) /
4147 rd->get_right();
4148 }
4149 }
4150 } else if (md.get() && rd.get()) {
4151// fma(a,d/(b*c),e/c) -> fma(a,d,e*b)/(b*c)
4152// fma(a,d/(c*b),e/c) -> fma(a,d,e*b)/(c*b)
4153// fma(a,d/c,e/(c*b)) -> fma(a,d*b,e)/(b*c)
4154// fma(a,d/c,e/(b*c)) -> fma(a,d*b,e)/(c*b)
4155 auto mdrm = multiply_cast(md->get_right());
4156 auto rdrm = multiply_cast(rd->get_right());
4157
4158 if (mdrm.get()) {
4159 if (mdrm->get_right()->is_match(rd->get_right())) {
4160 return fma(this->left, md->get_left(),
4161 rd->get_left()*mdrm->get_left()) /
4162 md->get_right();
4163 } else if (mdrm->get_left()->is_match(rd->get_right())) {
4164 return fma(this->left, md->get_left(),
4165 rd->get_left()*mdrm->get_right()) /
4166 md->get_right();
4167 }
4168 } else if (rdrm.get()) {
4169 if (rdrm->get_right()->is_match(md->get_right())) {
4170 return fma(this->left, md->get_left()*rdrm->get_left(),
4171 rd->get_left()) /
4172 rd->get_right();
4173 } else if (rdrm->get_left()->is_match(md->get_right())) {
4174 return fma(this->left, md->get_left()*rdrm->get_right(),
4175 rd->get_left()) /
4176 rd->get_right();
4177 }
4178 }
4179 }
4180
4181// Chained fma reductions.
4182 auto rfma = fma_cast(this->right);
4183 if (rfma.get()) {
4184// fma(a, b, fma(c, b, d)) -> fma(b, a + c, d)
4185// fma(b, a, fma(c, b, d)) -> fma(b, a + c, d)
4186// fma(a, b, fma(b, c, d)) -> fma(b, a + c, d)
4187// fma(b, a, fma(b, c, d)) -> fma(b, a + c, d)
4188 if (this->middle->is_match(rfma->get_middle())) {
4189 return fma(this->middle,
4190 this->left + rfma->get_left(),
4191 rfma->get_right());
4192 } else if (this->left->is_match(rfma->get_middle())) {
4193 return fma(this->left,
4194 this->middle + rfma->get_left(),
4195 rfma->get_right());
4196 } else if (this->middle->is_match(rfma->get_left())) {
4197 return fma(this->middle,
4198 this->left + rfma->get_middle(),
4199 rfma->get_right());
4200 } else if (this->left->is_match(rfma->get_left())) {
4201 return fma(this->left,
4202 this->middle + rfma->get_middle(),
4203 rfma->get_right());
4204 }
4205
4206 if (mm.get()) {
4207// fma(a, e*b, fma(c, b, d)) -> fma(b, fma(a, e, c), d)
4208// fma(a, b*e, fma(c, b, d)) -> fma(b, fma(a, e, c), d)
4209// fma(a, e*b, fma(b, c, d)) -> fma(b, fma(a, e, c), d)
4210// fma(a, b*e, fma(b, c, d)) -> fma(b, fma(a, e, c), d)
4211 if (mm->get_right()->is_match(rfma->get_middle())) {
4212 return fma(mm->get_right(),
4213 fma(this->left,
4214 mm->get_left(),
4215 rfma->get_left()),
4216 rfma->get_right());
4217 } else if (mm->get_left()->is_match(rfma->get_middle())) {
4218 return fma(mm->get_left(),
4219 fma(this->left,
4220 mm->get_right(),
4221 rfma->get_left()),
4222 rfma->get_right());
4223 } else if (mm->get_right()->is_match(rfma->get_left())) {
4224 return fma(mm->get_right(),
4225 fma(this->left,
4226 mm->get_left(),
4227 rfma->get_middle()),
4228 rfma->get_right());
4229 } else if (mm->get_left()->is_match(rfma->get_left())) {
4230 return fma(mm->get_left(),
4231 fma(this->left,
4232 mm->get_right(),
4233 rfma->get_middle()),
4234 rfma->get_right());
4235 }
4236 } else if (lm.get()) {
4237// fma(e*b, a, fma(c, b, d)) -> fma(b, fma(a, e, c), d)
4238// fma(b*e, a, fma(c, b, d)) -> fma(b, fma(a, e, c), d)
4239// fma(e*b, a, fma(b, c, d)) -> fma(b, fma(a, e, c), d)
4240// fma(e*d, a, fma(b, c, d)) -> fma(b, fma(a, e, c), d)
4241 if (lm->get_right()->is_match(rfma->get_middle())) {
4242 return fma(lm->get_right(),
4243 fma(this->middle,
4244 lm->get_left(),
4245 rfma->get_left()),
4246 rfma->get_right());
4247 } else if (lm->get_left()->is_match(rfma->get_middle())) {
4248 return fma(lm->get_left(),
4249 fma(this->middle,
4250 lm->get_right(),
4251 rfma->get_left()),
4252 rfma->get_right());
4253 } else if (lm->get_right()->is_match(rfma->get_left())) {
4254 return fma(lm->get_right(),
4255 fma(this->middle,
4256 lm->get_left(),
4257 rfma->get_middle()),
4258 rfma->get_right());
4259 } else if (lm->get_left()->is_match(rfma->get_left())) {
4260 return fma(lm->get_left(),
4261 fma(this->middle,
4262 lm->get_right(),
4263 rfma->get_middle()),
4264 rfma->get_right());
4265 }
4266 }
4267
4268 auto rfmamm = multiply_cast(rfma->get_middle());
4269 auto rfmalm = multiply_cast(rfma->get_left());
4270 if (rfmamm.get()) {
4271// fma(a, b, fma(c, e*b, d)) -> fma(b, fma(c, e, a), d)
4272// fma(b, a, fma(c, e*b, d)) -> fma(b, fma(c, e, a), d)
4273// fma(a, b, fma(c, b*e, d)) -> fma(b, fma(c, e, a), d)
4274// fma(b, a, fma(c, b*e, d)) -> fma(b, fma(c, e, a), d)
4275 if (rfmamm->get_right()->is_match(this->middle)) {
4276 return fma(this->middle,
4277 fma(rfma->get_left(),
4278 rfmamm->get_left(),
4279 this->left),
4280 rfma->get_right());
4281 } else if (rfmamm->get_right()->is_match(this->left)) {
4282 return fma(this->left,
4283 fma(rfma->get_left(),
4284 rfmamm->get_left(),
4285 this->middle),
4286 rfma->get_right());
4287 } else if (rfmamm->get_left()->is_match(this->middle)) {
4288 return fma(this->middle,
4289 fma(rfma->get_left(),
4290 rfmamm->get_right(),
4291 this->left),
4292 rfma->get_right());
4293 } else if (rfmamm->get_left()->is_match(this->left)) {
4294 return fma(this->left,
4295 fma(rfma->get_left(),
4296 rfmamm->get_right(),
4297 this->middle),
4298 rfma->get_right());
4299 }
4300 } else if (rfmalm.get()) {
4301// fma(a, b, fma(e*b, c, d)) -> fma(b, fma(c, e, a), d)
4302// fma(b, a, fma(e*b, c, d)) -> fma(b, fma(c, e, a), d)
4303// fma(a, b, fma(b*e, c, d)) -> fma(b, fma(c, e, a), d)
4304// fma(b, a, fma(b*e, c, d)) -> fma(b, fma(c, e, a), d)
4305 if (rfmalm->get_right()->is_match(this->middle)) {
4306 return fma(this->middle,
4307 fma(rfma->get_middle(),
4308 rfmalm->get_left(),
4309 this->left),
4310 rfma->get_right());
4311 } else if (rfmalm->get_right()->is_match(this->left)) {
4312 return fma(this->left,
4313 fma(rfma->get_middle(),
4314 rfmalm->get_left(),
4315 this->middle),
4316 rfma->get_right());
4317 } else if (rfmalm->get_left()->is_match(this->middle)) {
4318 return fma(this->middle,
4319 fma(rfma->get_middle(),
4320 rfmalm->get_right(),
4321 this->left),
4322 rfma->get_right());
4323 } else if (rfmalm->get_left()->is_match(this->left)) {
4324 return fma(this->left,
4325 fma(rfma->get_middle(),
4326 rfmalm->get_right(),
4327 this->middle),
4328 rfma->get_right());
4329 }
4330 }
4331
4332 if (mm.get() && rfmamm.get()) {
4333// fma(a, f*b, fma(c, e*b, d)) -> fma(b, fma(a, f, c*e), d)
4334// fma(a, b*f, fma(c, e*b, d)) -> fma(b, fma(a, f, c*e), d)
4335// fma(a, f*b, fma(c, b*e, d)) -> fma(b, fma(a, f, c*e), d)
4336// fma(a, b*f, fma(c, b*e, d)) -> fma(b, fma(a, f, c*e), d)
4337 if (mm->get_right()->is_match(rfmamm->get_right())) {
4338 return fma(mm->get_right(),
4339 fma(this->left,
4340 mm->get_left(),
4341 rfma->get_left()*rfmamm->get_left()),
4342 rfma->get_right());
4343 } else if (mm->get_left()->is_match(rfmamm->get_right())) {
4344 return fma(mm->get_left(),
4345 fma(this->left,
4346 mm->get_right(),
4347 rfma->get_left()*rfmamm->get_left()),
4348 rfma->get_right());
4349 } else if (mm->get_right()->is_match(rfmamm->get_left())) {
4350 return fma(mm->get_right(),
4351 fma(this->left,
4352 mm->get_left(),
4353 rfma->get_left()*rfmamm->get_right()),
4354 rfma->get_right());
4355 } else if (mm->get_left()->is_match(rfmamm->get_left())) {
4356 return fma(mm->get_left(),
4357 fma(this->left,
4358 mm->get_right(),
4359 rfma->get_left()*rfmamm->get_right()),
4360 rfma->get_right());
4361 }
4362 } else if (lm.get() && rfmamm.get()) {
4363// fma(f*b, a, fma(c, e*b, d)) -> fma(b, fma(a, f, c*e), d)
4364// fma(b*f, a, fma(c, e*b, d)) -> fma(b, fma(a, f, c*e), d)
4365// fma(f*b, a, fma(c, b*e, d)) -> fma(b, fma(a, f, c*e), d)
4366// fma(b*f, a, fma(c, b*e, d)) -> fma(b, fma(a, f, c*e), d)
4367 if (lm->get_right()->is_match(rfmamm->get_right())) {
4368 return fma(lm->get_right(),
4369 fma(this->middle,
4370 lm->get_left(),
4371 rfma->get_left()*rfmamm->get_left()),
4372 rfma->get_right());
4373 } else if (lm->get_left()->is_match(rfmamm->get_right())) {
4374 return fma(lm->get_left(),
4375 fma(this->middle,
4376 lm->get_right(),
4377 rfma->get_left()*rfmamm->get_left()),
4378 rfma->get_right());
4379 } else if (lm->get_right()->is_match(rfmamm->get_left())) {
4380 return fma(lm->get_right(),
4381 fma(this->middle,
4382 lm->get_left(),
4383 rfma->get_left()*rfmamm->get_right()),
4384 rfma->get_right());
4385 } else if (lm->get_left()->is_match(rfmamm->get_left())) {
4386 return fma(lm->get_left(),
4387 fma(this->middle,
4388 lm->get_right(),
4389 rfma->get_left()*rfmamm->get_right()),
4390 rfma->get_right());
4391 }
4392 } else if (mm.get() && rfmalm.get()) {
4393// fma(a, f*b, fma(e*b, c, d)) -> fma(b, fma(a, f, c*e), d)
4394// fma(a, b*f, fma(e*b, c, d)) -> fma(b, fma(a, f, c*e), d)
4395// fma(a, f*b, fma(b*e, c, d)) -> fma(b, fma(a, f, c*e), d)
4396// fma(a, b*f, fma(b*e, c, d)) -> fma(b, fma(a, f, c*e), d)
4397 if (mm->get_right()->is_match(rfmalm->get_right())) {
4398 return fma(mm->get_right(),
4399 fma(this->left,
4400 mm->get_left(),
4401 rfma->get_middle()*rfmalm->get_left()),
4402 rfma->get_right());
4403 } else if (mm->get_left()->is_match(rfmalm->get_right())) {
4404 return fma(mm->get_left(),
4405 fma(this->left,
4406 mm->get_right(),
4407 rfma->get_middle()*rfmalm->get_left()),
4408 rfma->get_right());
4409 } else if (mm->get_right()->is_match(rfmalm->get_left())) {
4410 return fma(mm->get_right(),
4411 fma(this->left,
4412 mm->get_left(),
4413 rfma->get_middle()*rfmalm->get_right()),
4414 rfma->get_right());
4415 } else if (mm->get_left()->is_match(rfmalm->get_left())) {
4416 return fma(mm->get_left(),
4417 fma(this->left,
4418 mm->get_right(),
4419 rfma->get_middle()*rfmalm->get_right()),
4420 rfma->get_right());
4421 }
4422 } else if (lm.get() && rfmalm.get()) {
4423// fma(f*b, a, fma(e*b, c, d)) -> fma(b, fma(a, f, c*e), d)
4424// fma(b*f, a, fma(e*b, c, d)) -> fma(b, fma(a, f, c*e), d)
4425// fma(f*b, a, fma(b*e, c, d)) -> fma(b, fma(a, f, c*e), d)
4426// fma(b*f, a, fma(b*e, c, d)) -> fma(b, fma(a, f, c*e), d)
4427 if (lm->get_right()->is_match(rfmalm->get_right())) {
4428 return fma(lm->get_right(),
4429 fma(this->middle,
4430 lm->get_left(),
4431 rfma->get_middle()*rfmalm->get_left()),
4432 rfma->get_right());
4433 } else if (lm->get_left()->is_match(rfmalm->get_right())) {
4434 return fma(lm->get_left(),
4435 fma(this->middle,
4436 lm->get_right(),
4437 rfma->get_middle()*rfmalm->get_left()),
4438 rfma->get_right());
4439 } else if (lm->get_right()->is_match(rfmalm->get_left())) {
4440 return fma(lm->get_right(),
4441 fma(this->middle,
4442 lm->get_left(),
4443 rfma->get_middle()*rfmalm->get_right()),
4444 rfma->get_right());
4445 } else if (lm->get_left()->is_match(rfmalm->get_left())) {
4446 return fma(lm->get_left(),
4447 fma(this->middle,
4448 lm->get_right(),
4449 rfma->get_middle()*rfmalm->get_right()),
4450 rfma->get_right());
4451 }
4452 }
4453
4454 if (is_variable_combineable(this->middle, rfma->get_middle())) {
4455 if (is_greater_exponent(this->middle, rfma->get_middle())) {
4456// fma(a,x^b,fma(c,x^d,e)) -> fma(x^d,fma(x^(d-b),a,c),e) if b > d
4457 return fma(rfma->get_middle(),
4458 fma(this->middle/rfma->get_middle(),
4459 this->left,
4460 rfma->get_left()),
4461 rfma->get_right());
4462 } else {
4463// fma(a,x^b,fma(c,x^d,e)) -> fma(x^b,fma(x^(d-b),c,a),e) if d > b
4464 return fma(this->middle,
4465 fma(rfma->get_middle()/this->middle,
4466 rfma->get_left(),
4467 this->left),
4468 rfma->get_right());
4469 }
4470 } else if (is_variable_combineable(this->left, rfma->get_middle())) {
4471 if (is_greater_exponent(this->left, rfma->get_middle())) {
4472// fma(x^b,a,fma(c,x^d,e)) -> fma(x^d,fma(x^(d-b),a,c),e) if b > d
4473 return fma(rfma->get_middle(),
4474 fma(this->left/rfma->get_middle(),
4475 this->middle,
4476 rfma->get_left()),
4477 rfma->get_right());
4478 } else {
4479// fma(x^b,a,fma(c,x^d,e)) -> fma(x^b,fma(x^(d-b),c,a),e) if d > b
4480 return fma(this->left,
4481 fma(rfma->get_middle()/this->left,
4482 rfma->get_left(),
4483 this->middle),
4484 rfma->get_right());
4485 }
4486 } else if (is_variable_combineable(this->middle, rfma->get_left())) {
4487 if (is_greater_exponent(this->middle, rfma->get_left())) {
4488// fma(a,x^b,fma(x^d,c,e)) -> fma(x^d,fma(x^(d-b),a,c),e) if b > d
4489 return fma(rfma->get_left(),
4490 fma(this->middle/rfma->get_left(),
4491 this->left,
4492 rfma->get_middle()),
4493 rfma->get_right());
4494 } else {
4495// fma(a,x^b,fma(x^d,c,e)) -> fma(x^b,fma(x^(d-b),c,a),e) if d > b
4496 return fma(this->middle,
4497 fma(rfma->get_left()/this->middle,
4498 rfma->get_middle(),
4499 this->left),
4500 rfma->get_right());
4501 }
4502 } else if (is_variable_combineable(this->left, rfma->get_left())) {
4503 if (is_greater_exponent(this->left, rfma->get_left())) {
4504// fma(x^b,a,fma(x^d,c,e)) -> fma(x^d,fma(x^(d-b),a,c),e) if b > d
4505 return fma(rfma->get_left(),
4506 fma(this->left/rfma->get_left(),
4507 this->middle,
4508 rfma->get_middle()),
4509 rfma->get_right());
4510 } else {
4511// fma(x^b,a,fma(x^d,c,e)) -> fma(x^b,fma(x^(d-b),c,a),e) if d > b
4512 return fma(this->left,
4513 fma(rfma->get_left()/this->left,
4514 rfma->get_middle(),
4515 this->middle),
4516 rfma->get_right());
4517 }
4518 }
4519
4520// fma(a,b,fma(a,b,c)) -> fma(2*a,b,c)
4521// fma(a,b,fma(b,a,c)) -> fma(2*a,b,c)
4522 if (this->left->is_match(rfma->get_left()) &&
4523 this->middle->is_match(rfma->get_middle())) {
4524 return fma(2.0*this->left, this->middle, rfma->get_right());
4525 } else if (this->left->is_match(rfma->get_middle()) &&
4526 this->middle->is_match(rfma->get_left())) {
4527 return fma(2.0*this->left, this->middle, rfma->get_right());
4528 }
4529
4530// fma(a,b/c,fma(e,f/c,g)) -> (a*b + e*f)/c + g
4531// fma(a,b/c,fma(e/c,f,g)) -> (a*b + e*f)/c + g
4532// fma(a/c,b,fma(e,f/c,g)) -> (a*b + e*f)/c + g
4533// fma(a/c,b,fma(e/c,f,g)) -> (a*b + e*f)/c + g
4534 auto fmald = divide_cast(rfma->get_left());
4535 auto fmamd = divide_cast(rfma->get_middle());
4536 if (ld.get()) {
4537 if (fmald.get() && ld->get_right()->is_match(fmald->get_right())) {
4538 return (ld->get_left()*this->middle +
4539 fmald->get_left()*rfma->get_middle())/ld->get_right() +
4540 rfma->get_right();
4541 } else if (fmamd.get() && ld->get_right()->is_match(fmamd->get_right())) {
4542 return (ld->get_left()*this->middle +
4543 fmamd->get_left()*rfma->get_left())/ld->get_right() +
4544 rfma->get_right();
4545 }
4546 } else if (md.get()) {
4547 if (fmald.get() && md->get_right()->is_match(fmald->get_right())) {
4548 return (md->get_left()*this->left +
4549 fmald->get_left()*rfma->get_middle())/md->get_right() +
4550 rfma->get_right();
4551 } else if (fmamd.get() && md->get_right()->is_match(fmamd->get_right())) {
4552 return (md->get_left()*this->left +
4553 fmamd->get_left()*rfma->get_left())/md->get_right() +
4554 rfma->get_right();
4555 }
4556 }
4557 }
4558
4559// Check to see if it is worth moving nodes out of a fma nodes. These should be
4560// restricted to variable like nodes. Only do this reduction if the complexity
4561// reduces.
4562 if (this->left->is_all_variables()) {
4563 auto rdl = this->right/this->left;
4564 if (rdl->get_complexity() < this->left->get_complexity() +
4565 this->right->get_complexity()) {
4566 return (this->middle + rdl)*this->left;
4567 }
4568 } else if (this->middle->is_all_variables()) {
4569 auto rdm = this->right/this->middle;
4570 auto rdmc = constant_cast(rdm->get_power_exponent());
4571 if ((rdm->get_complexity() < this->middle->get_complexity() +
4572 this->right->get_complexity()) &&
4573 !(rdmc.get() && rdmc->evaluate().is_negative())) {
4574 return (this->left + rdm)*this->middle;
4575 }
4576 }
4577
4578// Change negative exponents to divide so that can be factored out.
4579// fma(a,b^-c,d) = a/b^c + d
4580// fma(b^-c,a,d) = a/b^c + d
4581 auto lp = pow_cast(this->left);
4582 if (lp.get()) {
4583 auto exponent = constant_cast(lp->get_right());
4584 if (exponent.get() && exponent->evaluate().is_negative()) {
4585 return this->middle/pow(lp->get_left(), -lp->get_right()) +
4586 this->right;
4587 }
4588 }
4589 auto mp = pow_cast(this->middle);
4590 if (mp.get()) {
4591 auto exponent = constant_cast(mp->get_right());
4592 if (exponent.get() && exponent->evaluate().is_negative()) {
4593 return this->left/pow(mp->get_left(), -mp->get_right()) +
4594 this->right;
4595 }
4596
4597// fma(2,a^2,a) -> a*fma(2,a,1)
4598// Note this case is handled eailer. fma(2,a,a^2) -> a*fma(2,1,a)
4600 this->right)) {
4601 auto temp = this->right/this->middle;
4602 auto temp_exponent = constant_cast(temp->get_power_exponent());
4603 if (temp_exponent.get() && temp_exponent->evaluate().is_negative()) {
4604 return this->right*fma(this->left,
4605 this->middle/this->right,
4606 1.0);
4607 }
4608 }
4609 }
4610
4611// a^b*c^b + d -> (a*c)^b + d
4612 if (lp.get() && mp.get()) {
4613 if (lp->get_right()->is_match(mp->get_right())) {
4614 return pow(lp->get_left()*mp->get_left(),
4615 lp->get_right()) +
4616 this->right;
4617 }
4618 }
4619
4620// fma(2,(ab)^2,a^2b) -> a^2*fma(2, b^2, b)
4621 if (rm.get() && mp.get()) {
4622 auto mplm = multiply_cast(mp->get_left());
4623 if (mplm.get()) {
4624 if (is_variable_combineable(mplm->get_left(),
4625 rm->get_left())) {
4626 auto temp = pow(mplm->get_left(),
4627 mp->get_right());
4628 return temp*fma(this->left,
4629 this->middle/temp,
4630 this->right/temp);
4631 } else if (is_variable_combineable(mplm->get_right(),
4632 rm->get_left())) {
4633 auto temp = pow(mplm->get_right(),
4634 mp->get_right());
4635 return temp*fma(this->left,
4636 this->middle/temp,
4637 this->right/temp);
4638 }
4639 }
4640 }
4641// fma(2,(a*b)^2,fma(a^2,b,c)) -> fma(a^2,fma(2,b^2,b),c)
4642 if (rfma.get() && mp.get()) {
4643 auto mplm = multiply_cast(mp->get_left());
4644 if (mplm.get()) {
4645 if (is_variable_combineable(mplm->get_left(),
4646 rfma->get_left())) {
4647 auto temp = pow(mplm->get_left(),
4648 mp->get_right());
4649 return fma(temp,
4650 fma(this->left,
4651 this->middle/temp,
4652 rfma->get_middle()),
4653 rfma->get_right());
4654 } else if (is_variable_combineable(mplm->get_right(),
4655 rfma->get_left())) {
4656 auto temp = pow(mplm->get_right(),
4657 mp->get_right());
4658 return fma(temp,
4659 fma(this->left,
4660 this->middle/temp,
4661 rfma->get_middle()),
4662 rfma->get_right());
4663 }
4664
4665// fma(2,(a*b)^2,fma(3,a^2*b,c)) -> a^2*fma(2,b^2,fma(3,b,c))
4666 auto rfmamm = multiply_cast(rfma->get_middle());
4667 if (rfmamm.get()) {
4668 if (is_variable_combineable(mplm->get_left(),
4669 rfmamm->get_left())) {
4670 auto temp = pow(mplm->get_left(),
4671 mp->get_right());
4672 return temp*fma(this->left,
4673 this->middle/temp,
4674 fma(rfma->get_left(),
4675 rfma->get_middle()/temp,
4676 rfma->get_right()));
4677 }
4678 }
4679 }
4680 }
4681
4682// fma(a,b/c,b/d) -> b*(a/c + 1/d)
4683// fma(a,c/b,d/b) -> (a*c + d)/b
4684 if (md.get() && rd.get()) {
4685 if (md->get_left()->is_match(rd->get_left())) {
4686 return md->get_left()*(this->left/md->get_right() +
4687 1.0/rd->get_right());
4688 } else if (md->get_right()->is_match(rd->get_right())) {
4689 return (this->left*md->get_left() +
4690 rd->get_left())/md->get_right();
4691 }
4692 }
4693// fma(b/c,a,b/d) -> b*(a/c + 1/d)
4694// fma(c/b,a,d/b) -> (a*c + d)/b
4695 if (ld.get() && rd.get()) {
4696 if (ld->get_left()->is_match(rd->get_left())) {
4697 return ld->get_left()*(this->middle/ld->get_right() +
4698 1.0/rd->get_right());
4699 } else if (ld->get_right()->is_match(rd->get_right())) {
4700 return (this->middle*ld->get_left() +
4701 rd->get_left())/ld->get_right();
4702 }
4703 }
4704
4705// fma(a/b,c,(d/b)*e) -> fma(a,c,d*e)/b
4706// fma(a/b,c,e*(d/b)) -> fma(a,c,d*e)/b
4707 if (rm.get() && ld.get()) {
4708 auto rmld = divide_cast(rm->get_left());
4709 if (rmld.get() && ld->get_right()->is_match(rmld->get_right())) {
4710 return fma(ld->get_left(), this->middle, rmld->get_left()*rm->get_right())/ld->get_right();
4711 }
4712 auto rmrd = divide_cast(rm->get_right());
4713 if (rmrd.get() && ld->get_right()->is_match(rmrd->get_right())) {
4714 return fma(ld->get_left(), this->middle, rmrd->get_left()*rm->get_left())/ld->get_right();
4715 }
4716 }
4717// fma(a,c/b,(d/b)*e) -> fma(a,c,d*e)/b
4718// fma(a,c/b,e*(d/b)) -> fma(a,c,d*e)/b
4719 if (rm.get() && md.get()) {
4720 auto rmld = divide_cast(rm->get_left());
4721 if (rmld.get() && md->get_right()->is_match(rmld->get_right())) {
4722 return fma(this->left, md->get_left(), rmld->get_left()*rm->get_right())/md->get_right();
4723 }
4724 auto rmrd = divide_cast(rm->get_right());
4725 if (rmrd.get() && md->get_right()->is_match(rmrd->get_right())) {
4726 return fma(this->left, md->get_left(), rmrd->get_left()*rm->get_left())/md->get_right();
4727 }
4728 }
4729
4730// fma(a/b*c,d,e/b) -> fma(a*c,d,e)/b
4731// fma(a*c/b,d,e/b) -> fma(a*c,d,e)/b
4732 if (rd.get() && lm.get()) {
4733 auto lmld = divide_cast(lm->get_left());
4734 if (lmld.get() && rd->get_right()->is_match(lmld->get_right())) {
4735 return fma(lmld->get_left()*lm->get_right(), this->middle, rd->get_left())/rd->get_right();
4736 }
4737 auto lmrd = divide_cast(lm->get_right());
4738 if (lmrd.get() && rd->get_right()->is_match(lmrd->get_right())) {
4739 return fma(lmld->get_left()*lm->get_left(), this->middle, rd->get_left())/rd->get_right();
4740 }
4741 }
4742// fma(a,c/b*d,e/b) -> fma(a,c*d,e)/b
4743// fma(a,c*d/b,e/b) -> fma(a,c*d,e)/b
4744 if (rd.get() && mm.get()) {
4745 auto mmld = divide_cast(mm->get_left());
4746 if (mmld.get() && rd->get_right()->is_match(mmld->get_right())) {
4747 return fma(this->left, mmld->get_left()*mm->get_right(), rd->get_left())/rd->get_right();
4748 }
4749 auto mmrd = divide_cast(mm->get_right());
4750 if (mmrd.get() && rd->get_right()->is_match(mmrd->get_right())) {
4751 return fma(this->left, mmrd->get_left()*mm->get_left(), rd->get_left())/rd->get_right();
4752 }
4753 }
4754
4755// fma(a, b/c, ((f/c)*e)*d) -> fma(a, b, f*e*d)/c
4756// fma(a/c, b, ((f/c)*e)*d) -> fma(a, b, f*e*d)/c
4757// fma(a, b/c, (e*(f/c))*d) -> fma(a, b, f*e*d)/c
4758// fma(a/c, b, (e*(f/c))*d) -> fma(a, b, f*e*d)/c
4759// fma(a, b/c, d*((f/c)*e)) -> fma(a, b, f*e*d)/c
4760// fma(a/c, b, d*((f/c)*e)) -> fma(a, b, f*e*d)/c
4761// fma(a, b/c, d*(e*(f/c))) -> fma(a, b, f*e*d)/c
4762// fma(a/c, b, d*(e*(f/c))) -> fma(a, b, f*e*d)/c
4763 if (md.get() && rm.get()) {
4764 auto rmlm = multiply_cast(rm->get_left());
4765 if (rmlm.get()) {
4766 auto rmlmld = divide_cast(rmlm->get_left());
4767 if (rmlmld.get() && rmlmld->get_right()->is_match(md->get_right())) {
4768 return fma(this->left, md->get_left(),
4769 rmlmld->get_left()*rmlm->get_right()*rm->get_right())/md->get_right();
4770 }
4771 auto rmlmrd = divide_cast(rmlm->get_right());
4772 if (rmlmrd.get() && rmlmrd->get_right()->is_match(md->get_right())) {
4773 return fma(this->left, md->get_left(),
4774 rmlmrd->get_left()*rmlm->get_left()*rm->get_right())/md->get_right();
4775 }
4776 }
4777 auto rmrm = multiply_cast(rm->get_right());
4778 if (rmrm.get()) {
4779 auto rmrmld = divide_cast(rmrm->get_left());
4780 if (rmrmld.get() && rmrmld->get_right()->is_match(md->get_right())) {
4781 return fma(this->left, md->get_left(),
4782 rmrmld->get_left()*rmrm->get_right()*rm->get_left())/md->get_right();
4783 }
4784 auto rmrmrd = divide_cast(rmrm->get_right());
4785 if (rmrmrd.get() && rmrmrd->get_right()->is_match(md->get_right())) {
4786 return fma(this->left, md->get_left(),
4787 rmrmrd->get_left()*rmrm->get_left()*rm->get_left())/md->get_right();
4788 }
4789 }
4790 } else if (ld.get() && rm.get()) {
4791 auto rmlm = multiply_cast(rm->get_left());
4792 if (rmlm.get()) {
4793 auto rmlmld = divide_cast(rmlm->get_left());
4794 if (rmlmld.get() && rmlmld->get_right()->is_match(ld->get_right())) {
4795 return fma(ld->get_left(), this->middle,
4796 rmlmld->get_left()*rmlm->get_right()*rm->get_right())/ld->get_right();
4797 }
4798 auto rmlmrd = divide_cast(rmlm->get_right());
4799 if (rmlmrd.get() && rmlmrd->get_right()->is_match(ld->get_right())) {
4800 return fma(ld->get_left(), this->middle,
4801 rmlmrd->get_left()*rmlm->get_right()*rm->get_right())/ld->get_right();
4802 }
4803 }
4804 auto rmrm = multiply_cast(rm->get_right());
4805 if (rmrm.get()) {
4806 auto rmrmld = divide_cast(rmrm->get_left());
4807 if (rmrmld.get() && rmrmld->get_right()->is_match(ld->get_right())) {
4808 return fma(ld->get_left(), this->middle,
4809 rmrmld->get_left()*rmrm->get_right()*rm->get_left())/ld->get_right();
4810 }
4811 auto rmrmrd = divide_cast(rmrm->get_right());
4812 if (rmrmrd.get() && rmrmrd->get_right()->is_match(ld->get_right())) {
4813 return fma(ld->get_left(), this->middle,
4814 rmrmrd->get_left()*rmrm->get_left()*rm->get_left())/ld->get_right();
4815 }
4816 }
4817 }
4818
4819// fma(exp(a), exp(b), c) -> exp(a + b) + c
4820 auto le = exp_cast(this->left);
4821 auto me = exp_cast(this->middle);
4822 if (le.get() && me.get()) {
4823 return exp(le->get_arg() + me->get_arg()) + this->right;
4824 }
4825
4826// fma(exp(a), exp(b)*c, d) -> fma(exp(a)*exp(b), c, d)
4827// fma(exp(a), c*exp(b), d) -> fma(exp(a)*exp(b), c, d)
4828 if (mm.get() && le.get()) {
4829 auto mmle = exp_cast(mm->get_left());
4830 if (mmle.get()) {
4831 return fma(this->left*mm->get_left(),
4832 mm->get_right(),
4833 this->right);
4834 }
4835 auto mmre = exp_cast(mm->get_right());
4836 if (mmre.get()) {
4837 return fma(this->left*mm->get_right(),
4838 mm->get_left(),
4839 this->right);
4840 }
4841 }
4842// fma(exp(a)*c, exp(b), d) -> fma(exp(a)*exp(b), c, d)
4843// fma(c*exp(a), exp(b), d) -> fma(exp(a)*exp(b), c, d)
4844 if (lm.get() && me.get()) {
4845 auto lmle = exp_cast(lm->get_left());
4846 if (lmle.get()) {
4847 return fma(lm->get_left()*this->middle,
4848 lm->get_right(),
4849 this->right);
4850 }
4851 auto lmre = exp_cast(lm->get_right());
4852 if (lmre.get()) {
4853 return fma(lm->get_right()*this->middle,
4854 lm->get_left(),
4855 this->right);
4856 }
4857 }
4858
4859// fma(exp(a)*c, exp(b)*d, e) -> fma(exp(a)*exp(b), c*d, e)
4860// fma(exp(a)*c, d*exp(b), e) -> fma(exp(a)*exp(b), c*d, e)
4861// fma(c*exp(a), exp(b)*d, e) -> fma(exp(a)*exp(b), c*d, e)
4862// fma(c*exp(a), d*exp(b), e) -> fma(exp(a)*exp(b), c*d, e)
4863 if (lm.get() && mm.get()) {
4864 auto lmle = exp_cast(lm->get_left());
4865 if (lmle.get()) {
4866 auto mmle = exp_cast(mm->get_left());
4867 if (mmle.get()) {
4868 return fma(lm->get_left()*mm->get_left(),
4869 lm->get_right()*mm->get_right(),
4870 this->right);
4871 }
4872 auto mmre = exp_cast(mm->get_right());
4873 if (mmre.get()) {
4874 return fma(lm->get_left()*mm->get_right(),
4875 lm->get_right()*mm->get_left(),
4876 this->right);
4877 }
4878 }
4879 auto lmre = exp_cast(lm->get_right());
4880 if (lmre.get()) {
4881 auto mmle = exp_cast(mm->get_left());
4882 if (mmle.get()) {
4883 return fma(lm->get_right()*mm->get_left(),
4884 lm->get_left()*mm->get_right(),
4885 this->right);
4886 }
4887 auto mmre = exp_cast(mm->get_right());
4888 if (mmre.get()) {
4889 return fma(lm->get_right()*mm->get_right(),
4890 lm->get_left()*mm->get_left(),
4891 this->right);
4892 }
4893 }
4894 }
4895
4896// fma(exp(a)*c, exp(b)/d, e) -> fma(exp(a)*exp(b), c/d, e)
4897// fma(exp(a)*c, d/exp(b), e) -> fma(exp(a)/exp(b), c*d, e)
4898// fma(c*exp(a), exp(b)/d, e) -> fma(exp(a)*exp(b), c/d, e)
4899// fma(c*exp(a), d/exp(b), e) -> fma(exp(a)/exp(b), c*d, e)
4900 if (lm.get() && md.get()) {
4901 auto lmle = exp_cast(lm->get_left());
4902 if (lmle.get()) {
4903 auto mdle = exp_cast(md->get_left());
4904 if (mdle.get()) {
4905 return fma(lm->get_left()*md->get_left(),
4906 lm->get_right()/md->get_right(),
4907 this->right);
4908 }
4909 auto mdre = exp_cast(md->get_right());
4910 if (mdre.get()) {
4911 return fma(lm->get_left()/md->get_right(),
4912 lm->get_right()*md->get_left(),
4913 this->right);
4914 }
4915 }
4916 auto lmre = exp_cast(lm->get_right());
4917 if (lmre.get()) {
4918 auto mdle = exp_cast(md->get_left());
4919 if (mdle.get()) {
4920 return fma(lm->get_right()*md->get_left(),
4921 lm->get_left()/md->get_right(),
4922 this->right);
4923 }
4924 auto mdre = exp_cast(md->get_right());
4925 if (mdre.get()) {
4926 return fma(lm->get_right()/md->get_right(),
4927 lm->get_left()*md->get_left(),
4928 this->right);
4929 }
4930 }
4931 }
4932
4933// fma(exp(a)/c, exp(b)*d, e) -> fma(exp(a)*exp(b), d/c, e)
4934// fma(exp(a)/c, d*exp(b), e) -> fma(exp(a)*exp(b), d/c, e)
4935// fma(c/exp(a), exp(b)*d, e) -> fma(exp(b)/exp(a), c*d, e)
4936// fma(c/exp(a), d*exp(b), e) -> fma(exp(b)/exp(a), c*d, e)
4937 if (ld.get() && mm.get()) {
4938 auto ldle = exp_cast(ld->get_left());
4939 if (ldle.get()) {
4940 auto mmle = exp_cast(mm->get_left());
4941 if (mmle.get()) {
4942 return fma(ld->get_left()*mm->get_left(),
4943 mm->get_right()/ld->get_right(),
4944 this->right);
4945 }
4946 auto mmre = exp_cast(mm->get_right());
4947 if (mmre.get()) {
4948 return fma(ld->get_left()*mm->get_right(),
4949 mm->get_left()/ld->get_right(),
4950 this->right);
4951 }
4952 }
4953 auto ldre = exp_cast(ld->get_right());
4954 if (ldre.get()) {
4955 auto mmle = exp_cast(mm->get_left());
4956 if (mmle.get()) {
4957 return fma(mm->get_left()/ld->get_right(),
4958 ld->get_left()*mm->get_right(),
4959 this->right);
4960 }
4961 auto mmre = exp_cast(mm->get_right());
4962 if (mmre.get()) {
4963 return fma(mm->get_right()/ld->get_right(),
4964 ld->get_left()*mm->get_left(),
4965 this->right);
4966 }
4967 }
4968 }
4969
4970// fma(exp(a)/c, exp(b)/d, e) -> (exp(a)*exp(b))/(c*d) + e
4971// fma(exp(a)/c, d/exp(b), e) -> fma(exp(a)/exp(b), d/c, e)
4972// fma(c/exp(a), exp(b)/d, e) -> fma(exp(b)/exp(a), c/d, e)
4973// fma(c/exp(a), d/exp(b), e) -> (c*d)/(exp(a)*exp(b)) + e
4974 if (ld.get() && md.get()) {
4975 auto ldle = exp_cast(ld->get_left());
4976 if (ldle.get()) {
4977 auto mdle = exp_cast(md->get_left());
4978 if (mdle.get()) {
4979 return ((ld->get_left()*md->get_left()) /
4980 (ld->get_right()*md->get_right())) +
4981 this->right;
4982 }
4983 auto mdre = exp_cast(md->get_right());
4984 if (mdre.get()) {
4985 return fma(ld->get_left()/md->get_right(),
4986 md->get_left()/ld->get_right(),
4987 this->right);
4988 }
4989 }
4990 auto ldre = exp_cast(ld->get_right());
4991 if (ldre.get()) {
4992 auto mdle = exp_cast(md->get_left());
4993 if (mdle.get()) {
4994 return fma(md->get_left()/ld->get_right(),
4995 ld->get_left()/md->get_right(),
4996 this->right);
4997 }
4998 auto mdre = exp_cast(md->get_right());
4999 if (mdre.get()) {
5000 return ((ld->get_left()*md->get_left()) /
5001 (ld->get_right()*md->get_right())) +
5002 this->right;
5003 }
5004 }
5005 }
5006
5007 return this->shared_from_this();
5008 }
5009
5010//------------------------------------------------------------------------------
5017//------------------------------------------------------------------------------
5020 if (this->is_match(x)) {
5021 return one<T, SAFE_MATH> ();
5022 }
5023
5024 const size_t hash = reinterpret_cast<size_t> (x.get());
5025 if (this->df_cache.find(hash) == this->df_cache.end()) {
5026 auto temp_right = fma(this->left,
5027 this->middle->df(x),
5028 this->right->df(x));
5029
5030 this->df_cache[hash] = fma(this->left->df(x),
5031 this->middle,
5032 temp_right);
5033 }
5034 return this->df_cache[hash];
5035 }
5036
5037//------------------------------------------------------------------------------
5045//------------------------------------------------------------------------------
5047 compile(std::ostringstream &stream,
5048 jit::register_map &registers,
5050 const jit::register_usage &usage) {
5051 if (registers.find(this) == registers.end()) {
5052 shared_leaf<T, SAFE_MATH> l = this->left->compile(stream,
5053 registers,
5054 indices,
5055 usage);
5056 shared_leaf<T, SAFE_MATH> m = this->middle->compile(stream,
5057 registers,
5058 indices,
5059 usage);
5060 shared_leaf<T, SAFE_MATH> r = this->right->compile(stream,
5061 registers,
5062 indices,
5063 usage);
5064
5065 registers[this] = jit::to_string('r', this);
5066 stream << " const ";
5067 jit::add_type<T> (stream);
5068 stream << " " << registers[this] << " = ";
5069 if constexpr (SAFE_MATH) {
5070 stream << "(" << registers[l.get()] << " == ";
5071 if constexpr (jit::complex_scalar<T>) {
5072 jit::add_type<T> (stream);
5073 stream << "(0, 0)";
5074 } else {
5075 stream << "0";
5076 }
5077 stream << " || " << registers[m.get()] << " == ";
5078 if constexpr (jit::complex_scalar<T>) {
5079 jit::add_type<T> (stream);
5080 stream << "(0, 0)";
5081 } else {
5082 stream << "0";
5083 }
5084 stream << ") ? " << registers[r.get()] << " : ";
5085 }
5086 if constexpr (jit::complex_scalar<T>) {
5087 stream << registers[l.get()] << "*"
5088 << registers[m.get()] << " + "
5089 << registers[r.get()];
5090 } else {
5091 stream << "fma("
5092 << registers[l.get()] << ", "
5093 << registers[m.get()] << ", "
5094 << registers[r.get()] << ")";
5095 }
5096 this->endline(stream, usage);
5097 }
5098
5099 return this->shared_from_this();
5100 }
5101
5102//------------------------------------------------------------------------------
5107//------------------------------------------------------------------------------
5109 if (this == x.get()) {
5110 return true;
5111 }
5112
5113 auto x_cast = fma_cast(x);
5114 if (x_cast.get()) {
5115 return this->left->is_match(x_cast->get_left()) &&
5116 this->middle->is_match(x_cast->get_middle()) &&
5117 this->right->is_match(x_cast->get_right());
5118 }
5119
5120 return false;
5121 }
5122
5123//------------------------------------------------------------------------------
5125//------------------------------------------------------------------------------
5126 virtual void to_latex() const {
5127 std::cout << "\\left(";
5128 if (add_cast(this->left).get() ||
5129 subtract_cast(this->left).get()) {
5130 std::cout << "\\left(";
5131 this->left->to_latex();
5132 std::cout << "\\right)";
5133 } else {
5134 this->left->to_latex();
5135 }
5136 std::cout << " ";
5137 if (add_cast(this->middle).get() ||
5138 subtract_cast(this->middle).get()) {
5139 std::cout << "\\left(";
5140 this->middle->to_latex();
5141 std::cout << "\\right)";
5142 } else {
5143 this->middle->to_latex();
5144 }
5145 std::cout << "+";
5146 this->right->to_latex();
5147 std::cout << "\\right)";
5148 }
5149
5150//------------------------------------------------------------------------------
5154//------------------------------------------------------------------------------
5156 if (this->has_pseudo()) {
5157 return fma(this->left->remove_pseudo(),
5158 this->middle->remove_pseudo(),
5159 this->right->remove_pseudo());
5160 }
5161 return this->shared_from_this();
5162 }
5163
5164//------------------------------------------------------------------------------
5170//------------------------------------------------------------------------------
5171 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
5172 jit::register_map &registers) {
5173 if (registers.find(this) == registers.end()) {
5174 const std::string name = jit::to_string('r', this);
5175 registers[this] = name;
5176 stream << " " << name
5177 << " [label = \"fma\", shape = oval, style = filled, fillcolor = blue, fontcolor = white];" << std::endl;
5178
5179 auto l = this->left->to_vizgraph(stream, registers);
5180 stream << " " << name << " -- " << registers[l.get()] << ";" << std::endl;
5181 auto m = this->middle->to_vizgraph(stream, registers);
5182 stream << " " << name << " -- " << registers[m.get()] << ";" << std::endl;
5183 auto r = this->right->to_vizgraph(stream, registers);
5184 stream << " " << name << " -- " << registers[r.get()] << ";" << std::endl;
5185 }
5186
5187 return this->shared_from_this();
5188 }
5189 };
5190
5191//------------------------------------------------------------------------------
5200//------------------------------------------------------------------------------
5201 template<jit::float_scalar T, bool SAFE_MATH=false>
5205 auto temp = std::make_shared<fma_node<T, SAFE_MATH>> (l, m, r)->reduce();
5206// Test for hash collisions.
5207 for (size_t i = temp->get_hash();
5209 if (leaf_node<T, SAFE_MATH>::caches.nodes.find(i) ==
5210 leaf_node<T, SAFE_MATH>::caches.nodes.end()) {
5212 return temp;
5213 } else if (temp->is_match(leaf_node<T, SAFE_MATH>::caches.nodes[i])) {
5214 return leaf_node<T, SAFE_MATH>::caches.nodes[i];
5215 }
5216 }
5217#if defined(__clang__) || defined(__GNUC__)
5219#else
5220 assert(false && "Should never reach.");
5221#endif
5222 }
5223
5224//------------------------------------------------------------------------------
5237//------------------------------------------------------------------------------
5238 template<jit::float_scalar T, jit::float_scalar L, bool SAFE_MATH=false>
5244
5245//------------------------------------------------------------------------------
5258//------------------------------------------------------------------------------
5259 template<jit::float_scalar T, jit::float_scalar M, bool SAFE_MATH=false>
5265
5266//------------------------------------------------------------------------------
5279//------------------------------------------------------------------------------
5280 template<jit::float_scalar T, jit::float_scalar R, bool SAFE_MATH=false>
5286
5287//------------------------------------------------------------------------------
5301//------------------------------------------------------------------------------
5302 template<jit::float_scalar T, jit::float_scalar L, jit::float_scalar M, bool SAFE_MATH=false>
5304 const M m,
5306 return fma<T, SAFE_MATH> (constant<T, SAFE_MATH> (static_cast<T> (l)),
5307 constant<T, SAFE_MATH> (static_cast<T> (m)), r);
5308 }
5309
5310//------------------------------------------------------------------------------
5324//------------------------------------------------------------------------------
5325 template<jit::float_scalar T, jit::float_scalar M, jit::float_scalar R, bool SAFE_MATH=false>
5327 const M m,
5328 const R r) {
5329 return fma<T, SAFE_MATH> (l, constant<T, SAFE_MATH> (static_cast<T> (m)),
5330 constant<T, SAFE_MATH> (static_cast<T> (r)));
5331 }
5332
5333//------------------------------------------------------------------------------
5347//------------------------------------------------------------------------------
5348 template<jit::float_scalar T, jit::float_scalar L, jit::float_scalar R, bool SAFE_MATH=false>
5351 const R r) {
5352 return fma<T, SAFE_MATH> (constant<T, SAFE_MATH> (static_cast<T> (l)), m,
5353 constant<T, SAFE_MATH> (static_cast<T> (r)));
5354 }
5355
5357 template<jit::float_scalar T, bool SAFE_MATH=false>
5358 using shared_fma = std::shared_ptr<fma_node<T, SAFE_MATH>>;
5359
5360//------------------------------------------------------------------------------
5368//------------------------------------------------------------------------------
5369 template<jit::float_scalar T, bool SAFE_MATH=false>
5371 return std::dynamic_pointer_cast<fma_node<T, SAFE_MATH>> (x);
5372 }
5373}
5374
5375#endif /* arithmetic_h */
Class representing a generic buffer.
Definition backend.hpp:29
void subtract_row(const buffer< T > &x)
Subtract row operation.
Definition backend.hpp:375
void multiply_row(const buffer< T > &x)
Multiply row operation.
Definition backend.hpp:447
void add_col(const buffer< T > &x)
Add col operation.
Definition backend.hpp:339
void divide_col(const buffer< T > &x)
Divide col operation.
Definition backend.hpp:555
void add_row(const buffer< T > &x)
Add row operation.
Definition backend.hpp:303
void subtract_col(const buffer< T > &x)
Subtract col operation.
Definition backend.hpp:411
void multiply_col(const buffer< T > &x)
Multiply col operation.
Definition backend.hpp:483
void divide_row(const buffer< T > &x)
Divide row operation.
Definition backend.hpp:519
bool is_zero() const
Is every element zero.
Definition backend.hpp:141
An addition node.
Definition arithmetic.hpp:132
virtual backend::buffer< T > evaluate()
Evaluate the results of addition.
Definition arithmetic.hpp:166
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition arithmetic.hpp:709
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition arithmetic.hpp:591
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Querey if the nodes match.
Definition arithmetic.hpp:645
virtual shared_leaf< T, SAFE_MATH > remove_pseudo()
Remove pseudo variable nodes.
Definition arithmetic.hpp:694
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition arithmetic.hpp:613
add_node(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Construct an addition node.
Definition arithmetic.hpp:154
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduce an addition node.
Definition arithmetic.hpp:177
virtual void to_latex() const
Convert the node to latex.
Definition arithmetic.hpp:667
Class representing a branch node.
Definition node.hpp:1173
shared_leaf< T, SAFE_MATH > right
Right branch of the tree.
Definition node.hpp:1178
shared_leaf< T, SAFE_MATH > left
Left branch of the tree.
Definition node.hpp:1176
A division node.
Definition arithmetic.hpp:2737
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduce an division node.
Definition arithmetic.hpp:2790
virtual shared_leaf< T, SAFE_MATH > remove_pseudo()
Remove pseudo variable nodes.
Definition arithmetic.hpp:3554
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition arithmetic.hpp:3476
divide_node(shared_leaf< T, SAFE_MATH > n, shared_leaf< T, SAFE_MATH > d)
Construct an addition node.
Definition arithmetic.hpp:2759
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Querey if the nodes match.
Definition arithmetic.hpp:3524
virtual backend::buffer< T > evaluate()
Evaluate the results of division.
Definition arithmetic.hpp:2771
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition arithmetic.hpp:3569
virtual void to_latex() const
Convert the node to latex.
Definition arithmetic.hpp:3541
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition arithmetic.hpp:3453
A fused multiply add node.
Definition arithmetic.hpp:3704
virtual shared_leaf< T, SAFE_MATH > remove_pseudo()
Remove pseudo variable nodes.
Definition arithmetic.hpp:5155
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition arithmetic.hpp:5019
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition arithmetic.hpp:5047
virtual backend::buffer< T > evaluate()
Evaluate the results of fused multiply add.
Definition arithmetic.hpp:3780
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition arithmetic.hpp:5171
fma_node(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > m, shared_leaf< T, SAFE_MATH > r)
Construct a fused multiply add node.
Definition arithmetic.hpp:3766
virtual void to_latex() const
Convert the node to latex.
Definition arithmetic.hpp:5126
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduce a fused multiply add node.
Definition arithmetic.hpp:3799
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Querey if the nodes match.
Definition arithmetic.hpp:5108
Class representing a node leaf.
Definition node.hpp:364
virtual void endline(std::ostringstream &stream, const jit::register_usage &usage) const final
End a line in the kernel source.
Definition node.hpp:637
virtual backend::buffer< T > evaluate()=0
Evaluate method.
std::map< size_t, std::shared_ptr< leaf_node< T, SAFE_MATH > > > df_cache
Cache derivative terms.
Definition node.hpp:371
virtual bool has_pseudo() const
Query if the node contains pseudo variables.
Definition node.hpp:618
const size_t hash
Hash for node.
Definition node.hpp:367
A multiplcation node.
Definition arithmetic.hpp:1688
virtual backend::buffer< T > evaluate()
Evaluate the results of multiplcation.
Definition arithmetic.hpp:1827
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition arithmetic.hpp:2461
virtual shared_leaf< T, SAFE_MATH > remove_pseudo()
Remove pseudo variable nodes.
Definition arithmetic.hpp:2589
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduce an multiplcation node.
Definition arithmetic.hpp:1852
multiply_node(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Consruct a multiplcation node.
Definition arithmetic.hpp:1816
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Querey if the nodes match.
Definition arithmetic.hpp:2540
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition arithmetic.hpp:2484
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition arithmetic.hpp:2604
virtual void to_latex() const
Convert the node to latex.
Definition arithmetic.hpp:2562
A subtraction node.
Definition arithmetic.hpp:847
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition arithmetic.hpp:1443
subtract_node(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Consruct a subtraction node.
Definition arithmetic.hpp:869
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition arithmetic.hpp:1421
virtual shared_leaf< T, SAFE_MATH > remove_pseudo()
Remove pseudo variable nodes.
Definition arithmetic.hpp:1519
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Querey if the nodes match.
Definition arithmetic.hpp:1475
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduce an subtraction node.
Definition arithmetic.hpp:892
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition arithmetic.hpp:1534
virtual backend::buffer< T > evaluate()
Evaluate the results of subtraction.
Definition arithmetic.hpp:881
virtual void to_latex() const
Convert the node to latex.
Definition arithmetic.hpp:1492
Class representing a triple branch node.
Definition node.hpp:1297
shared_leaf< T, SAFE_MATH > middle
Middle branch of the tree.
Definition node.hpp:1300
Complex scalar concept.
Definition register.hpp:24
subroutine assert(test, message)
Assert check.
Definition f_binding_test.f90:38
buffer< T > fma(buffer< T > &a, buffer< T > &b, buffer< T > &c)
Fused multiply add operation.
Definition backend.hpp:918
Name space for graph nodes.
Definition arithmetic.hpp:13
shared_piecewise_2D< T, SAFE_MATH > piecewise_2D_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a piecewise 2D node.
Definition piecewise.hpp:1323
shared_leaf< T, SAFE_MATH > operator*(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Build multiply node from two leaves.
Definition arithmetic.hpp:2666
bool is_variable_combineable(shared_leaf< T, SAFE_MATH > a, shared_leaf< T, SAFE_MATH > b)
Check if the variable is combinable.
Definition arithmetic.hpp:75
std::shared_ptr< add_node< T, SAFE_MATH > > shared_add
Convenience type alias for shared add nodes.
Definition arithmetic.hpp:819
shared_pow< T, SAFE_MATH > pow_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a power node.
Definition math.hpp:1416
shared_subtract< T, SAFE_MATH > subtract_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a subtract node.
Definition arithmetic.hpp:1674
constexpr shared_leaf< T, SAFE_MATH > zero()
Forward declare for zero.
Definition node.hpp:994
shared_leaf< T, SAFE_MATH > operator/(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Build divide node from two leaves.
Definition arithmetic.hpp:3631
shared_leaf< T, SAFE_MATH > multiply(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Build multiply node from two leaves.
Definition arithmetic.hpp:2632
shared_add< T, SAFE_MATH > add_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a add node.
Definition arithmetic.hpp:831
bool is_greater_exponent(shared_leaf< T, SAFE_MATH > a, shared_leaf< T, SAFE_MATH > b)
Check if the exponent is greater than the other.
Definition arithmetic.hpp:111
shared_leaf< T, SAFE_MATH > operator+(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Build add node from two leaves.
Definition arithmetic.hpp:774
std::shared_ptr< divide_node< T, SAFE_MATH > > shared_divide
Convenience type alias for shared divide nodes.
Definition arithmetic.hpp:3676
shared_leaf< T, SAFE_MATH > pow(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Build power node.
Definition math.hpp:1352
shared_sine< T, SAFE_MATH > sin_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a sine node.
Definition trigonometry.hpp:262
shared_divide< T, SAFE_MATH > divide_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a divide node.
Definition arithmetic.hpp:3688
shared_piecewise_1D< T, SAFE_MATH > piecewise_1D_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a piecewise 1D node.
Definition piecewise.hpp:601
shared_leaf< T, SAFE_MATH > operator-(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Build subtract node from two leaves.
Definition arithmetic.hpp:1598
shared_leaf< T, SAFE_MATH > divide(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Build divide node from two leaves.
Definition arithmetic.hpp:3597
shared_multiply< T, SAFE_MATH > multiply_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a multiply node.
Definition arithmetic.hpp:2723
shared_leaf< T, SAFE_MATH > exp(shared_leaf< T, SAFE_MATH > x)
Define exp convience function.
Definition math.hpp:544
shared_exp< T, SAFE_MATH > exp_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a exp node.
Definition math.hpp:578
bool is_constant_combineable(shared_leaf< T, SAFE_MATH > a, shared_leaf< T, SAFE_MATH > b)
Check if nodes are constant combineable.
Definition arithmetic.hpp:25
bool is_constant_promotable(shared_leaf< T, SAFE_MATH > a, shared_leaf< T, SAFE_MATH > b)
Check if the constants are promotable.
Definition arithmetic.hpp:53
shared_leaf< T, SAFE_MATH > fma(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > m, shared_leaf< T, SAFE_MATH > r)
Build fused multiply add node.
Definition arithmetic.hpp:5202
shared_constant< T, SAFE_MATH > constant_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a constant node.
Definition node.hpp:1042
constexpr T i
Convinece type for imaginary constant.
Definition node.hpp:1026
shared_leaf< T, SAFE_MATH > subtract(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Build subtract node from two leaves.
Definition arithmetic.hpp:1563
std::shared_ptr< fma_node< T, SAFE_MATH > > shared_fma
Convenience type alias for shared add nodes.
Definition arithmetic.hpp:5358
std::shared_ptr< leaf_node< T, SAFE_MATH > > shared_leaf
Convenience type alias for shared leaf nodes.
Definition node.hpp:673
bool is_variable_promotable(shared_leaf< T, SAFE_MATH > a, shared_leaf< T, SAFE_MATH > b)
Check if the variable is variable is promotable.
Definition arithmetic.hpp:91
std::shared_ptr< multiply_node< T, SAFE_MATH > > shared_multiply
Convenience type alias for shared multiply nodes.
Definition arithmetic.hpp:2711
shared_leaf< T, SAFE_MATH > add(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Build add node from two leaves.
Definition arithmetic.hpp:740
shared_fma< T, SAFE_MATH > fma_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a fma node.
Definition arithmetic.hpp:5370
shared_cosine< T, SAFE_MATH > cos_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a cosine node.
Definition trigonometry.hpp:514
std::shared_ptr< subtract_node< T, SAFE_MATH > > shared_subtract
Convenience type alias for shared subtract nodes.
Definition arithmetic.hpp:1662
std::string format_to_string(const T value)
Convert a value to a string while avoiding locale.
Definition register.hpp:211
std::map< void *, size_t > register_usage
Type alias for counting register usage.
Definition register.hpp:258
std::map< void *, std::string > register_map
Type alias for mapping node pointers to register names.
Definition register.hpp:256
std::string to_string(const char prefix, const NODE *pointer)
Convert a graph::leaf_node pointer to a string.
Definition register.hpp:245
Base nodes of graph computation framework.
void piecewise_1D()
Tests for 1D piecewise nodes.
Definition piecewise_test.cpp:80
void piecewise_2D()
Tests for 2D piecewise nodes.
Definition piecewise_test.cpp:283