Graph Framework
Loading...
Searching...
No Matches
node.hpp
Go to the documentation of this file.
1//------------------------------------------------------------------------------
6//------------------------------------------------------------------------------
7//------------------------------------------------------------------------------
340//------------------------------------------------------------------------------
341#ifndef node_h
342#define node_h
343
344#include <iostream>
345#include <string>
346#include <memory>
347#include <iomanip>
348#include <functional>
349
350#include "backend.hpp"
351
353namespace graph {
354//******************************************************************************
355// Base leaf node.
356//******************************************************************************
357//------------------------------------------------------------------------------
362//------------------------------------------------------------------------------
363 template<jit::float_scalar T, bool SAFE_MATH=false>
364 class leaf_node : public std::enable_shared_from_this<leaf_node<T, SAFE_MATH>> {
365 protected:
367 const size_t hash;
369 const size_t complexity;
371 std::map<size_t, std::shared_ptr<leaf_node<T, SAFE_MATH>>> df_cache;
373 const bool contains_pseudo;
374
375 public:
376//------------------------------------------------------------------------------
382//------------------------------------------------------------------------------
383 leaf_node(const std::string s,
384 const size_t count,
385 const bool pseudo) :
386 hash(std::hash<std::string>{} (s)),
388
389//------------------------------------------------------------------------------
391//------------------------------------------------------------------------------
392 virtual ~leaf_node() {}
393
394//------------------------------------------------------------------------------
398//------------------------------------------------------------------------------
400
401//------------------------------------------------------------------------------
407//------------------------------------------------------------------------------
408 virtual std::shared_ptr<leaf_node<T, SAFE_MATH>> reduce() {
409 return this->shared_from_this();
410 }
411
412//------------------------------------------------------------------------------
417//------------------------------------------------------------------------------
418 virtual std::shared_ptr<leaf_node<T, SAFE_MATH>>
419 df(std::shared_ptr<leaf_node<T, SAFE_MATH>> x) = 0;
420
421//------------------------------------------------------------------------------
434//------------------------------------------------------------------------------
435 virtual void compile_preamble(std::ostringstream &stream,
436 jit::register_map &registers,
441 int &avail_const_mem) {
442#ifdef SHOW_USE_COUNT
443 if (usage.find(this) == usage.end()) {
444 usage[this] = 1;
445 } else {
446 ++usage[this];
447 }
448#endif
449 }
450
451//------------------------------------------------------------------------------
459//------------------------------------------------------------------------------
460 virtual std::shared_ptr<leaf_node<T, SAFE_MATH>>
461 compile(std::ostringstream &stream,
462 jit::register_map &registers,
464 const jit::register_usage &usage) = 0;
465
466//------------------------------------------------------------------------------
471//------------------------------------------------------------------------------
472 virtual bool is_match(std::shared_ptr<leaf_node<T, SAFE_MATH>> x) {
473 return this == x.get();
474 }
475
476//------------------------------------------------------------------------------
481//------------------------------------------------------------------------------
483 return this->get_power_base()->is_match(x->get_power_base());
484 }
485
486//------------------------------------------------------------------------------
490//------------------------------------------------------------------------------
491 virtual void set(const T d) {}
492
493//------------------------------------------------------------------------------
498//------------------------------------------------------------------------------
499 virtual void set(const size_t index,
500 const T d) {}
501
502//------------------------------------------------------------------------------
506//------------------------------------------------------------------------------
507 virtual void set(const std::vector<T> &d) {}
508
509//------------------------------------------------------------------------------
513//------------------------------------------------------------------------------
514 virtual void set(const backend::buffer<T> &d) {}
515
516//------------------------------------------------------------------------------
518//------------------------------------------------------------------------------
519 virtual void to_latex() const = 0;
520
521//------------------------------------------------------------------------------
527//------------------------------------------------------------------------------
528 virtual std::shared_ptr<leaf_node<T, SAFE_MATH>> to_vizgraph(std::stringstream &stream,
529 jit::register_map &registers) = 0;
530
531//------------------------------------------------------------------------------
535//------------------------------------------------------------------------------
536 virtual bool is_constant() const {
537 return false;
538 }
539
540//------------------------------------------------------------------------------
544//------------------------------------------------------------------------------
545 virtual bool has_constant_zero() const {
546 return false;
547 }
548
549//------------------------------------------------------------------------------
553//------------------------------------------------------------------------------
554 bool is_normal() {
555 return this->evaluate().is_normal();
556 }
557
558//------------------------------------------------------------------------------
562//------------------------------------------------------------------------------
563 virtual bool is_all_variables() const = 0;
564
565//------------------------------------------------------------------------------
571//------------------------------------------------------------------------------
572 virtual bool is_power_like() const {
573 return false;
574 }
575
576//------------------------------------------------------------------------------
582//------------------------------------------------------------------------------
583 virtual std::shared_ptr<leaf_node<T, SAFE_MATH>> get_power_base() {
584 return this->shared_from_this();
585 }
586
587//------------------------------------------------------------------------------
594//------------------------------------------------------------------------------
595 virtual std::shared_ptr<leaf_node<T, SAFE_MATH>> get_power_exponent() const = 0;
596
597//------------------------------------------------------------------------------
601//------------------------------------------------------------------------------
602 size_t get_hash() const {
603 return hash;
604 }
605
606//------------------------------------------------------------------------------
610//------------------------------------------------------------------------------
611 size_t get_complexity() const {
612 return complexity;
613 }
614
615//------------------------------------------------------------------------------
619//------------------------------------------------------------------------------
620 virtual bool has_pseudo() const {
621 return contains_pseudo;
622 }
623
624//------------------------------------------------------------------------------
628//------------------------------------------------------------------------------
629 virtual std::shared_ptr<leaf_node<T, SAFE_MATH>> remove_pseudo() {
630 return this->shared_from_this();
631 }
632
633//------------------------------------------------------------------------------
638//------------------------------------------------------------------------------
639 virtual void endline(std::ostringstream &stream,
641#ifndef SHOW_USE_COUNT
642 const
643#endif
644 final {
645 stream << ";"
646#ifdef SHOW_USE_COUNT
647 << " // used " << usage.at(this)
648#endif
649 << std::endl;
650 }
651
652// Create one struct that holds both caches: for constructed nodes and for the
653// backend buffers
654//------------------------------------------------------------------------------
659//------------------------------------------------------------------------------
660 struct caches_t {
662 std::map<size_t, std::shared_ptr<leaf_node<T, SAFE_MATH>>> nodes;
664 std::map<size_t, backend::buffer<T>> backends;
665 };
666
668 inline static thread_local caches_t caches;
669
671 typedef T base;
672 };
673
675 template<jit::float_scalar T, bool SAFE_MATH=false>
676 using shared_leaf = std::shared_ptr<leaf_node<T, SAFE_MATH>>;
677//------------------------------------------------------------------------------
684//------------------------------------------------------------------------------
685 template<jit::float_scalar T, bool SAFE_MATH=false>
690 template<jit::float_scalar T, bool SAFE_MATH=false>
691 using output_nodes = std::vector<shared_leaf<T, SAFE_MATH>>;
692
694 template<jit::float_scalar T, bool SAFE_MATH=false>
697 template<jit::float_scalar T, bool SAFE_MATH=false>
698 constexpr shared_leaf<T, SAFE_MATH> one();
699
700//------------------------------------------------------------------------------
704//------------------------------------------------------------------------------
705 template<jit::float_scalar T, bool SAFE_MATH=false>
707 std::stringstream stream;
708 jit::register_map registers;
710
711 stream << "graph \"\" {" << std::endl;
712 stream << " node [fontname = \"Helvetica\", ordering = out]" << std::endl << std::endl;
713 node->to_vizgraph(stream, registers);
714 stream << "}" << std::endl;
715
716 std::cout << stream.str() << std::endl;
717 }
718
719//******************************************************************************
720// Constant node.
721//******************************************************************************
722//------------------------------------------------------------------------------
727//------------------------------------------------------------------------------
728 template<jit::float_scalar T, bool SAFE_MATH=false>
729 class constant_node final : public leaf_node<T, SAFE_MATH> {
730//------------------------------------------------------------------------------
735//------------------------------------------------------------------------------
736 static std::string to_string(const T d) {
737 return jit::format_to_string<T> (d);
738 }
739
740 private:
742 const backend::buffer<T> data;
743
744 public:
745//------------------------------------------------------------------------------
749//------------------------------------------------------------------------------
751 leaf_node<T, SAFE_MATH> (constant_node::to_string(d.at(0)), 1, false), data(d) {
752 assert(d.size() == 1 && "Constants need to be scalar functions.");
753 }
754
755//------------------------------------------------------------------------------
759//------------------------------------------------------------------------------
761 return data;
762 }
763
764//------------------------------------------------------------------------------
769//------------------------------------------------------------------------------
773
774//------------------------------------------------------------------------------
782//------------------------------------------------------------------------------
784 compile(std::ostringstream &stream,
785 jit::register_map &registers,
787 const jit::register_usage &usage) {
788 if (registers.find(this) == registers.end()) {
789#ifdef USE_CONSTANT_CACHE
790 registers[this] = jit::to_string('r', this);
791 stream << " const ";
792 jit::add_type<T> (stream);
793 const T temp = this->evaluate().at(0);
794
795 stream << " " << registers[this] << " = ";
796 if constexpr (jit::complex_scalar<T>) {
797 jit::add_type<T> (stream);
798 }
799 stream << temp;
800 this->endline(stream, usage);
801#else
802 if constexpr (jit::complex_scalar<T>) {
803 registers[this] = jit::get_type_string<T> () + "("
804 + jit::format_to_string(this->evaluate().at(0))
805 + ")";
806 } else {
807 registers[this] = "(" + jit::get_type_string<T> () + ")"
808 + jit::format_to_string(this->evaluate().at(0));
809 }
810#endif
811 }
812
813 return this->shared_from_this();
814 }
815
816//------------------------------------------------------------------------------
821//------------------------------------------------------------------------------
823 if (this == x.get()) {
824 return true;
825 }
826
827 auto x_cast = constant_cast(x);
828 if (x_cast.get()) {
829 return this->evaluate() == x_cast->evaluate();
830 }
831
832 return false;
833 }
834
835//------------------------------------------------------------------------------
837//------------------------------------------------------------------------------
838 bool is(const T d) {
839 return data.size() == 1 && data.at(0) == d;
840 }
841
842//------------------------------------------------------------------------------
844//------------------------------------------------------------------------------
845 bool is_integer() {
846 const auto temp = this->evaluate().at(0);
847 return std::imag(temp) == 0 &&
848 fmod(std::real(temp), 1.0) == 0.0;
849 }
850
851//------------------------------------------------------------------------------
853//------------------------------------------------------------------------------
854 virtual void to_latex() const {
855 std::cout << data.at(0);
856 }
857
858//------------------------------------------------------------------------------
864//------------------------------------------------------------------------------
865 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
866 jit::register_map &registers) {
867 if (registers.find(this) == registers.end()) {
868 const std::string name = jit::to_string('r', this);
869 registers[this] = name;
870 stream << " " << name
871 << " [label = \"" << this->evaluate().at(0)
872 << "\", shape = box, style = \"rounded,filled\", fillcolor = black, fontcolor = white];" << std::endl;
873 }
874
875 return this->shared_from_this();
876 }
877
878//------------------------------------------------------------------------------
882//------------------------------------------------------------------------------
883 virtual bool is_constant() const {
884 return true;
885 }
886
887//------------------------------------------------------------------------------
891//------------------------------------------------------------------------------
892 virtual bool has_constant_zero() const {
893 return data.has_zero();
894 }
895
896//------------------------------------------------------------------------------
900//------------------------------------------------------------------------------
901 virtual bool is_all_variables() const {
902 return false;
903 }
904
905//------------------------------------------------------------------------------
909//------------------------------------------------------------------------------
910 virtual bool is_power_like() const {
911 return true;
912 }
913
914//------------------------------------------------------------------------------
918//------------------------------------------------------------------------------
920 return this->shared_from_this();
921 }
922
923//------------------------------------------------------------------------------
927//------------------------------------------------------------------------------
929 return one<T, SAFE_MATH> ();
930 }
931 };
932
933//------------------------------------------------------------------------------
941//------------------------------------------------------------------------------
942 template<jit::float_scalar T, bool SAFE_MATH=false>
944 auto temp = std::make_shared<constant_node<T, SAFE_MATH>> (d);
945// Test for hash collisions.
946 for (size_t i = temp->get_hash(); i < std::numeric_limits<size_t>::max(); i++) {
947 if (leaf_node<T, SAFE_MATH>::caches.nodes.find(i) ==
950 return temp;
951 } else if (temp->is_match(leaf_node<T, SAFE_MATH>::caches.nodes[i])) {
953 }
954 }
955#if defined(__clang__) || defined(__GNUC__)
957#else
958 assert(false && "Should never reach.");
959#endif
960 }
961
962//------------------------------------------------------------------------------
970//------------------------------------------------------------------------------
971 template<jit::float_scalar T, bool SAFE_MATH=false>
975
976// Define some common constants.
977//------------------------------------------------------------------------------
984//------------------------------------------------------------------------------
985 template<jit::float_scalar T, bool SAFE_MATH>
987 return constant<T, SAFE_MATH> (static_cast<T> (0.0));
988 }
989
990//------------------------------------------------------------------------------
997//------------------------------------------------------------------------------
998 template<jit::float_scalar T, bool SAFE_MATH>
1000 return constant<T, SAFE_MATH> (static_cast<T> (1.0));
1001 }
1002
1003//------------------------------------------------------------------------------
1010//------------------------------------------------------------------------------
1011 template<jit::float_scalar T, bool SAFE_MATH=false>
1013 return constant<T, SAFE_MATH> (static_cast<T> (-1.0));
1014 }
1015
1017 template<jit::complex_scalar T>
1018 constexpr T i = T(0.0, 1.0);
1019
1021 template<jit::float_scalar T, bool SAFE_MATH=false>
1022 using shared_constant = std::shared_ptr<constant_node<T, SAFE_MATH>>;
1023
1024//------------------------------------------------------------------------------
1032//------------------------------------------------------------------------------
1033 template<jit::float_scalar T, bool SAFE_MATH=false>
1035 return std::dynamic_pointer_cast<constant_node<T, SAFE_MATH>> (x);
1036 }
1037
1038//******************************************************************************
1039// Base straight node.
1040//******************************************************************************
1041//------------------------------------------------------------------------------
1049//------------------------------------------------------------------------------
1050 template<jit::float_scalar T, bool SAFE_MATH=false>
1051 class straight_node : public leaf_node<T, SAFE_MATH> {
1052 protected:
1055
1056 public:
1057//------------------------------------------------------------------------------
1062//------------------------------------------------------------------------------
1064 const std::string s) :
1065 leaf_node<T, SAFE_MATH> (s, a->get_complexity() + 1, a->has_pseudo()),
1066 arg(a) {}
1067
1068//------------------------------------------------------------------------------
1072//------------------------------------------------------------------------------
1074 return this->arg->evaluate();
1075 }
1076
1077//------------------------------------------------------------------------------
1087//------------------------------------------------------------------------------
1088 virtual void compile_preamble(std::ostringstream &stream,
1089 jit::register_map &registers,
1094 int &avail_const_mem) {
1095 if (visited.find(this) == visited.end()) {
1096 this->arg->compile_preamble(stream, registers,
1097 visited, usage,
1100 visited.insert(this);
1101#ifdef SHOW_USE_COUNT
1102 usage[this] = 1;
1103 } else {
1104 ++usage[this];
1105#endif
1106 }
1107 }
1108
1109//------------------------------------------------------------------------------
1117//------------------------------------------------------------------------------
1119 compile(std::ostringstream &stream,
1120 jit::register_map &registers,
1122 const jit::register_usage &usage) {
1123 return this->arg->compile(stream, registers, indices, usage);
1124 }
1125
1126//------------------------------------------------------------------------------
1128//------------------------------------------------------------------------------
1130 return this->arg;
1131 }
1132
1133//------------------------------------------------------------------------------
1137//------------------------------------------------------------------------------
1138 virtual bool is_all_variables() const {
1139 return this->arg->is_all_variables();
1140 }
1141
1142//------------------------------------------------------------------------------
1146//------------------------------------------------------------------------------
1148 return one<T, SAFE_MATH> ();
1149 }
1150 };
1151
1152//******************************************************************************
1153// Base branch node.
1154//******************************************************************************
1155//------------------------------------------------------------------------------
1163//------------------------------------------------------------------------------
1164 template<jit::float_scalar T, bool SAFE_MATH=false>
1165 class branch_node : public leaf_node<T, SAFE_MATH> {
1166 protected:
1171
1172 public:
1173
1174//------------------------------------------------------------------------------
1180//------------------------------------------------------------------------------
1183 const std::string s) :
1185 l->has_pseudo() || r->has_pseudo()),
1186 left(l), right(r) {}
1187
1188//------------------------------------------------------------------------------
1196//------------------------------------------------------------------------------
1199 const std::string s,
1200 const size_t count,
1201 const bool pseudo) :
1202 leaf_node<T, SAFE_MATH> (s, count, pseudo),
1203 left(l), right(r) {}
1204
1205//------------------------------------------------------------------------------
1215//------------------------------------------------------------------------------
1216 virtual void compile_preamble(std::ostringstream &stream,
1217 jit::register_map &registers,
1222 int &avail_const_mem) {
1223 if (visited.find(this) == visited.end()) {
1224 this->left->compile_preamble(stream, registers,
1225 visited, usage,
1228 this->right->compile_preamble(stream, registers,
1229 visited, usage,
1232 visited.insert(this);
1233#ifdef SHOW_USE_COUNT
1234 usage[this] = 1;
1235 } else {
1236 ++usage[this];
1237#endif
1238 }
1239 }
1240
1241//------------------------------------------------------------------------------
1243//------------------------------------------------------------------------------
1245 return this->left;
1246 }
1247
1248//------------------------------------------------------------------------------
1250//------------------------------------------------------------------------------
1252 return this->right;
1253 }
1254
1255//------------------------------------------------------------------------------
1259//------------------------------------------------------------------------------
1260 virtual bool is_all_variables() const {
1261 return this->left->is_all_variables() &&
1262 this->right->is_all_variables();
1263 }
1264
1265//------------------------------------------------------------------------------
1269//------------------------------------------------------------------------------
1270 virtual std::shared_ptr<leaf_node<T, SAFE_MATH>>
1272 return one<T, SAFE_MATH> ();
1273 }
1274 };
1275
1276//******************************************************************************
1277// Base triple node.
1278//******************************************************************************
1279//------------------------------------------------------------------------------
1287//------------------------------------------------------------------------------
1288 template<jit::float_scalar T, bool SAFE_MATH=false>
1289 class triple_node : public branch_node<T, SAFE_MATH> {
1290 protected:
1293
1294 public:
1295
1296//------------------------------------------------------------------------------
1303//------------------------------------------------------------------------------
1307 const std::string s) :
1308 branch_node<T, SAFE_MATH> (l, r, s,
1309 l->get_complexity() +
1310 m->get_complexity() +
1311 r->get_complexity(),
1312 l->has_pseudo() ||
1313 m->has_pseudo() ||
1314 r->has_pseudo()),
1315 middle(m) {}
1316
1317//------------------------------------------------------------------------------
1327//------------------------------------------------------------------------------
1328 virtual void compile_preamble(std::ostringstream &stream,
1329 jit::register_map &registers,
1334 int &avail_const_mem) {
1335 if (visited.find(this) == visited.end()) {
1336 this->left->compile_preamble(stream, registers,
1337 visited, usage,
1340 this->middle->compile_preamble(stream, registers,
1341 visited, usage,
1344 this->right->compile_preamble(stream, registers,
1345 visited, usage,
1348 visited.insert(this);
1349#ifdef SHOW_USE_COUNT
1350 usage[this] = 1;
1351 } else {
1352 ++usage[this];
1353#endif
1354 }
1355 }
1356
1357//------------------------------------------------------------------------------
1359//------------------------------------------------------------------------------
1361 return this->middle;
1362 }
1363
1364//------------------------------------------------------------------------------
1368//------------------------------------------------------------------------------
1369 virtual bool is_all_variables() const {
1370 return this->left->is_all_variables() &&
1371 this->middle->is_all_variables() &&
1372 this->right->is_all_variables();
1373 }
1374 };
1375
1376//******************************************************************************
1377// Variable node.
1378//******************************************************************************
1379//------------------------------------------------------------------------------
1384//------------------------------------------------------------------------------
1385 template<jit::float_scalar T, bool SAFE_MATH=false>
1386 class variable_node final : public leaf_node<T, SAFE_MATH> {
1387 private:
1389 backend::buffer<T> buffer;
1391 const std::string symbol;
1392
1393//------------------------------------------------------------------------------
1398//------------------------------------------------------------------------------
1399 static std::string to_string(variable_node<T, SAFE_MATH> *p) {
1400 return jit::format_to_string(reinterpret_cast<size_t> (p));
1401 }
1402
1403 public:
1404//------------------------------------------------------------------------------
1409//------------------------------------------------------------------------------
1410 variable_node(const size_t s,
1411 const std::string &symbol) :
1412 leaf_node<T, SAFE_MATH> (variable_node::to_string(this), 1, false),
1413 buffer(s), symbol(symbol) {}
1414
1415//------------------------------------------------------------------------------
1421//------------------------------------------------------------------------------
1422 variable_node(const size_t s, const T d,
1423 const std::string &symbol) :
1424 leaf_node<T, SAFE_MATH> (variable_node::to_string(this), 1, false),
1425 buffer(s, d), symbol(symbol) {
1426 assert(buffer.is_normal() && "NaN or Inf value.");
1427 }
1428
1429//------------------------------------------------------------------------------
1434//------------------------------------------------------------------------------
1435 variable_node(const std::vector<T> &d,
1436 const std::string &symbol) :
1437 leaf_node<T, SAFE_MATH> (variable_node::to_string(this), 1, false),
1438 buffer(d), symbol(symbol) {
1439 assert(buffer.is_normal() && "NaN or Inf value.");
1440 }
1441
1442//------------------------------------------------------------------------------
1447//------------------------------------------------------------------------------
1449 const std::string &symbol) :
1450 leaf_node<T, SAFE_MATH> (variable_node::to_string(this), 1, false),
1451 buffer(d), symbol(symbol) {
1452 assert(buffer.is_normal() && "NaN or Inf value.");
1453 }
1454
1455//------------------------------------------------------------------------------
1459//------------------------------------------------------------------------------
1461 return buffer;
1462 }
1463
1464//------------------------------------------------------------------------------
1469//------------------------------------------------------------------------------
1471 return constant<T, SAFE_MATH> (static_cast<T> (this->is_match(x)));
1472 }
1473
1474//------------------------------------------------------------------------------
1487//------------------------------------------------------------------------------
1488 virtual void compile_preamble(std::ostringstream &stream,
1489 jit::register_map &registers,
1494 int &avail_const_mem) {
1495 if (usage.find(this) == usage.end()) {
1496 usage[this] = 1;
1497#ifdef SHOW_USE_COUNT
1498 } else {
1499 ++usage[this];
1500#endif
1501 }
1502 }
1503
1504//------------------------------------------------------------------------------
1512//------------------------------------------------------------------------------
1514 compile(std::ostringstream &stream,
1515 jit::register_map &registers,
1517 const jit::register_usage &usage) {
1518 return this->shared_from_this();
1519 }
1520
1521//------------------------------------------------------------------------------
1525//------------------------------------------------------------------------------
1526 virtual void set(const T d) {
1527 buffer.set(d);
1528 }
1529
1530//------------------------------------------------------------------------------
1535//------------------------------------------------------------------------------
1536 virtual void set(const size_t index, const T d) {
1537 buffer[index] = d;
1538 }
1539
1540//------------------------------------------------------------------------------
1544//------------------------------------------------------------------------------
1545 virtual void set(const std::vector<T> &d) {
1546 buffer.set(d);
1547 }
1548
1549//------------------------------------------------------------------------------
1553//------------------------------------------------------------------------------
1554 virtual void set(const backend::buffer<T> &d) {
1555 buffer = d;
1556 }
1557
1558//------------------------------------------------------------------------------
1560//------------------------------------------------------------------------------
1561 std::string get_symbol() const {
1562 return symbol;
1563 }
1564
1565//------------------------------------------------------------------------------
1567//------------------------------------------------------------------------------
1568 virtual void to_latex() const {
1569 std::cout << get_symbol();
1570 }
1571
1572//------------------------------------------------------------------------------
1578//------------------------------------------------------------------------------
1579 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
1580 jit::register_map &registers) {
1581 if (registers.find(this) == registers.end()) {
1582 const std::string name = jit::to_string('r', this);
1583 registers[this] = name;
1584 stream << " " << name
1585 << " [label = \"" << this->get_symbol()
1586 << "\", shape = box];" << std::endl;
1587 }
1588
1589 return this->shared_from_this();
1590 }
1591
1592//------------------------------------------------------------------------------
1594//------------------------------------------------------------------------------
1595 size_t size() {
1596 return buffer.size();
1597 }
1598
1599//------------------------------------------------------------------------------
1603//------------------------------------------------------------------------------
1604 T *data() {
1605 return buffer.data();
1606 }
1607
1608//------------------------------------------------------------------------------
1612//------------------------------------------------------------------------------
1613 virtual bool is_all_variables() const {
1614 return true;
1615 }
1616
1617//------------------------------------------------------------------------------
1621//------------------------------------------------------------------------------
1622 virtual bool is_power_like() const {
1623 return true;
1624 }
1625
1626//------------------------------------------------------------------------------
1630//------------------------------------------------------------------------------
1632 return this->shared_from_this();
1633 }
1634
1635//------------------------------------------------------------------------------
1639//------------------------------------------------------------------------------
1641 return one<T, SAFE_MATH> ();
1642 }
1643 };
1644
1645//------------------------------------------------------------------------------
1653//------------------------------------------------------------------------------
1654 template<jit::float_scalar T, bool SAFE_MATH=false>
1656 const std::string &symbol) {
1657 return std::make_shared<variable_node<T, SAFE_MATH>> (s, symbol);
1658 }
1659
1660//------------------------------------------------------------------------------
1669//------------------------------------------------------------------------------
1670 template<jit::float_scalar T, bool SAFE_MATH=false>
1671 shared_leaf<T, SAFE_MATH> variable(const size_t s, const T d,
1672 const std::string &symbol) {
1673 return std::make_shared<variable_node<T, SAFE_MATH>> (s, d, symbol);
1674 }
1675
1676//------------------------------------------------------------------------------
1684//------------------------------------------------------------------------------
1685 template<jit::float_scalar T, bool SAFE_MATH=false>
1686 shared_leaf<T, SAFE_MATH> variable(const std::vector<T> &d,
1687 const std::string &symbol) {
1688 return std::make_shared<variable_node<T, SAFE_MATH>> (d, symbol);
1689 }
1690
1691//------------------------------------------------------------------------------
1699//------------------------------------------------------------------------------
1700 template<jit::float_scalar T, bool SAFE_MATH=false>
1702 const std::string &symbol) {
1703 return std::make_shared<variable_node<T, SAFE_MATH>> (d, symbol);
1704 }
1705
1707 template<jit::float_scalar T, bool SAFE_MATH=false>
1708 using shared_variable = std::shared_ptr<variable_node<T, SAFE_MATH>>;
1710 template<jit::float_scalar T, bool SAFE_MATH=false>
1711 using input_nodes = std::vector<shared_variable<T, SAFE_MATH>>;
1713 template<jit::float_scalar T, bool SAFE_MATH=false>
1714 using map_nodes = std::vector<std::pair<shared_leaf<T, SAFE_MATH>,
1716
1717//------------------------------------------------------------------------------
1725//------------------------------------------------------------------------------
1726 template<jit::float_scalar T, bool SAFE_MATH=false>
1728 return std::dynamic_pointer_cast<variable_node<T, SAFE_MATH>> (x);
1729 }
1730
1731//******************************************************************************
1732// Pseudo variable node.
1733//******************************************************************************
1734//------------------------------------------------------------------------------
1743//------------------------------------------------------------------------------
1744 template<jit::float_scalar T, bool SAFE_MATH=false>
1745 class pseudo_variable_node final : public straight_node<T, SAFE_MATH> {
1746 private:
1747//------------------------------------------------------------------------------
1752//------------------------------------------------------------------------------
1753 static std::string to_string(leaf_node<T, SAFE_MATH> *p) {
1754 return jit::format_to_string(reinterpret_cast<size_t> (p));
1755 }
1756
1757 public:
1758//------------------------------------------------------------------------------
1762//------------------------------------------------------------------------------
1765
1766//------------------------------------------------------------------------------
1771//------------------------------------------------------------------------------
1773 return constant<T, SAFE_MATH> (static_cast<T> (this->is_match(x)));
1774 }
1775
1776//------------------------------------------------------------------------------
1778//------------------------------------------------------------------------------
1779 virtual void to_latex() const {
1780 std::cout << "\\left(";
1781 this->arg->to_latex();
1782 std::cout << "\\right)";
1783 }
1784
1785//------------------------------------------------------------------------------
1789//------------------------------------------------------------------------------
1790 virtual bool is_all_variables() const {
1791 return true;
1792 }
1793
1794//------------------------------------------------------------------------------
1798//------------------------------------------------------------------------------
1799 virtual bool is_power_like() const {
1800 return true;
1801 }
1802
1803//------------------------------------------------------------------------------
1807//------------------------------------------------------------------------------
1809 return this->arg->get_power_base();
1810 }
1811
1812//------------------------------------------------------------------------------
1816//------------------------------------------------------------------------------
1818 return this->arg->get_power_exponent();
1819 }
1820
1821//------------------------------------------------------------------------------
1825//------------------------------------------------------------------------------
1826 virtual bool has_pseudo() const {
1827 return true;
1828 }
1829
1830//------------------------------------------------------------------------------
1834//------------------------------------------------------------------------------
1836 return this->arg;
1837 }
1838
1839//------------------------------------------------------------------------------
1845//------------------------------------------------------------------------------
1846 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
1847 jit::register_map &registers) {
1848 if (registers.find(this) == registers.end()) {
1849 const std::string name = jit::to_string('r', this);
1850 registers[this] = name;
1851 stream << " " << name
1852 << " [label = \"pseudo_variable\", shape = oval, style = filled, fillcolor = blue, fontcolor = white];" << std::endl;
1853
1854 auto a = this->arg->to_vizgraph(stream, registers);
1855 stream << " " << name << " -- " << registers[a.get()] << ";" << std::endl;
1856 }
1857
1858 return this->shared_from_this();
1859 }
1860 };
1861
1862//------------------------------------------------------------------------------
1870//------------------------------------------------------------------------------
1871 template<jit::float_scalar T, bool SAFE_MATH=false>
1873 return std::make_shared<pseudo_variable_node<T, SAFE_MATH>> (x);
1874 }
1875
1877 template<jit::float_scalar T, bool SAFE_MATH=false>
1878 using shared_pseudo_variable = std::shared_ptr<pseudo_variable_node<T, SAFE_MATH>>;
1879
1880//------------------------------------------------------------------------------
1888//------------------------------------------------------------------------------
1889 template<jit::float_scalar T, bool SAFE_MATH=false>
1891 return std::dynamic_pointer_cast<pseudo_variable_node<T, SAFE_MATH>> (x);
1892 }
1893}
1894
1895#endif /* node_h */
Class signature to implement compute backends.
Class representing a generic buffer.
Definition backend.hpp:29
bool is_normal() const
Check for normal values.
Definition backend.hpp:279
const T at(const size_t index) const
Get value at.
Definition backend.hpp:91
bool has_zero() const
Is any element zero.
Definition backend.hpp:156
size_t size() const
Get size of the buffer.
Definition backend.hpp:116
void set(const T d)
Assign a constant value.
Definition backend.hpp:100
T * data()
Get a pointer to the basic memory buffer.
Definition backend.hpp:270
Class representing a branch node.
Definition node.hpp:1165
shared_leaf< T, SAFE_MATH > get_left()
Get the left branch.
Definition node.hpp:1244
virtual bool is_all_variables() const
Test if node acts like a variable.
Definition node.hpp:1260
virtual void compile_preamble(std::ostringstream &stream, jit::register_map &registers, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d, int &avail_const_mem)
Compile preamble.
Definition node.hpp:1216
branch_node(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r, const std::string s, const size_t count, const bool pseudo)
Assigns the left and right branches.
Definition node.hpp:1197
shared_leaf< T, SAFE_MATH > get_right()
Get the right branch.
Definition node.hpp:1251
virtual std::shared_ptr< leaf_node< T, SAFE_MATH > > get_power_exponent() const
Get the exponent of a power.
Definition node.hpp:1271
branch_node(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r, const std::string s)
Assigns the left and right branches.
Definition node.hpp:1181
shared_leaf< T, SAFE_MATH > right
Right branch of the tree.
Definition node.hpp:1170
shared_leaf< T, SAFE_MATH > left
Left branch of the tree.
Definition node.hpp:1168
Class representing data that cannot change.
Definition node.hpp:729
virtual shared_leaf< T, SAFE_MATH > get_power_exponent() const
Get the exponent of a power.
Definition node.hpp:928
virtual void to_latex() const
Convert the node to latex.
Definition node.hpp:854
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition node.hpp:784
virtual shared_leaf< T, SAFE_MATH > get_power_base()
Get the base of a power.
Definition node.hpp:919
bool is(const T d)
Check if the constant is value.
Definition node.hpp:838
virtual backend::buffer< T > evaluate()
Evaluate method.
Definition node.hpp:760
virtual bool is_all_variables() const
Test if node acts like a variable.
Definition node.hpp:901
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Query if the nodes match.
Definition node.hpp:822
constant_node(const backend::buffer< T > &d)
Construct a constant node from a vector.
Definition node.hpp:750
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition node.hpp:865
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition node.hpp:770
virtual bool is_constant() const
Test if node is a constant.
Definition node.hpp:883
virtual bool has_constant_zero() const
Test the constant node has a zero.
Definition node.hpp:892
bool is_integer()
Check if the value is an integer.
Definition node.hpp:845
virtual bool is_power_like() const
Test if the node acts like a power of variable.
Definition node.hpp:910
Class representing a node leaf.
Definition node.hpp:364
virtual ~leaf_node()
Destructor.
Definition node.hpp:392
virtual void endline(std::ostringstream &stream, const jit::register_usage &usage) const final
End a line in the kernel source.
Definition node.hpp:639
virtual void set(const std::vector< T > &d)
Set the value of variable data.
Definition node.hpp:507
virtual bool is_power_like() const
Test if the node acts like a power of variable.
Definition node.hpp:572
virtual bool is_match(std::shared_ptr< leaf_node< T, SAFE_MATH > > x)
Query if the nodes match.
Definition node.hpp:472
virtual std::shared_ptr< leaf_node< T, SAFE_MATH > > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)=0
Compile the node.
virtual std::shared_ptr< leaf_node< T, SAFE_MATH > > get_power_exponent() const =0
Get the exponent of a power.
virtual void to_latex() const =0
Convert the node to latex.
T base
Type def to retrieve the backend type.
Definition node.hpp:671
virtual std::shared_ptr< leaf_node< T, SAFE_MATH > > df(std::shared_ptr< leaf_node< T, SAFE_MATH > > x)=0
Transform node to derivative.
const bool contains_pseudo
Node contains pseudo variables.
Definition node.hpp:373
virtual std::shared_ptr< leaf_node< T, SAFE_MATH > > remove_pseudo()
Remove pseudo variable nodes.
Definition node.hpp:629
virtual void set(const backend::buffer< T > &d)
Set the value of variable data.
Definition node.hpp:514
static thread_local caches_t caches
A per thread instance of the cache structure.
Definition node.hpp:668
virtual bool is_constant() const
Test if node is a constant.
Definition node.hpp:536
virtual void set(const size_t index, const T d)
Set the value of variable data.
Definition node.hpp:499
virtual std::shared_ptr< leaf_node< T, SAFE_MATH > > to_vizgraph(std::stringstream &stream, jit::register_map &registers)=0
Convert the node to vizgraph.
virtual std::shared_ptr< leaf_node< T, SAFE_MATH > > get_power_base()
Get the base of a power.
Definition node.hpp:583
virtual std::shared_ptr< leaf_node< T, SAFE_MATH > > reduce()
Reduction method.
Definition node.hpp:408
leaf_node(const std::string s, const size_t count, const bool pseudo)
Construct a basic node.
Definition node.hpp:383
virtual bool has_constant_zero() const
Test the constant node has a zero.
Definition node.hpp:545
virtual void set(const T d)
Set the value of variable data.
Definition node.hpp:491
const size_t complexity
Graph complexity.
Definition node.hpp:369
virtual backend::buffer< T > evaluate()=0
Evaluate method.
std::map< size_t, std::shared_ptr< leaf_node< T, SAFE_MATH > > > df_cache
Cache derivative terms.
Definition node.hpp:371
bool is_power_base_match(std::shared_ptr< leaf_node< T, SAFE_MATH > > x)
Check if the base of the powers match.
Definition node.hpp:482
virtual bool is_all_variables() const =0
Test if all the sub-nodes terminate in variables.
size_t get_complexity() const
Get the number of nodes in the subgraph.
Definition node.hpp:611
virtual bool has_pseudo() const
Query if the node contains pseudo variables.
Definition node.hpp:620
const size_t hash
Hash for node.
Definition node.hpp:367
virtual void compile_preamble(std::ostringstream &stream, jit::register_map &registers, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d, int &avail_const_mem)
Compile preamble.
Definition node.hpp:435
bool is_normal()
Test if the result is normal.
Definition node.hpp:554
size_t get_hash() const
Get the hash for the node.
Definition node.hpp:602
Class representing a subexpression that acts like a variable.
Definition node.hpp:1745
virtual void to_latex() const
Convert the node to latex.
Definition node.hpp:1779
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition node.hpp:1772
virtual bool is_power_like() const
Test if the node acts like a power of variable.
Definition node.hpp:1799
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition node.hpp:1846
virtual shared_leaf< T, SAFE_MATH > get_power_exponent() const
Get the exponent of a power.
Definition node.hpp:1817
virtual bool is_all_variables() const
Test if node acts like a variable.
Definition node.hpp:1790
virtual shared_leaf< T, SAFE_MATH > remove_pseudo()
Remove pseudo variable nodes.
Definition node.hpp:1835
virtual shared_leaf< T, SAFE_MATH > get_power_base()
Get the base of a power.
Definition node.hpp:1808
virtual bool has_pseudo() const
Query if the node contains pseudo variables.
Definition node.hpp:1826
pseudo_variable_node(shared_leaf< T, SAFE_MATH > a)
Construct a pseudo variable node.
Definition node.hpp:1763
Class representing a straight node.
Definition node.hpp:1051
virtual shared_leaf< T, SAFE_MATH > get_power_exponent() const
Get the exponent of a power.
Definition node.hpp:1147
shared_leaf< T, SAFE_MATH > arg
Argument.
Definition node.hpp:1054
virtual backend::buffer< T > evaluate()
Evaluate method.
Definition node.hpp:1073
shared_leaf< T, SAFE_MATH > get_arg()
Get the argument.
Definition node.hpp:1129
virtual void compile_preamble(std::ostringstream &stream, jit::register_map &registers, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d, int &avail_const_mem)
Compile preamble.
Definition node.hpp:1088
straight_node(shared_leaf< T, SAFE_MATH > a, const std::string s)
Construct a straight node.
Definition node.hpp:1063
virtual bool is_all_variables() const
Test if node acts like a variable.
Definition node.hpp:1138
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition node.hpp:1119
Class representing a triple branch node.
Definition node.hpp:1289
virtual void compile_preamble(std::ostringstream &stream, jit::register_map &registers, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d, int &avail_const_mem)
Compile preamble.
Definition node.hpp:1328
shared_leaf< T, SAFE_MATH > get_middle()
Get the right branch.
Definition node.hpp:1360
virtual bool is_all_variables() const
Test if node acts like a variable.
Definition node.hpp:1369
triple_node(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > m, shared_leaf< T, SAFE_MATH > r, const std::string s)
Reduces and assigns the left and right branches.
Definition node.hpp:1304
shared_leaf< T, SAFE_MATH > middle
Middle branch of the tree.
Definition node.hpp:1292
Class representing data that can change.
Definition node.hpp:1386
virtual shared_leaf< T, SAFE_MATH > get_power_base()
Get the base of a power.
Definition node.hpp:1631
variable_node(const backend::buffer< T > &d, const std::string &symbol)
Construct a variable node from backend buffer.
Definition node.hpp:1448
virtual bool is_all_variables() const
Test if node acts like a variable.
Definition node.hpp:1613
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition node.hpp:1470
variable_node(const size_t s, const T d, const std::string &symbol)
Construct a variable node from a scalar.
Definition node.hpp:1422
virtual void to_latex() const
Convert the node to latex.
Definition node.hpp:1568
size_t size()
Get the size of the variable buffer.
Definition node.hpp:1595
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition node.hpp:1514
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition node.hpp:1579
virtual void compile_preamble(std::ostringstream &stream, jit::register_map &registers, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d, int &avail_const_mem)
Compile preamble.
Definition node.hpp:1488
virtual backend::buffer< T > evaluate()
Evaluate method.
Definition node.hpp:1460
T * data()
Get a pointer to raw buffer.
Definition node.hpp:1604
virtual void set(const T d)
Set the value of variable data.
Definition node.hpp:1526
virtual bool is_power_like() const
Test if the node acts like a power of variable.
Definition node.hpp:1622
virtual void set(const backend::buffer< T > &d)
Set the value of variable data.
Definition node.hpp:1554
virtual void set(const size_t index, const T d)
Set the value of variable data.
Definition node.hpp:1536
variable_node(const size_t s, const std::string &symbol)
Construct a variable node with a size.
Definition node.hpp:1410
virtual void set(const std::vector< T > &d)
Set the value of variable data.
Definition node.hpp:1545
variable_node(const std::vector< T > &d, const std::string &symbol)
Construct a variable node from a vector.
Definition node.hpp:1435
std::string get_symbol() const
Get Symbol.
Definition node.hpp:1561
virtual shared_leaf< T, SAFE_MATH > get_power_exponent() const
Get the exponent of a power.
Definition node.hpp:1640
Complex scalar concept.
Definition register.hpp:24
subroutine assert(test, message)
Assert check.
Definition f_binding_test.f90:38
Name space for graph nodes.
Definition arithmetic.hpp:13
constexpr shared_leaf< T, SAFE_MATH > zero()
Forward declare for zero.
Definition node.hpp:986
std::shared_ptr< variable_node< T, SAFE_MATH > > shared_variable
Convenience type alias for shared variable nodes.
Definition node.hpp:1708
constexpr shared_leaf< T, SAFE_MATH > none()
Create a one constant.
Definition node.hpp:1012
void make_vizgraph(shared_leaf< T, SAFE_MATH > node)
Build the vizgraph input.
Definition node.hpp:706
constexpr shared_leaf< T, SAFE_MATH > one()
Forward declare for one.
Definition node.hpp:999
std::vector< shared_variable< T, SAFE_MATH > > input_nodes
Convenience type alias for a vector of inputs.
Definition node.hpp:1711
shared_constant< T, SAFE_MATH > constant_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a constant node.
Definition node.hpp:1034
shared_variable< T, SAFE_MATH > variable_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a variable node.
Definition node.hpp:1727
constexpr T i
Convenience type for imaginary constant.
Definition node.hpp:1018
shared_leaf< T, SAFE_MATH > constant(const backend::buffer< T > &d)
Construct a constant.
Definition node.hpp:943
shared_leaf< T, SAFE_MATH > variable(const size_t s, const std::string &symbol)
Construct a variable.
Definition node.hpp:1655
std::shared_ptr< pseudo_variable_node< T, SAFE_MATH > > shared_pseudo_variable
Convenience type alias for shared pseudo variable nodes.
Definition node.hpp:1878
std::shared_ptr< leaf_node< T, SAFE_MATH > > shared_leaf
Convenience type alias for shared leaf nodes.
Definition node.hpp:676
std::vector< std::pair< shared_leaf< T, SAFE_MATH >, shared_variable< T, SAFE_MATH > > > map_nodes
Convenience type alias for mapping end codes back to inputs.
Definition node.hpp:1715
constexpr shared_leaf< T, SAFE_MATH > null_leaf()
Create a null leaf.
Definition node.hpp:686
std::shared_ptr< constant_node< T, SAFE_MATH > > shared_constant
Convenience type alias for shared constant nodes.
Definition node.hpp:1022
std::vector< shared_leaf< T, SAFE_MATH > > output_nodes
Convenience type alias for a vector of output nodes.
Definition node.hpp:691
shared_leaf< T, SAFE_MATH > pseudo_variable(shared_leaf< T, SAFE_MATH > x)
Define pseudo variable convenience function.
Definition node.hpp:1872
shared_pseudo_variable< T, SAFE_MATH > pseudo_variable_cast(shared_leaf< T, SAFE_MATH > &x)
Cast to a pseudo variable node.
Definition node.hpp:1890
std::map< void *, size_t > texture1d_list
Type alias for indexing 1D textures.
Definition register.hpp:263
std::string format_to_string(const T value)
Convert a value to a string while avoiding locale.
Definition register.hpp:212
std::map< void *, std::array< size_t, 2 > > texture2d_list
Type alias for indexing 2D textures.
Definition register.hpp:265
std::map< void *, size_t > register_usage
Type alias for counting register usage.
Definition register.hpp:259
std::map< void *, std::string > register_map
Type alias for mapping node pointers to register names.
Definition register.hpp:257
std::set< void * > visiter_map
Type alias for listing visited nodes.
Definition register.hpp:261
std::string to_string(const char prefix, const NODE *pointer)
Convert a graph::leaf_node pointer to a string.
Definition register.hpp:246
Data structure to contain the two caches.
Definition node.hpp:660
std::map< size_t, std::shared_ptr< leaf_node< T, SAFE_MATH > > > nodes
Cache of node.
Definition node.hpp:662
std::map< size_t, backend::buffer< T > > backends
Cache of backend buffers.
Definition node.hpp:664