Graph Framework
Loading...
Searching...
No Matches
node.hpp
Go to the documentation of this file.
1//------------------------------------------------------------------------------
6//------------------------------------------------------------------------------
7//------------------------------------------------------------------------------
340//------------------------------------------------------------------------------
341#ifndef node_h
342#define node_h
343
344#include <iostream>
345#include <string>
346#include <memory>
347#include <iomanip>
348#include <functional>
349
350#include "backend.hpp"
351
353namespace graph {
354//******************************************************************************
355// Base leaf node.
356//******************************************************************************
357//------------------------------------------------------------------------------
362//------------------------------------------------------------------------------
363 template<jit::float_scalar T, bool SAFE_MATH=false>
364 class leaf_node : public std::enable_shared_from_this<leaf_node<T, SAFE_MATH>> {
365 protected:
367 const size_t hash;
369 const size_t complexity;
371 std::map<size_t, std::shared_ptr<leaf_node<T, SAFE_MATH>>> df_cache;
373 const bool contains_pseudo;
374
375 public:
376//------------------------------------------------------------------------------
382//------------------------------------------------------------------------------
383 leaf_node(const std::string s,
384 const size_t count,
385 const bool pseudo) :
386 hash(std::hash<std::string>{} (s)),
388
389//------------------------------------------------------------------------------
391//------------------------------------------------------------------------------
392 virtual ~leaf_node() {}
393
394//------------------------------------------------------------------------------
398//------------------------------------------------------------------------------
400
401//------------------------------------------------------------------------------
407//------------------------------------------------------------------------------
408 virtual std::shared_ptr<leaf_node> reduce() = 0;
409
410//------------------------------------------------------------------------------
415//------------------------------------------------------------------------------
416 virtual std::shared_ptr<leaf_node<T, SAFE_MATH>>
417 df(std::shared_ptr<leaf_node<T, SAFE_MATH>> x) = 0;
418
419//------------------------------------------------------------------------------
432//------------------------------------------------------------------------------
433 virtual void compile_preamble(std::ostringstream &stream,
434 jit::register_map &registers,
439 int &avail_const_mem) {
440#ifdef SHOW_USE_COUNT
441 if (usage.find(this) == usage.end()) {
442 usage[this] = 1;
443 } else {
444 ++usage[this];
445 }
446#endif
447 }
448
449//------------------------------------------------------------------------------
457//------------------------------------------------------------------------------
458 virtual std::shared_ptr<leaf_node<T, SAFE_MATH>>
459 compile(std::ostringstream &stream,
460 jit::register_map &registers,
462 const jit::register_usage &usage) = 0;
463
464//------------------------------------------------------------------------------
469//------------------------------------------------------------------------------
470 virtual bool is_match(std::shared_ptr<leaf_node<T, SAFE_MATH>> x) {
471 return this == x.get();
472 }
473
474//------------------------------------------------------------------------------
479//------------------------------------------------------------------------------
481 return this->get_power_base()->is_match(x->get_power_base());
482 }
483
484//------------------------------------------------------------------------------
488//------------------------------------------------------------------------------
489 virtual void set(const T d) {}
490
491//------------------------------------------------------------------------------
496//------------------------------------------------------------------------------
497 virtual void set(const size_t index,
498 const T d) {}
499
500//------------------------------------------------------------------------------
504//------------------------------------------------------------------------------
505 virtual void set(const std::vector<T> &d) {}
506
507//------------------------------------------------------------------------------
511//------------------------------------------------------------------------------
512 virtual void set(const backend::buffer<T> &d) {}
513
514//------------------------------------------------------------------------------
516//------------------------------------------------------------------------------
517 virtual void to_latex() const = 0;
518
519//------------------------------------------------------------------------------
525//------------------------------------------------------------------------------
526 virtual std::shared_ptr<leaf_node<T, SAFE_MATH>> to_vizgraph(std::stringstream &stream,
527 jit::register_map &registers) = 0;
528
529//------------------------------------------------------------------------------
533//------------------------------------------------------------------------------
534 virtual bool is_constant() const {
535 return false;
536 }
537
538//------------------------------------------------------------------------------
542//------------------------------------------------------------------------------
543 virtual bool has_constant_zero() const {
544 return false;
545 }
546
547//------------------------------------------------------------------------------
551//------------------------------------------------------------------------------
552 bool is_normal() {
553 return this->evaluate().is_normal();
554 }
555
556//------------------------------------------------------------------------------
560//------------------------------------------------------------------------------
561 virtual bool is_all_variables() const = 0;
562
563//------------------------------------------------------------------------------
569//------------------------------------------------------------------------------
570 virtual bool is_power_like() const {
571 return false;
572 }
573
574//------------------------------------------------------------------------------
580//------------------------------------------------------------------------------
581 virtual std::shared_ptr<leaf_node<T, SAFE_MATH>> get_power_base() {
582 return this->shared_from_this();
583 }
584
585//------------------------------------------------------------------------------
592//------------------------------------------------------------------------------
593 virtual std::shared_ptr<leaf_node<T, SAFE_MATH>> get_power_exponent() const = 0;
594
595//------------------------------------------------------------------------------
599//------------------------------------------------------------------------------
600 size_t get_hash() const {
601 return hash;
602 }
603
604//------------------------------------------------------------------------------
608//------------------------------------------------------------------------------
609 size_t get_complexity() const {
610 return complexity;
611 }
612
613//------------------------------------------------------------------------------
617//------------------------------------------------------------------------------
618 virtual bool has_pseudo() const {
619 return contains_pseudo;
620 }
621
622//------------------------------------------------------------------------------
626//------------------------------------------------------------------------------
627 virtual std::shared_ptr<leaf_node<T, SAFE_MATH>> remove_pseudo() {
628 return this->shared_from_this();
629 }
630
631//------------------------------------------------------------------------------
636//------------------------------------------------------------------------------
637 virtual void endline(std::ostringstream &stream,
639#ifndef SHOW_USE_COUNT
640 const
641#endif
642 final {
643 stream << ";"
644#ifdef SHOW_USE_COUNT
645 << " // used " << usage.at(this)
646#endif
647 << std::endl;
648 }
649
650// Create one struct that holds both caches: for constructed nodes and for the
651// backend buffers
652//------------------------------------------------------------------------------
657//------------------------------------------------------------------------------
658 struct caches_t {
660 std::map<size_t, std::shared_ptr<leaf_node<T, SAFE_MATH>>> nodes;
662 std::map<size_t, backend::buffer<T>> backends;
663 };
664
666 inline static thread_local caches_t caches;
667
669 typedef T base;
670 };
671
673 template<jit::float_scalar T, bool SAFE_MATH=false>
674 using shared_leaf = std::shared_ptr<leaf_node<T, SAFE_MATH>>;
675//------------------------------------------------------------------------------
682//------------------------------------------------------------------------------
683 template<jit::float_scalar T, bool SAFE_MATH=false>
688 template<jit::float_scalar T, bool SAFE_MATH=false>
689 using output_nodes = std::vector<shared_leaf<T, SAFE_MATH>>;
690
692 template<jit::float_scalar T, bool SAFE_MATH=false>
695 template<jit::float_scalar T, bool SAFE_MATH=false>
696 constexpr shared_leaf<T, SAFE_MATH> one();
697
698//------------------------------------------------------------------------------
702//------------------------------------------------------------------------------
703 template<jit::float_scalar T, bool SAFE_MATH=false>
705 std::stringstream stream;
706 jit::register_map registers;
708
709 stream << "graph \"\" {" << std::endl;
710 stream << " node [fontname = \"Helvetica\", ordering = out]" << std::endl << std::endl;
711 node->to_vizgraph(stream, registers);
712 stream << "}" << std::endl;
713
714 std::cout << stream.str() << std::endl;
715 }
716
717//******************************************************************************
718// Constant node.
719//******************************************************************************
720//------------------------------------------------------------------------------
725//------------------------------------------------------------------------------
726 template<jit::float_scalar T, bool SAFE_MATH=false>
727 class constant_node final : public leaf_node<T, SAFE_MATH> {
728//------------------------------------------------------------------------------
733//------------------------------------------------------------------------------
734 static std::string to_string(const T d) {
735 return jit::format_to_string<T> (d);
736 }
737
738 private:
740 const backend::buffer<T> data;
741
742 public:
743//------------------------------------------------------------------------------
747//------------------------------------------------------------------------------
749 leaf_node<T, SAFE_MATH> (constant_node::to_string(d.at(0)), 1, false), data(d) {
750 assert(d.size() == 1 && "Constants need to be scalar functions.");
751 }
752
753//------------------------------------------------------------------------------
757//------------------------------------------------------------------------------
759 return data;
760 }
761
762//------------------------------------------------------------------------------
768//------------------------------------------------------------------------------
770 return this->shared_from_this();
771 }
772
773//------------------------------------------------------------------------------
778//------------------------------------------------------------------------------
782
783//------------------------------------------------------------------------------
791//------------------------------------------------------------------------------
793 compile(std::ostringstream &stream,
794 jit::register_map &registers,
796 const jit::register_usage &usage) {
797 if (registers.find(this) == registers.end()) {
798#ifdef USE_CONSTANT_CACHE
799 registers[this] = jit::to_string('r', this);
800 stream << " const ";
801 jit::add_type<T> (stream);
802 const T temp = this->evaluate().at(0);
803
804 stream << " " << registers[this] << " = ";
805 if constexpr (jit::complex_scalar<T>) {
806 jit::add_type<T> (stream);
807 }
808 stream << temp;
809 this->endline(stream, usage);
810#else
811 if constexpr (jit::complex_scalar<T>) {
812 registers[this] = jit::get_type_string<T> () + "("
813 + jit::format_to_string(this->evaluate().at(0))
814 + ")";
815 } else {
816 registers[this] = "(" + jit::get_type_string<T> () + ")"
817 + jit::format_to_string(this->evaluate().at(0));
818 }
819#endif
820 }
821
822 return this->shared_from_this();
823 }
824
825//------------------------------------------------------------------------------
830//------------------------------------------------------------------------------
832 if (this == x.get()) {
833 return true;
834 }
835
836 auto x_cast = constant_cast(x);
837 if (x_cast.get()) {
838 return this->evaluate() == x_cast->evaluate();
839 }
840
841 return false;
842 }
843
844//------------------------------------------------------------------------------
846//------------------------------------------------------------------------------
847 bool is(const T d) {
848 return data.size() == 1 && data.at(0) == d;
849 }
850
851//------------------------------------------------------------------------------
853//------------------------------------------------------------------------------
854 bool is_integer() {
855 const auto temp = this->evaluate().at(0);
856 return std::imag(temp) == 0 &&
857 fmod(std::real(temp), 1.0) == 0.0;
858 }
859
860//------------------------------------------------------------------------------
862//------------------------------------------------------------------------------
863 virtual void to_latex() const {
864 std::cout << data.at(0);
865 }
866
867//------------------------------------------------------------------------------
873//------------------------------------------------------------------------------
874 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
875 jit::register_map &registers) {
876 if (registers.find(this) == registers.end()) {
877 const std::string name = jit::to_string('r', this);
878 registers[this] = name;
879 stream << " " << name
880 << " [label = \"" << this->evaluate().at(0)
881 << "\", shape = box, style = \"rounded,filled\", fillcolor = black, fontcolor = white];" << std::endl;
882 }
883
884 return this->shared_from_this();
885 }
886
887//------------------------------------------------------------------------------
891//------------------------------------------------------------------------------
892 virtual bool is_constant() const {
893 return true;
894 }
895
896//------------------------------------------------------------------------------
900//------------------------------------------------------------------------------
901 virtual bool has_constant_zero() const {
902 return data.has_zero();
903 }
904
905//------------------------------------------------------------------------------
909//------------------------------------------------------------------------------
910 virtual bool is_all_variables() const {
911 return false;
912 }
913
914//------------------------------------------------------------------------------
918//------------------------------------------------------------------------------
919 virtual bool is_power_like() const {
920 return true;
921 }
922
923//------------------------------------------------------------------------------
927//------------------------------------------------------------------------------
929 return this->shared_from_this();
930 }
931
932//------------------------------------------------------------------------------
936//------------------------------------------------------------------------------
938 return one<T, SAFE_MATH> ();
939 }
940 };
941
942//------------------------------------------------------------------------------
950//------------------------------------------------------------------------------
951 template<jit::float_scalar T, bool SAFE_MATH=false>
953 auto temp = std::make_shared<constant_node<T, SAFE_MATH>> (d);
954// Test for hash collisions.
955 for (size_t i = temp->get_hash(); i < std::numeric_limits<size_t>::max(); i++) {
956 if (leaf_node<T, SAFE_MATH>::caches.nodes.find(i) ==
959 return temp;
960 } else if (temp->is_match(leaf_node<T, SAFE_MATH>::caches.nodes[i])) {
962 }
963 }
964#if defined(__clang__) || defined(__GNUC__)
966#else
967 assert(false && "Should never reach.");
968#endif
969 }
970
971//------------------------------------------------------------------------------
979//------------------------------------------------------------------------------
980 template<jit::float_scalar T, bool SAFE_MATH=false>
984
985// Define some common constants.
986//------------------------------------------------------------------------------
993//------------------------------------------------------------------------------
994 template<jit::float_scalar T, bool SAFE_MATH>
996 return constant<T, SAFE_MATH> (static_cast<T> (0.0));
997 }
998
999//------------------------------------------------------------------------------
1006//------------------------------------------------------------------------------
1007 template<jit::float_scalar T, bool SAFE_MATH>
1009 return constant<T, SAFE_MATH> (static_cast<T> (1.0));
1010 }
1011
1012//------------------------------------------------------------------------------
1019//------------------------------------------------------------------------------
1020 template<jit::float_scalar T, bool SAFE_MATH=false>
1022 return constant<T, SAFE_MATH> (static_cast<T> (-1.0));
1023 }
1024
1026 template<jit::complex_scalar T>
1027 constexpr T i = T(0.0, 1.0);
1028
1030 template<jit::float_scalar T, bool SAFE_MATH=false>
1031 using shared_constant = std::shared_ptr<constant_node<T, SAFE_MATH>>;
1032
1033//------------------------------------------------------------------------------
1041//------------------------------------------------------------------------------
1042 template<jit::float_scalar T, bool SAFE_MATH=false>
1044 return std::dynamic_pointer_cast<constant_node<T, SAFE_MATH>> (x);
1045 }
1046
1047//******************************************************************************
1048// Base straight node.
1049//******************************************************************************
1050//------------------------------------------------------------------------------
1058//------------------------------------------------------------------------------
1059 template<jit::float_scalar T, bool SAFE_MATH=false>
1060 class straight_node : public leaf_node<T, SAFE_MATH> {
1061 protected:
1064
1065 public:
1066//------------------------------------------------------------------------------
1071//------------------------------------------------------------------------------
1073 const std::string s) :
1074 leaf_node<T, SAFE_MATH> (s, a->get_complexity() + 1, a->has_pseudo()),
1075 arg(a) {}
1076
1077//------------------------------------------------------------------------------
1081//------------------------------------------------------------------------------
1083 return this->arg->evaluate();
1084 }
1085
1086//------------------------------------------------------------------------------
1096//------------------------------------------------------------------------------
1097 virtual void compile_preamble(std::ostringstream &stream,
1098 jit::register_map &registers,
1103 int &avail_const_mem) {
1104 if (visited.find(this) == visited.end()) {
1105 this->arg->compile_preamble(stream, registers,
1106 visited, usage,
1109 visited.insert(this);
1110#ifdef SHOW_USE_COUNT
1111 usage[this] = 1;
1112 } else {
1113 ++usage[this];
1114#endif
1115 }
1116 }
1117
1118//------------------------------------------------------------------------------
1126//------------------------------------------------------------------------------
1128 compile(std::ostringstream &stream,
1129 jit::register_map &registers,
1131 const jit::register_usage &usage) {
1132 return this->arg->compile(stream, registers, indices, usage);
1133 }
1134
1135//------------------------------------------------------------------------------
1137//------------------------------------------------------------------------------
1139 return this->arg;
1140 }
1141
1142//------------------------------------------------------------------------------
1146//------------------------------------------------------------------------------
1147 virtual bool is_all_variables() const {
1148 return this->arg->is_all_variables();
1149 }
1150
1151//------------------------------------------------------------------------------
1155//------------------------------------------------------------------------------
1157 return one<T, SAFE_MATH> ();
1158 }
1159 };
1160
1161//******************************************************************************
1162// Base branch node.
1163//******************************************************************************
1164//------------------------------------------------------------------------------
1172//------------------------------------------------------------------------------
1173 template<jit::float_scalar T, bool SAFE_MATH=false>
1174 class branch_node : public leaf_node<T, SAFE_MATH> {
1175 protected:
1180
1181 public:
1182
1183//------------------------------------------------------------------------------
1189//------------------------------------------------------------------------------
1192 const std::string s) :
1194 l->has_pseudo() || r->has_pseudo()),
1195 left(l), right(r) {}
1196
1197//------------------------------------------------------------------------------
1205//------------------------------------------------------------------------------
1208 const std::string s,
1209 const size_t count,
1210 const bool pseudo) :
1211 leaf_node<T, SAFE_MATH> (s, count, pseudo),
1212 left(l), right(r) {}
1213
1214//------------------------------------------------------------------------------
1224//------------------------------------------------------------------------------
1225 virtual void compile_preamble(std::ostringstream &stream,
1226 jit::register_map &registers,
1231 int &avail_const_mem) {
1232 if (visited.find(this) == visited.end()) {
1233 this->left->compile_preamble(stream, registers,
1234 visited, usage,
1237 this->right->compile_preamble(stream, registers,
1238 visited, usage,
1241 visited.insert(this);
1242#ifdef SHOW_USE_COUNT
1243 usage[this] = 1;
1244 } else {
1245 ++usage[this];
1246#endif
1247 }
1248 }
1249
1250//------------------------------------------------------------------------------
1252//------------------------------------------------------------------------------
1254 return this->left;
1255 }
1256
1257//------------------------------------------------------------------------------
1259//------------------------------------------------------------------------------
1261 return this->right;
1262 }
1263
1264//------------------------------------------------------------------------------
1268//------------------------------------------------------------------------------
1269 virtual bool is_all_variables() const {
1270 return this->left->is_all_variables() &&
1271 this->right->is_all_variables();
1272 }
1273
1274//------------------------------------------------------------------------------
1278//------------------------------------------------------------------------------
1279 virtual std::shared_ptr<leaf_node<T, SAFE_MATH>>
1281 return one<T, SAFE_MATH> ();
1282 }
1283 };
1284
1285//******************************************************************************
1286// Base triple node.
1287//******************************************************************************
1288//------------------------------------------------------------------------------
1296//------------------------------------------------------------------------------
1297 template<jit::float_scalar T, bool SAFE_MATH=false>
1298 class triple_node : public branch_node<T, SAFE_MATH> {
1299 protected:
1302
1303 public:
1304
1305//------------------------------------------------------------------------------
1312//------------------------------------------------------------------------------
1316 const std::string s) :
1317 branch_node<T, SAFE_MATH> (l, r, s,
1318 l->get_complexity() +
1319 m->get_complexity() +
1320 r->get_complexity(),
1321 l->has_pseudo() ||
1322 m->has_pseudo() ||
1323 r->has_pseudo()),
1324 middle(m) {}
1325
1326//------------------------------------------------------------------------------
1336//------------------------------------------------------------------------------
1337 virtual void compile_preamble(std::ostringstream &stream,
1338 jit::register_map &registers,
1343 int &avail_const_mem) {
1344 if (visited.find(this) == visited.end()) {
1345 this->left->compile_preamble(stream, registers,
1346 visited, usage,
1349 this->middle->compile_preamble(stream, registers,
1350 visited, usage,
1353 this->right->compile_preamble(stream, registers,
1354 visited, usage,
1357 visited.insert(this);
1358#ifdef SHOW_USE_COUNT
1359 usage[this] = 1;
1360 } else {
1361 ++usage[this];
1362#endif
1363 }
1364 }
1365
1366//------------------------------------------------------------------------------
1368//------------------------------------------------------------------------------
1370 return this->middle;
1371 }
1372
1373//------------------------------------------------------------------------------
1377//------------------------------------------------------------------------------
1378 virtual bool is_all_variables() const {
1379 return this->left->is_all_variables() &&
1380 this->middle->is_all_variables() &&
1381 this->right->is_all_variables();
1382 }
1383 };
1384
1385//******************************************************************************
1386// Variable node.
1387//******************************************************************************
1388//------------------------------------------------------------------------------
1393//------------------------------------------------------------------------------
1394 template<jit::float_scalar T, bool SAFE_MATH=false>
1395 class variable_node final : public leaf_node<T, SAFE_MATH> {
1396 private:
1398 backend::buffer<T> buffer;
1400 const std::string symbol;
1401
1402//------------------------------------------------------------------------------
1407//------------------------------------------------------------------------------
1408 static std::string to_string(variable_node<T, SAFE_MATH> *p) {
1409 return jit::format_to_string(reinterpret_cast<size_t> (p));
1410 }
1411
1412 public:
1413//------------------------------------------------------------------------------
1418//------------------------------------------------------------------------------
1419 variable_node(const size_t s,
1420 const std::string &symbol) :
1421 leaf_node<T, SAFE_MATH> (variable_node::to_string(this), 1, false),
1422 buffer(s), symbol(symbol) {}
1423
1424//------------------------------------------------------------------------------
1430//------------------------------------------------------------------------------
1431 variable_node(const size_t s, const T d,
1432 const std::string &symbol) :
1433 leaf_node<T, SAFE_MATH> (variable_node::to_string(this), 1, false),
1434 buffer(s, d), symbol(symbol) {
1435 assert(buffer.is_normal() && "NaN or Inf value.");
1436 }
1437
1438//------------------------------------------------------------------------------
1443//------------------------------------------------------------------------------
1444 variable_node(const std::vector<T> &d,
1445 const std::string &symbol) :
1446 leaf_node<T, SAFE_MATH> (variable_node::to_string(this), 1, false),
1447 buffer(d), symbol(symbol) {
1448 assert(buffer.is_normal() && "NaN or Inf value.");
1449 }
1450
1451//------------------------------------------------------------------------------
1456//------------------------------------------------------------------------------
1458 const std::string &symbol) :
1459 leaf_node<T, SAFE_MATH> (variable_node::to_string(this), 1, false),
1460 buffer(d), symbol(symbol) {
1461 assert(buffer.is_normal() && "NaN or Inf value.");
1462 }
1463
1464//------------------------------------------------------------------------------
1468//------------------------------------------------------------------------------
1470 return buffer;
1471 }
1472
1473//------------------------------------------------------------------------------
1479//------------------------------------------------------------------------------
1481 return this->shared_from_this();
1482 }
1483
1484//------------------------------------------------------------------------------
1489//------------------------------------------------------------------------------
1491 return constant<T, SAFE_MATH> (static_cast<T> (this->is_match(x)));
1492 }
1493
1494//------------------------------------------------------------------------------
1507//------------------------------------------------------------------------------
1508 virtual void compile_preamble(std::ostringstream &stream,
1509 jit::register_map &registers,
1514 int &avail_const_mem) {
1515 if (usage.find(this) == usage.end()) {
1516 usage[this] = 1;
1517#ifdef SHOW_USE_COUNT
1518 } else {
1519 ++usage[this];
1520#endif
1521 }
1522 }
1523
1524//------------------------------------------------------------------------------
1532//------------------------------------------------------------------------------
1534 compile(std::ostringstream &stream,
1535 jit::register_map &registers,
1537 const jit::register_usage &usage) {
1538 return this->shared_from_this();
1539 }
1540
1541//------------------------------------------------------------------------------
1545//------------------------------------------------------------------------------
1546 virtual void set(const T d) {
1547 buffer.set(d);
1548 }
1549
1550//------------------------------------------------------------------------------
1555//------------------------------------------------------------------------------
1556 virtual void set(const size_t index, const T d) {
1557 buffer[index] = d;
1558 }
1559
1560//------------------------------------------------------------------------------
1564//------------------------------------------------------------------------------
1565 virtual void set(const std::vector<T> &d) {
1566 buffer.set(d);
1567 }
1568
1569//------------------------------------------------------------------------------
1573//------------------------------------------------------------------------------
1574 virtual void set(const backend::buffer<T> &d) {
1575 buffer = d;
1576 }
1577
1578//------------------------------------------------------------------------------
1580//------------------------------------------------------------------------------
1581 std::string get_symbol() const {
1582 return symbol;
1583 }
1584
1585//------------------------------------------------------------------------------
1587//------------------------------------------------------------------------------
1588 virtual void to_latex() const {
1589 std::cout << get_symbol();
1590 }
1591
1592//------------------------------------------------------------------------------
1598//------------------------------------------------------------------------------
1599 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
1600 jit::register_map &registers) {
1601 if (registers.find(this) == registers.end()) {
1602 const std::string name = jit::to_string('r', this);
1603 registers[this] = name;
1604 stream << " " << name
1605 << " [label = \"" << this->get_symbol()
1606 << "\", shape = box];" << std::endl;
1607 }
1608
1609 return this->shared_from_this();
1610 }
1611
1612//------------------------------------------------------------------------------
1614//------------------------------------------------------------------------------
1615 size_t size() {
1616 return buffer.size();
1617 }
1618
1619//------------------------------------------------------------------------------
1623//------------------------------------------------------------------------------
1624 T *data() {
1625 return buffer.data();
1626 }
1627
1628//------------------------------------------------------------------------------
1632//------------------------------------------------------------------------------
1633 virtual bool is_all_variables() const {
1634 return true;
1635 }
1636
1637//------------------------------------------------------------------------------
1641//------------------------------------------------------------------------------
1642 virtual bool is_power_like() const {
1643 return true;
1644 }
1645
1646//------------------------------------------------------------------------------
1650//------------------------------------------------------------------------------
1652 return this->shared_from_this();
1653 }
1654
1655//------------------------------------------------------------------------------
1659//------------------------------------------------------------------------------
1661 return one<T, SAFE_MATH> ();
1662 }
1663 };
1664
1665//------------------------------------------------------------------------------
1673//------------------------------------------------------------------------------
1674 template<jit::float_scalar T, bool SAFE_MATH=false>
1676 const std::string &symbol) {
1677 return std::make_shared<variable_node<T, SAFE_MATH>> (s, symbol);
1678 }
1679
1680//------------------------------------------------------------------------------
1689//------------------------------------------------------------------------------
1690 template<jit::float_scalar T, bool SAFE_MATH=false>
1691 shared_leaf<T, SAFE_MATH> variable(const size_t s, const T d,
1692 const std::string &symbol) {
1693 return std::make_shared<variable_node<T, SAFE_MATH>> (s, d, symbol);
1694 }
1695
1696//------------------------------------------------------------------------------
1704//------------------------------------------------------------------------------
1705 template<jit::float_scalar T, bool SAFE_MATH=false>
1706 shared_leaf<T, SAFE_MATH> variable(const std::vector<T> &d,
1707 const std::string &symbol) {
1708 return std::make_shared<variable_node<T, SAFE_MATH>> (d, symbol);
1709 }
1710
1711//------------------------------------------------------------------------------
1719//------------------------------------------------------------------------------
1720 template<jit::float_scalar T, bool SAFE_MATH=false>
1722 const std::string &symbol) {
1723 return std::make_shared<variable_node<T, SAFE_MATH>> (d, symbol);
1724 }
1725
1727 template<jit::float_scalar T, bool SAFE_MATH=false>
1728 using shared_variable = std::shared_ptr<variable_node<T, SAFE_MATH>>;
1730 template<jit::float_scalar T, bool SAFE_MATH=false>
1731 using input_nodes = std::vector<shared_variable<T, SAFE_MATH>>;
1733 template<jit::float_scalar T, bool SAFE_MATH=false>
1734 using map_nodes = std::vector<std::pair<shared_leaf<T, SAFE_MATH>,
1736
1737//------------------------------------------------------------------------------
1745//------------------------------------------------------------------------------
1746 template<jit::float_scalar T, bool SAFE_MATH=false>
1748 return std::dynamic_pointer_cast<variable_node<T, SAFE_MATH>> (x);
1749 }
1750
1751//******************************************************************************
1752// Pseudo variable node.
1753//******************************************************************************
1754//------------------------------------------------------------------------------
1763//------------------------------------------------------------------------------
1764 template<jit::float_scalar T, bool SAFE_MATH=false>
1765 class pseudo_variable_node final : public straight_node<T, SAFE_MATH> {
1766 private:
1767//------------------------------------------------------------------------------
1772//------------------------------------------------------------------------------
1773 static std::string to_string(leaf_node<T, SAFE_MATH> *p) {
1774 return jit::format_to_string(reinterpret_cast<size_t> (p));
1775 }
1776
1777 public:
1778//------------------------------------------------------------------------------
1782//------------------------------------------------------------------------------
1785
1786//------------------------------------------------------------------------------
1792//------------------------------------------------------------------------------
1794 return this->shared_from_this();
1795 }
1796
1797//------------------------------------------------------------------------------
1802//------------------------------------------------------------------------------
1804 return constant<T, SAFE_MATH> (static_cast<T> (this->is_match(x)));
1805 }
1806
1807//------------------------------------------------------------------------------
1809//------------------------------------------------------------------------------
1810 virtual void to_latex() const {
1811 std::cout << "\\left(";
1812 this->arg->to_latex();
1813 std::cout << "\\right)";
1814 }
1815
1816//------------------------------------------------------------------------------
1820//------------------------------------------------------------------------------
1821 virtual bool is_all_variables() const {
1822 return true;
1823 }
1824
1825//------------------------------------------------------------------------------
1829//------------------------------------------------------------------------------
1830 virtual bool is_power_like() const {
1831 return true;
1832 }
1833
1834//------------------------------------------------------------------------------
1838//------------------------------------------------------------------------------
1840 return this->arg->get_power_base();
1841 }
1842
1843//------------------------------------------------------------------------------
1847//------------------------------------------------------------------------------
1849 return this->arg->get_power_exponent();
1850 }
1851
1852//------------------------------------------------------------------------------
1856//------------------------------------------------------------------------------
1857 virtual bool has_pseudo() const {
1858 return true;
1859 }
1860
1861//------------------------------------------------------------------------------
1865//------------------------------------------------------------------------------
1867 return this->arg;
1868 }
1869
1870//------------------------------------------------------------------------------
1876//------------------------------------------------------------------------------
1877 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
1878 jit::register_map &registers) {
1879 if (registers.find(this) == registers.end()) {
1880 const std::string name = jit::to_string('r', this);
1881 registers[this] = name;
1882 stream << " " << name
1883 << " [label = \"pseudo_variable\", shape = oval, style = filled, fillcolor = blue, fontcolor = white];" << std::endl;
1884
1885 auto a = this->arg->to_vizgraph(stream, registers);
1886 stream << " " << name << " -- " << registers[a.get()] << ";" << std::endl;
1887 }
1888
1889 return this->shared_from_this();
1890 }
1891 };
1892
1893//------------------------------------------------------------------------------
1901//------------------------------------------------------------------------------
1902 template<jit::float_scalar T, bool SAFE_MATH=false>
1904 return std::make_shared<pseudo_variable_node<T, SAFE_MATH>> (x);
1905 }
1906
1908 template<jit::float_scalar T, bool SAFE_MATH=false>
1909 using shared_pseudo_variable = std::shared_ptr<pseudo_variable_node<T, SAFE_MATH>>;
1910
1911//------------------------------------------------------------------------------
1919//------------------------------------------------------------------------------
1920 template<jit::float_scalar T, bool SAFE_MATH=false>
1922 return std::dynamic_pointer_cast<pseudo_variable_node<T, SAFE_MATH>> (x);
1923 }
1924}
1925
1926#endif /* node_h */
Class signature to implement compute backends.
Class representing a generic buffer.
Definition backend.hpp:29
bool is_normal() const
Check for normal values.
Definition backend.hpp:279
const T at(const size_t index) const
Get value at.
Definition backend.hpp:91
bool has_zero() const
Is any element zero.
Definition backend.hpp:156
size_t size() const
Get size of the buffer.
Definition backend.hpp:116
void set(const T d)
Assign a constant value.
Definition backend.hpp:100
T * data()
Get a pointer to the basic memory buffer.
Definition backend.hpp:270
Class representing a branch node.
Definition node.hpp:1174
shared_leaf< T, SAFE_MATH > get_left()
Get the left branch.
Definition node.hpp:1253
virtual bool is_all_variables() const
Test if node acts like a variable.
Definition node.hpp:1269
virtual void compile_preamble(std::ostringstream &stream, jit::register_map &registers, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d, int &avail_const_mem)
Compile preamble.
Definition node.hpp:1225
branch_node(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r, const std::string s, const size_t count, const bool pseudo)
Assigns the left and right branches.
Definition node.hpp:1206
shared_leaf< T, SAFE_MATH > get_right()
Get the right branch.
Definition node.hpp:1260
virtual std::shared_ptr< leaf_node< T, SAFE_MATH > > get_power_exponent() const
Get the exponent of a power.
Definition node.hpp:1280
branch_node(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r, const std::string s)
Assigns the left and right branches.
Definition node.hpp:1190
shared_leaf< T, SAFE_MATH > right
Right branch of the tree.
Definition node.hpp:1179
shared_leaf< T, SAFE_MATH > left
Left branch of the tree.
Definition node.hpp:1177
Class representing data that cannot change.
Definition node.hpp:727
virtual shared_leaf< T, SAFE_MATH > get_power_exponent() const
Get the exponent of a power.
Definition node.hpp:937
virtual void to_latex() const
Convert the node to latex.
Definition node.hpp:863
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition node.hpp:793
virtual shared_leaf< T, SAFE_MATH > get_power_base()
Get the base of a power.
Definition node.hpp:928
bool is(const T d)
Check if the constant is value.
Definition node.hpp:847
virtual backend::buffer< T > evaluate()
Evaluate method.
Definition node.hpp:758
virtual bool is_all_variables() const
Test if node acts like a variable.
Definition node.hpp:910
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Query if the nodes match.
Definition node.hpp:831
constant_node(const backend::buffer< T > &d)
Construct a constant node from a vector.
Definition node.hpp:748
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition node.hpp:874
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition node.hpp:779
virtual bool is_constant() const
Test if node is a constant.
Definition node.hpp:892
virtual bool has_constant_zero() const
Test the constant node has a zero.
Definition node.hpp:901
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduction method.
Definition node.hpp:769
bool is_integer()
Check if the value is an integer.
Definition node.hpp:854
virtual bool is_power_like() const
Test if the node acts like a power of variable.
Definition node.hpp:919
Class representing a node leaf.
Definition node.hpp:364
virtual ~leaf_node()
Destructor.
Definition node.hpp:392
virtual void endline(std::ostringstream &stream, const jit::register_usage &usage) const final
End a line in the kernel source.
Definition node.hpp:637
virtual void set(const std::vector< T > &d)
Set the value of variable data.
Definition node.hpp:505
virtual bool is_power_like() const
Test if the node acts like a power of variable.
Definition node.hpp:570
virtual bool is_match(std::shared_ptr< leaf_node< T, SAFE_MATH > > x)
Query if the nodes match.
Definition node.hpp:470
virtual std::shared_ptr< leaf_node< T, SAFE_MATH > > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)=0
Compile the node.
virtual std::shared_ptr< leaf_node< T, SAFE_MATH > > get_power_exponent() const =0
Get the exponent of a power.
virtual void to_latex() const =0
Convert the node to latex.
T base
Type def to retrieve the backend type.
Definition node.hpp:669
virtual std::shared_ptr< leaf_node< T, SAFE_MATH > > df(std::shared_ptr< leaf_node< T, SAFE_MATH > > x)=0
Transform node to derivative.
const bool contains_pseudo
Node contains pseudo variables.
Definition node.hpp:373
virtual std::shared_ptr< leaf_node< T, SAFE_MATH > > remove_pseudo()
Remove pseudo variable nodes.
Definition node.hpp:627
virtual void set(const backend::buffer< T > &d)
Set the value of variable data.
Definition node.hpp:512
static thread_local caches_t caches
A per thread instance of the cache structure.
Definition node.hpp:666
virtual bool is_constant() const
Test if node is a constant.
Definition node.hpp:534
virtual void set(const size_t index, const T d)
Set the value of variable data.
Definition node.hpp:497
virtual std::shared_ptr< leaf_node< T, SAFE_MATH > > to_vizgraph(std::stringstream &stream, jit::register_map &registers)=0
Convert the node to vizgraph.
virtual std::shared_ptr< leaf_node< T, SAFE_MATH > > get_power_base()
Get the base of a power.
Definition node.hpp:581
leaf_node(const std::string s, const size_t count, const bool pseudo)
Construct a basic node.
Definition node.hpp:383
virtual bool has_constant_zero() const
Test the constant node has a zero.
Definition node.hpp:543
virtual void set(const T d)
Set the value of variable data.
Definition node.hpp:489
const size_t complexity
Graph complexity.
Definition node.hpp:369
virtual backend::buffer< T > evaluate()=0
Evaluate method.
std::map< size_t, std::shared_ptr< leaf_node< T, SAFE_MATH > > > df_cache
Cache derivative terms.
Definition node.hpp:371
bool is_power_base_match(std::shared_ptr< leaf_node< T, SAFE_MATH > > x)
Check if the base of the powers match.
Definition node.hpp:480
virtual bool is_all_variables() const =0
Test if all the sub-nodes terminate in variables.
virtual std::shared_ptr< leaf_node > reduce()=0
Reduction method.
size_t get_complexity() const
Get the number of nodes in the subgraph.
Definition node.hpp:609
virtual bool has_pseudo() const
Query if the node contains pseudo variables.
Definition node.hpp:618
const size_t hash
Hash for node.
Definition node.hpp:367
virtual void compile_preamble(std::ostringstream &stream, jit::register_map &registers, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d, int &avail_const_mem)
Compile preamble.
Definition node.hpp:433
bool is_normal()
Test if the result is normal.
Definition node.hpp:552
size_t get_hash() const
Get the hash for the node.
Definition node.hpp:600
Class representing a subexpression that acts like a variable.
Definition node.hpp:1765
virtual void to_latex() const
Convert the node to latex.
Definition node.hpp:1810
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition node.hpp:1803
virtual bool is_power_like() const
Test if the node acts like a power of variable.
Definition node.hpp:1830
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition node.hpp:1877
virtual shared_leaf< T, SAFE_MATH > get_power_exponent() const
Get the exponent of a power.
Definition node.hpp:1848
virtual bool is_all_variables() const
Test if node acts like a variable.
Definition node.hpp:1821
virtual shared_leaf< T, SAFE_MATH > remove_pseudo()
Remove pseudo variable nodes.
Definition node.hpp:1866
virtual shared_leaf< T, SAFE_MATH > get_power_base()
Get the base of a power.
Definition node.hpp:1839
virtual bool has_pseudo() const
Query if the node contains pseudo variables.
Definition node.hpp:1857
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduction method.
Definition node.hpp:1793
pseudo_variable_node(shared_leaf< T, SAFE_MATH > a)
Construct a pseudo variable node.
Definition node.hpp:1783
Class representing a straight node.
Definition node.hpp:1060
virtual shared_leaf< T, SAFE_MATH > get_power_exponent() const
Get the exponent of a power.
Definition node.hpp:1156
shared_leaf< T, SAFE_MATH > arg
Argument.
Definition node.hpp:1063
virtual backend::buffer< T > evaluate()
Evaluate method.
Definition node.hpp:1082
shared_leaf< T, SAFE_MATH > get_arg()
Get the argument.
Definition node.hpp:1138
virtual void compile_preamble(std::ostringstream &stream, jit::register_map &registers, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d, int &avail_const_mem)
Compile preamble.
Definition node.hpp:1097
straight_node(shared_leaf< T, SAFE_MATH > a, const std::string s)
Construct a straight node.
Definition node.hpp:1072
virtual bool is_all_variables() const
Test if node acts like a variable.
Definition node.hpp:1147
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition node.hpp:1128
Class representing a triple branch node.
Definition node.hpp:1298
virtual void compile_preamble(std::ostringstream &stream, jit::register_map &registers, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d, int &avail_const_mem)
Compile preamble.
Definition node.hpp:1337
shared_leaf< T, SAFE_MATH > get_middle()
Get the right branch.
Definition node.hpp:1369
virtual bool is_all_variables() const
Test if node acts like a variable.
Definition node.hpp:1378
triple_node(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > m, shared_leaf< T, SAFE_MATH > r, const std::string s)
Reduces and assigns the left and right branches.
Definition node.hpp:1313
shared_leaf< T, SAFE_MATH > middle
Middle branch of the tree.
Definition node.hpp:1301
Class representing data that can change.
Definition node.hpp:1395
virtual shared_leaf< T, SAFE_MATH > get_power_base()
Get the base of a power.
Definition node.hpp:1651
variable_node(const backend::buffer< T > &d, const std::string &symbol)
Construct a variable node from backend buffer.
Definition node.hpp:1457
virtual bool is_all_variables() const
Test if node acts like a variable.
Definition node.hpp:1633
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition node.hpp:1490
variable_node(const size_t s, const T d, const std::string &symbol)
Construct a variable node from a scalar.
Definition node.hpp:1431
virtual void to_latex() const
Convert the node to latex.
Definition node.hpp:1588
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduction method.
Definition node.hpp:1480
size_t size()
Get the size of the variable buffer.
Definition node.hpp:1615
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition node.hpp:1534
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition node.hpp:1599
virtual void compile_preamble(std::ostringstream &stream, jit::register_map &registers, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d, int &avail_const_mem)
Compile preamble.
Definition node.hpp:1508
virtual backend::buffer< T > evaluate()
Evaluate method.
Definition node.hpp:1469
T * data()
Get a pointer to raw buffer.
Definition node.hpp:1624
virtual void set(const T d)
Set the value of variable data.
Definition node.hpp:1546
virtual bool is_power_like() const
Test if the node acts like a power of variable.
Definition node.hpp:1642
virtual void set(const backend::buffer< T > &d)
Set the value of variable data.
Definition node.hpp:1574
virtual void set(const size_t index, const T d)
Set the value of variable data.
Definition node.hpp:1556
variable_node(const size_t s, const std::string &symbol)
Construct a variable node with a size.
Definition node.hpp:1419
virtual void set(const std::vector< T > &d)
Set the value of variable data.
Definition node.hpp:1565
variable_node(const std::vector< T > &d, const std::string &symbol)
Construct a variable node from a vector.
Definition node.hpp:1444
std::string get_symbol() const
Get Symbol.
Definition node.hpp:1581
virtual shared_leaf< T, SAFE_MATH > get_power_exponent() const
Get the exponent of a power.
Definition node.hpp:1660
Complex scalar concept.
Definition register.hpp:24
subroutine assert(test, message)
Assert check.
Definition f_binding_test.f90:38
Name space for graph nodes.
Definition arithmetic.hpp:13
constexpr shared_leaf< T, SAFE_MATH > zero()
Forward declare for zero.
Definition node.hpp:995
std::shared_ptr< variable_node< T, SAFE_MATH > > shared_variable
Convenience type alias for shared variable nodes.
Definition node.hpp:1728
constexpr shared_leaf< T, SAFE_MATH > none()
Create a one constant.
Definition node.hpp:1021
void make_vizgraph(shared_leaf< T, SAFE_MATH > node)
Build the vizgraph input.
Definition node.hpp:704
constexpr shared_leaf< T, SAFE_MATH > one()
Forward declare for one.
Definition node.hpp:1008
std::vector< shared_variable< T, SAFE_MATH > > input_nodes
Convenience type alias for a vector of inputs.
Definition node.hpp:1731
shared_constant< T, SAFE_MATH > constant_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a constant node.
Definition node.hpp:1043
shared_variable< T, SAFE_MATH > variable_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a variable node.
Definition node.hpp:1747
constexpr T i
Convenience type for imaginary constant.
Definition node.hpp:1027
shared_leaf< T, SAFE_MATH > constant(const backend::buffer< T > &d)
Construct a constant.
Definition node.hpp:952
shared_leaf< T, SAFE_MATH > variable(const size_t s, const std::string &symbol)
Construct a variable.
Definition node.hpp:1675
std::shared_ptr< pseudo_variable_node< T, SAFE_MATH > > shared_pseudo_variable
Convenience type alias for shared pseudo variable nodes.
Definition node.hpp:1909
std::shared_ptr< leaf_node< T, SAFE_MATH > > shared_leaf
Convenience type alias for shared leaf nodes.
Definition node.hpp:674
std::vector< std::pair< shared_leaf< T, SAFE_MATH >, shared_variable< T, SAFE_MATH > > > map_nodes
Convenience type alias for mapping end codes back to inputs.
Definition node.hpp:1735
constexpr shared_leaf< T, SAFE_MATH > null_leaf()
Create a null leaf.
Definition node.hpp:684
std::shared_ptr< constant_node< T, SAFE_MATH > > shared_constant
Convenience type alias for shared constant nodes.
Definition node.hpp:1031
std::vector< shared_leaf< T, SAFE_MATH > > output_nodes
Convenience type alias for a vector of output nodes.
Definition node.hpp:689
shared_leaf< T, SAFE_MATH > pseudo_variable(shared_leaf< T, SAFE_MATH > x)
Define pseudo variable convenience function.
Definition node.hpp:1903
shared_pseudo_variable< T, SAFE_MATH > pseudo_variable_cast(shared_leaf< T, SAFE_MATH > &x)
Cast to a pseudo variable node.
Definition node.hpp:1921
std::map< void *, size_t > texture1d_list
Type alias for indexing 1D textures.
Definition register.hpp:263
std::string format_to_string(const T value)
Convert a value to a string while avoiding locale.
Definition register.hpp:212
std::map< void *, std::array< size_t, 2 > > texture2d_list
Type alias for indexing 2D textures.
Definition register.hpp:265
std::map< void *, size_t > register_usage
Type alias for counting register usage.
Definition register.hpp:259
std::map< void *, std::string > register_map
Type alias for mapping node pointers to register names.
Definition register.hpp:257
std::set< void * > visiter_map
Type alias for listing visited nodes.
Definition register.hpp:261
std::string to_string(const char prefix, const NODE *pointer)
Convert a graph::leaf_node pointer to a string.
Definition register.hpp:246
Data structure to contain the two caches.
Definition node.hpp:658
std::map< size_t, std::shared_ptr< leaf_node< T, SAFE_MATH > > > nodes
Cache of node.
Definition node.hpp:660
std::map< size_t, backend::buffer< T > > backends
Cache of backend buffers.
Definition node.hpp:662