Graph Framework
Loading...
Searching...
No Matches
node.hpp
Go to the documentation of this file.
1//------------------------------------------------------------------------------
6//------------------------------------------------------------------------------
7//------------------------------------------------------------------------------
340//------------------------------------------------------------------------------
341#ifndef node_h
342#define node_h
343
344#include <iostream>
345#include <string>
346#include <memory>
347#include <iomanip>
348#include <functional>
349
350#include "backend.hpp"
351
353namespace graph {
354//******************************************************************************
355// Base leaf node.
356//******************************************************************************
357//------------------------------------------------------------------------------
362//------------------------------------------------------------------------------
363 template<jit::float_scalar T, bool SAFE_MATH=false>
364 class leaf_node : public std::enable_shared_from_this<leaf_node<T, SAFE_MATH>> {
365 protected:
367 const size_t hash;
369 const size_t complexity;
371 std::map<size_t, std::shared_ptr<leaf_node<T, SAFE_MATH>>> df_cache;
373 const bool contains_pseudo;
374
375 public:
376//------------------------------------------------------------------------------
382//------------------------------------------------------------------------------
383 leaf_node(const std::string s,
384 const size_t count,
385 const bool pseudo) :
386 hash(std::hash<std::string>{} (s)),
388
389//------------------------------------------------------------------------------
391//------------------------------------------------------------------------------
392 virtual ~leaf_node() {}
393
394//------------------------------------------------------------------------------
398//------------------------------------------------------------------------------
400
401//------------------------------------------------------------------------------
407//------------------------------------------------------------------------------
408 virtual std::shared_ptr<leaf_node> reduce() = 0;
409
410//------------------------------------------------------------------------------
415//------------------------------------------------------------------------------
416 virtual std::shared_ptr<leaf_node<T, SAFE_MATH>>
417 df(std::shared_ptr<leaf_node<T, SAFE_MATH>> x) = 0;
418
419//------------------------------------------------------------------------------
432//------------------------------------------------------------------------------
433 virtual void compile_preamble(std::ostringstream &stream,
434 jit::register_map &registers,
439 int &avail_const_mem) {
440#ifdef SHOW_USE_COUNT
441 if (usage.find(this) == usage.end()) {
442 usage[this] = 1;
443 } else {
444 ++usage[this];
445 }
446#endif
447 }
448
449//------------------------------------------------------------------------------
457//------------------------------------------------------------------------------
458 virtual std::shared_ptr<leaf_node<T, SAFE_MATH>>
459 compile(std::ostringstream &stream,
460 jit::register_map &registers,
462 const jit::register_usage &usage) = 0;
463
464//------------------------------------------------------------------------------
469//------------------------------------------------------------------------------
470 virtual bool is_match(std::shared_ptr<leaf_node<T, SAFE_MATH>> x) {
471 return this == x.get();
472 }
473
474//------------------------------------------------------------------------------
479//------------------------------------------------------------------------------
481 return this->get_power_base()->is_match(x->get_power_base());
482 }
483
484//------------------------------------------------------------------------------
488//------------------------------------------------------------------------------
489 virtual void set(const T d) {}
490
491//------------------------------------------------------------------------------
496//------------------------------------------------------------------------------
497 virtual void set(const size_t index,
498 const T d) {}
499
500//------------------------------------------------------------------------------
504//------------------------------------------------------------------------------
505 virtual void set(const std::vector<T> &d) {}
506
507//------------------------------------------------------------------------------
511//------------------------------------------------------------------------------
512 virtual void set(const backend::buffer<T> &d) {}
513
514//------------------------------------------------------------------------------
516//------------------------------------------------------------------------------
517 virtual void to_latex() const = 0;
518
519//------------------------------------------------------------------------------
525//------------------------------------------------------------------------------
526 virtual std::shared_ptr<leaf_node<T, SAFE_MATH>> to_vizgraph(std::stringstream &stream,
527 jit::register_map &registers) = 0;
528
529//------------------------------------------------------------------------------
533//------------------------------------------------------------------------------
534 virtual bool is_constant() const {
535 return false;
536 }
537
538//------------------------------------------------------------------------------
542//------------------------------------------------------------------------------
543 virtual bool has_constant_zero() const {
544 return false;
545 }
546
547//------------------------------------------------------------------------------
551//------------------------------------------------------------------------------
552 bool is_normal() {
553 return this->evaluate().is_normal();
554 }
555
556//------------------------------------------------------------------------------
560//------------------------------------------------------------------------------
561 virtual bool is_all_variables() const = 0;
562
563//------------------------------------------------------------------------------
569//------------------------------------------------------------------------------
570 virtual bool is_power_like() const {
571 return false;
572 }
573
574//------------------------------------------------------------------------------
580//------------------------------------------------------------------------------
581 virtual std::shared_ptr<leaf_node<T, SAFE_MATH>> get_power_base() {
582 return this->shared_from_this();
583 }
584
585//------------------------------------------------------------------------------
592//------------------------------------------------------------------------------
593 virtual std::shared_ptr<leaf_node<T, SAFE_MATH>> get_power_exponent() const = 0;
594
595//------------------------------------------------------------------------------
599//------------------------------------------------------------------------------
600 size_t get_hash() const {
601 return hash;
602 }
603
604//------------------------------------------------------------------------------
608//------------------------------------------------------------------------------
609 size_t get_complexity() const {
610 return complexity;
611 }
612
613//------------------------------------------------------------------------------
617//------------------------------------------------------------------------------
618 virtual bool has_pseudo() const {
619 return contains_pseudo;
620 }
621
622//------------------------------------------------------------------------------
626//------------------------------------------------------------------------------
627 virtual std::shared_ptr<leaf_node<T, SAFE_MATH>> remove_pseudo() {
628 return this->shared_from_this();
629 }
630
631//------------------------------------------------------------------------------
636//------------------------------------------------------------------------------
637 virtual void endline(std::ostringstream &stream,
639#ifndef SHOW_USE_COUNT
640 const
641#endif
642 final {
643 stream << ";"
644#ifdef SHOW_USE_COUNT
645 << " // used " << usage.at(this)
646#endif
647 << std::endl;
648 }
649
650// Create one struct that holds both caches: for constructed nodes and for the backend buffers
651//------------------------------------------------------------------------------
656//------------------------------------------------------------------------------
657 struct caches_t {
659 std::map<size_t, std::shared_ptr<leaf_node<T, SAFE_MATH>>> nodes;
661 std::map<size_t, backend::buffer<T>> backends;
662 };
663
665 inline static thread_local caches_t caches;
666
668 typedef T base;
669 };
670
672 template<jit::float_scalar T, bool SAFE_MATH=false>
673 using shared_leaf = std::shared_ptr<leaf_node<T, SAFE_MATH>>;
674//------------------------------------------------------------------------------
681//------------------------------------------------------------------------------
682 template<jit::float_scalar T, bool SAFE_MATH=false>
687 template<jit::float_scalar T, bool SAFE_MATH=false>
688 using output_nodes = std::vector<shared_leaf<T, SAFE_MATH>>;
689
691 template<jit::float_scalar T, bool SAFE_MATH=false>
694 template<jit::float_scalar T, bool SAFE_MATH=false>
695 constexpr shared_leaf<T, SAFE_MATH> one();
696
697//------------------------------------------------------------------------------
701//------------------------------------------------------------------------------
702 template<jit::float_scalar T, bool SAFE_MATH=false>
704 std::stringstream stream;
705 jit::register_map registers;
707
708 stream << "graph \"\" {" << std::endl;
709 stream << " node [fontname = \"Helvetica\", ordering = out]" << std::endl << std::endl;
710 node->to_vizgraph(stream, registers);
711 stream << "}" << std::endl;
712
713 std::cout << stream.str() << std::endl;
714 }
715
716//******************************************************************************
717// Constant node.
718//******************************************************************************
719//------------------------------------------------------------------------------
724//------------------------------------------------------------------------------
725 template<jit::float_scalar T, bool SAFE_MATH=false>
726 class constant_node final : public leaf_node<T, SAFE_MATH> {
727//------------------------------------------------------------------------------
732//------------------------------------------------------------------------------
733 static std::string to_string(const T d) {
734 return jit::format_to_string<T> (d);
735 }
736
737 private:
739 const backend::buffer<T> data;
740
741 public:
742//------------------------------------------------------------------------------
746//------------------------------------------------------------------------------
748 leaf_node<T, SAFE_MATH> (constant_node::to_string(d.at(0)), 1, false), data(d) {
749 assert(d.size() == 1 && "Constants need to be scalar functions.");
750 }
751
752//------------------------------------------------------------------------------
756//------------------------------------------------------------------------------
758 return data;
759 }
760
761//------------------------------------------------------------------------------
767//------------------------------------------------------------------------------
769 return this->shared_from_this();
770 }
771
772//------------------------------------------------------------------------------
777//------------------------------------------------------------------------------
781
782//------------------------------------------------------------------------------
790//------------------------------------------------------------------------------
792 compile(std::ostringstream &stream,
793 jit::register_map &registers,
795 const jit::register_usage &usage) {
796 if (registers.find(this) == registers.end()) {
797#ifdef USE_CONSTANT_CACHE
798 registers[this] = jit::to_string('r', this);
799 stream << " const ";
800 jit::add_type<T> (stream);
801 const T temp = this->evaluate().at(0);
802
803 stream << " " << registers[this] << " = ";
804 if constexpr (jit::complex_scalar<T>) {
805 jit::add_type<T> (stream);
806 }
807 stream << temp;
808 this->endline(stream, usage);
809#else
810 if constexpr (jit::complex_scalar<T>) {
811 registers[this] = jit::get_type_string<T> () + "("
812 + jit::format_to_string(this->evaluate().at(0))
813 + ")";
814 } else {
815 registers[this] = "(" + jit::get_type_string<T> () + ")"
816 + jit::format_to_string(this->evaluate().at(0));
817 }
818#endif
819 }
820
821 return this->shared_from_this();
822 }
823
824//------------------------------------------------------------------------------
829//------------------------------------------------------------------------------
831 if (this == x.get()) {
832 return true;
833 }
834
835 auto x_cast = constant_cast(x);
836 if (x_cast.get()) {
837 return this->evaluate() == x_cast->evaluate();
838 }
839
840 return false;
841 }
842
843//------------------------------------------------------------------------------
845//------------------------------------------------------------------------------
846 bool is(const T d) {
847 return data.size() == 1 && data.at(0) == d;
848 }
849
850//------------------------------------------------------------------------------
852//------------------------------------------------------------------------------
853 bool is_integer() {
854 const auto temp = this->evaluate().at(0);
855 return std::imag(temp) == 0 &&
856 fmod(std::real(temp), 1.0) == 0.0;
857 }
858
859//------------------------------------------------------------------------------
861//------------------------------------------------------------------------------
862 virtual void to_latex() const {
863 std::cout << data.at(0);
864 }
865
866//------------------------------------------------------------------------------
872//------------------------------------------------------------------------------
873 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
874 jit::register_map &registers) {
875 if (registers.find(this) == registers.end()) {
876 const std::string name = jit::to_string('r', this);
877 registers[this] = name;
878 stream << " " << name
879 << " [label = \"" << this->evaluate().at(0)
880 << "\", shape = box, style = \"rounded,filled\", fillcolor = black, fontcolor = white];" << std::endl;
881 }
882
883 return this->shared_from_this();
884 }
885
886//------------------------------------------------------------------------------
890//------------------------------------------------------------------------------
891 virtual bool is_constant() const {
892 return true;
893 }
894
895//------------------------------------------------------------------------------
899//------------------------------------------------------------------------------
900 virtual bool has_constant_zero() const {
901 return data.has_zero();
902 }
903
904//------------------------------------------------------------------------------
908//------------------------------------------------------------------------------
909 virtual bool is_all_variables() const {
910 return false;
911 }
912
913//------------------------------------------------------------------------------
917//------------------------------------------------------------------------------
918 virtual bool is_power_like() const {
919 return true;
920 }
921
922//------------------------------------------------------------------------------
926//------------------------------------------------------------------------------
928 return this->shared_from_this();
929 }
930
931//------------------------------------------------------------------------------
935//------------------------------------------------------------------------------
937 return one<T, SAFE_MATH> ();
938 }
939 };
940
941//------------------------------------------------------------------------------
949//------------------------------------------------------------------------------
950 template<jit::float_scalar T, bool SAFE_MATH=false>
952 auto temp = std::make_shared<constant_node<T, SAFE_MATH>> (d);
953// Test for hash collisions.
954 for (size_t i = temp->get_hash(); i < std::numeric_limits<size_t>::max(); i++) {
955 if (leaf_node<T, SAFE_MATH>::caches.nodes.find(i) ==
958 return temp;
959 } else if (temp->is_match(leaf_node<T, SAFE_MATH>::caches.nodes[i])) {
961 }
962 }
963#if defined(__clang__) || defined(__GNUC__)
965#else
966 assert(false && "Should never reach.");
967#endif
968 }
969
970//------------------------------------------------------------------------------
978//------------------------------------------------------------------------------
979 template<jit::float_scalar T, bool SAFE_MATH=false>
983
984// Define some common constants.
985//------------------------------------------------------------------------------
992//------------------------------------------------------------------------------
993 template<jit::float_scalar T, bool SAFE_MATH>
995 return constant<T, SAFE_MATH> (static_cast<T> (0.0));
996 }
997
998//------------------------------------------------------------------------------
1005//------------------------------------------------------------------------------
1006 template<jit::float_scalar T, bool SAFE_MATH>
1008 return constant<T, SAFE_MATH> (static_cast<T> (1.0));
1009 }
1010
1011//------------------------------------------------------------------------------
1018//------------------------------------------------------------------------------
1019 template<jit::float_scalar T, bool SAFE_MATH=false>
1021 return constant<T, SAFE_MATH> (static_cast<T> (-1.0));
1022 }
1023
1025 template<jit::complex_scalar T>
1026 constexpr T i = T(0.0, 1.0);
1027
1029 template<jit::float_scalar T, bool SAFE_MATH=false>
1030 using shared_constant = std::shared_ptr<constant_node<T, SAFE_MATH>>;
1031
1032//------------------------------------------------------------------------------
1040//------------------------------------------------------------------------------
1041 template<jit::float_scalar T, bool SAFE_MATH=false>
1043 return std::dynamic_pointer_cast<constant_node<T, SAFE_MATH>> (x);
1044 }
1045
1046//******************************************************************************
1047// Base straight node.
1048//******************************************************************************
1049//------------------------------------------------------------------------------
1057//------------------------------------------------------------------------------
1058 template<jit::float_scalar T, bool SAFE_MATH=false>
1059 class straight_node : public leaf_node<T, SAFE_MATH> {
1060 protected:
1063
1064 public:
1065//------------------------------------------------------------------------------
1070//------------------------------------------------------------------------------
1072 const std::string s) :
1073 leaf_node<T, SAFE_MATH> (s, a->get_complexity() + 1, a->has_pseudo()),
1074 arg(a) {}
1075
1076//------------------------------------------------------------------------------
1080//------------------------------------------------------------------------------
1082 return this->arg->evaluate();
1083 }
1084
1085//------------------------------------------------------------------------------
1095//------------------------------------------------------------------------------
1096 virtual void compile_preamble(std::ostringstream &stream,
1097 jit::register_map &registers,
1102 int &avail_const_mem) {
1103 if (visited.find(this) == visited.end()) {
1104 this->arg->compile_preamble(stream, registers,
1105 visited, usage,
1108 visited.insert(this);
1109#ifdef SHOW_USE_COUNT
1110 usage[this] = 1;
1111 } else {
1112 ++usage[this];
1113#endif
1114 }
1115 }
1116
1117//------------------------------------------------------------------------------
1125//------------------------------------------------------------------------------
1127 compile(std::ostringstream &stream,
1128 jit::register_map &registers,
1130 const jit::register_usage &usage) {
1131 return this->arg->compile(stream, registers, indices, usage);
1132 }
1133
1134//------------------------------------------------------------------------------
1136//------------------------------------------------------------------------------
1138 return this->arg;
1139 }
1140
1141//------------------------------------------------------------------------------
1145//------------------------------------------------------------------------------
1146 virtual bool is_all_variables() const {
1147 return this->arg->is_all_variables();
1148 }
1149
1150//------------------------------------------------------------------------------
1154//------------------------------------------------------------------------------
1156 return one<T, SAFE_MATH> ();
1157 }
1158 };
1159
1160//******************************************************************************
1161// Base branch node.
1162//******************************************************************************
1163//------------------------------------------------------------------------------
1171//------------------------------------------------------------------------------
1172 template<jit::float_scalar T, bool SAFE_MATH=false>
1173 class branch_node : public leaf_node<T, SAFE_MATH> {
1174 protected:
1179
1180 public:
1181
1182//------------------------------------------------------------------------------
1188//------------------------------------------------------------------------------
1191 const std::string s) :
1193 l->has_pseudo() || r->has_pseudo()),
1194 left(l), right(r) {}
1195
1196//------------------------------------------------------------------------------
1204//------------------------------------------------------------------------------
1207 const std::string s,
1208 const size_t count,
1209 const bool pseudo) :
1210 leaf_node<T, SAFE_MATH> (s, count, pseudo),
1211 left(l), right(r) {}
1212
1213//------------------------------------------------------------------------------
1223//------------------------------------------------------------------------------
1224 virtual void compile_preamble(std::ostringstream &stream,
1225 jit::register_map &registers,
1230 int &avail_const_mem) {
1231 if (visited.find(this) == visited.end()) {
1232 this->left->compile_preamble(stream, registers,
1233 visited, usage,
1236 this->right->compile_preamble(stream, registers,
1237 visited, usage,
1240 visited.insert(this);
1241#ifdef SHOW_USE_COUNT
1242 usage[this] = 1;
1243 } else {
1244 ++usage[this];
1245#endif
1246 }
1247 }
1248
1249//------------------------------------------------------------------------------
1251//------------------------------------------------------------------------------
1253 return this->left;
1254 }
1255
1256//------------------------------------------------------------------------------
1258//------------------------------------------------------------------------------
1260 return this->right;
1261 }
1262
1263//------------------------------------------------------------------------------
1267//------------------------------------------------------------------------------
1268 virtual bool is_all_variables() const {
1269 return this->left->is_all_variables() &&
1270 this->right->is_all_variables();
1271 }
1272
1273//------------------------------------------------------------------------------
1277//------------------------------------------------------------------------------
1278 virtual std::shared_ptr<leaf_node<T, SAFE_MATH>>
1280 return one<T, SAFE_MATH> ();
1281 }
1282 };
1283
1284//******************************************************************************
1285// Base triple node.
1286//******************************************************************************
1287//------------------------------------------------------------------------------
1295//------------------------------------------------------------------------------
1296 template<jit::float_scalar T, bool SAFE_MATH=false>
1297 class triple_node : public branch_node<T, SAFE_MATH> {
1298 protected:
1301
1302 public:
1303
1304//------------------------------------------------------------------------------
1311//------------------------------------------------------------------------------
1315 const std::string s) :
1316 branch_node<T, SAFE_MATH> (l, r, s,
1317 l->get_complexity() +
1318 m->get_complexity() +
1319 r->get_complexity(),
1320 l->has_pseudo() ||
1321 m->has_pseudo() ||
1322 r->has_pseudo()),
1323 middle(m) {}
1324
1325//------------------------------------------------------------------------------
1335//------------------------------------------------------------------------------
1336 virtual void compile_preamble(std::ostringstream &stream,
1337 jit::register_map &registers,
1342 int &avail_const_mem) {
1343 if (visited.find(this) == visited.end()) {
1344 this->left->compile_preamble(stream, registers,
1345 visited, usage,
1348 this->middle->compile_preamble(stream, registers,
1349 visited, usage,
1352 this->right->compile_preamble(stream, registers,
1353 visited, usage,
1356 visited.insert(this);
1357#ifdef SHOW_USE_COUNT
1358 usage[this] = 1;
1359 } else {
1360 ++usage[this];
1361#endif
1362 }
1363 }
1364
1365//------------------------------------------------------------------------------
1367//------------------------------------------------------------------------------
1369 return this->middle;
1370 }
1371
1372//------------------------------------------------------------------------------
1376//------------------------------------------------------------------------------
1377 virtual bool is_all_variables() const {
1378 return this->left->is_all_variables() &&
1379 this->middle->is_all_variables() &&
1380 this->right->is_all_variables();
1381 }
1382 };
1383
1384//******************************************************************************
1385// Variable node.
1386//******************************************************************************
1387//------------------------------------------------------------------------------
1392//------------------------------------------------------------------------------
1393 template<jit::float_scalar T, bool SAFE_MATH=false>
1394 class variable_node final : public leaf_node<T, SAFE_MATH> {
1395 private:
1397 backend::buffer<T> buffer;
1399 const std::string symbol;
1400
1401//------------------------------------------------------------------------------
1406//------------------------------------------------------------------------------
1407 static std::string to_string(variable_node<T, SAFE_MATH> *p) {
1408 return jit::format_to_string(reinterpret_cast<size_t> (p));
1409 }
1410
1411 public:
1412//------------------------------------------------------------------------------
1417//------------------------------------------------------------------------------
1418 variable_node(const size_t s,
1419 const std::string &symbol) :
1420 leaf_node<T, SAFE_MATH> (variable_node::to_string(this), 1, false),
1421 buffer(s), symbol(symbol) {}
1422
1423//------------------------------------------------------------------------------
1429//------------------------------------------------------------------------------
1430 variable_node(const size_t s, const T d,
1431 const std::string &symbol) :
1432 leaf_node<T, SAFE_MATH> (variable_node::to_string(this), 1, false),
1433 buffer(s, d), symbol(symbol) {
1434 assert(buffer.is_normal() && "NaN or Inf value.");
1435 }
1436
1437//------------------------------------------------------------------------------
1442//------------------------------------------------------------------------------
1443 variable_node(const std::vector<T> &d,
1444 const std::string &symbol) :
1445 leaf_node<T, SAFE_MATH> (variable_node::to_string(this), 1, false),
1446 buffer(d), symbol(symbol) {
1447 assert(buffer.is_normal() && "NaN or Inf value.");
1448 }
1449
1450//------------------------------------------------------------------------------
1455//------------------------------------------------------------------------------
1457 const std::string &symbol) :
1458 leaf_node<T, SAFE_MATH> (variable_node::to_string(this), 1, false),
1459 buffer(d), symbol(symbol) {
1460 assert(buffer.is_normal() && "NaN or Inf value.");
1461 }
1462
1463//------------------------------------------------------------------------------
1467//------------------------------------------------------------------------------
1469 return buffer;
1470 }
1471
1472//------------------------------------------------------------------------------
1478//------------------------------------------------------------------------------
1480 return this->shared_from_this();
1481 }
1482
1483//------------------------------------------------------------------------------
1488//------------------------------------------------------------------------------
1490 return constant<T, SAFE_MATH> (static_cast<T> (this->is_match(x)));
1491 }
1492
1493//------------------------------------------------------------------------------
1506//------------------------------------------------------------------------------
1507 virtual void compile_preamble(std::ostringstream &stream,
1508 jit::register_map &registers,
1513 int &avail_const_mem) {
1514 if (usage.find(this) == usage.end()) {
1515 usage[this] = 1;
1516#ifdef SHOW_USE_COUNT
1517 } else {
1518 ++usage[this];
1519#endif
1520 }
1521 }
1522
1523//------------------------------------------------------------------------------
1531//------------------------------------------------------------------------------
1533 compile(std::ostringstream &stream,
1534 jit::register_map &registers,
1536 const jit::register_usage &usage) {
1537 return this->shared_from_this();
1538 }
1539
1540//------------------------------------------------------------------------------
1544//------------------------------------------------------------------------------
1545 virtual void set(const T d) {
1546 buffer.set(d);
1547 }
1548
1549//------------------------------------------------------------------------------
1554//------------------------------------------------------------------------------
1555 virtual void set(const size_t index, const T d) {
1556 buffer[index] = d;
1557 }
1558
1559//------------------------------------------------------------------------------
1563//------------------------------------------------------------------------------
1564 virtual void set(const std::vector<T> &d) {
1565 buffer.set(d);
1566 }
1567
1568//------------------------------------------------------------------------------
1572//------------------------------------------------------------------------------
1573 virtual void set(const backend::buffer<T> &d) {
1574 buffer = d;
1575 }
1576
1577//------------------------------------------------------------------------------
1579//------------------------------------------------------------------------------
1580 std::string get_symbol() const {
1581 return symbol;
1582 }
1583
1584//------------------------------------------------------------------------------
1586//------------------------------------------------------------------------------
1587 virtual void to_latex() const {
1588 std::cout << get_symbol();
1589 }
1590
1591//------------------------------------------------------------------------------
1597//------------------------------------------------------------------------------
1598 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
1599 jit::register_map &registers) {
1600 if (registers.find(this) == registers.end()) {
1601 const std::string name = jit::to_string('r', this);
1602 registers[this] = name;
1603 stream << " " << name
1604 << " [label = \"" << this->get_symbol()
1605 << "\", shape = box];" << std::endl;
1606 }
1607
1608 return this->shared_from_this();
1609 }
1610
1611//------------------------------------------------------------------------------
1613//------------------------------------------------------------------------------
1614 size_t size() {
1615 return buffer.size();
1616 }
1617
1618//------------------------------------------------------------------------------
1622//------------------------------------------------------------------------------
1623 T *data() {
1624 return buffer.data();
1625 }
1626
1627//------------------------------------------------------------------------------
1631//------------------------------------------------------------------------------
1632 virtual bool is_all_variables() const {
1633 return true;
1634 }
1635
1636//------------------------------------------------------------------------------
1640//------------------------------------------------------------------------------
1641 virtual bool is_power_like() const {
1642 return true;
1643 }
1644
1645//------------------------------------------------------------------------------
1649//------------------------------------------------------------------------------
1651 return this->shared_from_this();
1652 }
1653
1654//------------------------------------------------------------------------------
1658//------------------------------------------------------------------------------
1660 return one<T, SAFE_MATH> ();
1661 }
1662 };
1663
1664//------------------------------------------------------------------------------
1672//------------------------------------------------------------------------------
1673 template<jit::float_scalar T, bool SAFE_MATH=false>
1675 const std::string &symbol) {
1676 return std::make_shared<variable_node<T, SAFE_MATH>> (s, symbol);
1677 }
1678
1679//------------------------------------------------------------------------------
1688//------------------------------------------------------------------------------
1689 template<jit::float_scalar T, bool SAFE_MATH=false>
1690 shared_leaf<T, SAFE_MATH> variable(const size_t s, const T d,
1691 const std::string &symbol) {
1692 return std::make_shared<variable_node<T, SAFE_MATH>> (s, d, symbol);
1693 }
1694
1695//------------------------------------------------------------------------------
1703//------------------------------------------------------------------------------
1704 template<jit::float_scalar T, bool SAFE_MATH=false>
1705 shared_leaf<T, SAFE_MATH> variable(const std::vector<T> &d,
1706 const std::string &symbol) {
1707 return std::make_shared<variable_node<T, SAFE_MATH>> (d, symbol);
1708 }
1709
1710//------------------------------------------------------------------------------
1718//------------------------------------------------------------------------------
1719 template<jit::float_scalar T, bool SAFE_MATH=false>
1721 const std::string &symbol) {
1722 return std::make_shared<variable_node<T, SAFE_MATH>> (d, symbol);
1723 }
1724
1726 template<jit::float_scalar T, bool SAFE_MATH=false>
1727 using shared_variable = std::shared_ptr<variable_node<T, SAFE_MATH>>;
1729 template<jit::float_scalar T, bool SAFE_MATH=false>
1730 using input_nodes = std::vector<shared_variable<T, SAFE_MATH>>;
1732 template<jit::float_scalar T, bool SAFE_MATH=false>
1733 using map_nodes = std::vector<std::pair<shared_leaf<T, SAFE_MATH>,
1735
1736//------------------------------------------------------------------------------
1744//------------------------------------------------------------------------------
1745 template<jit::float_scalar T, bool SAFE_MATH=false>
1747 return std::dynamic_pointer_cast<variable_node<T, SAFE_MATH>> (x);
1748 }
1749
1750//******************************************************************************
1751// Pseudo variable node.
1752//******************************************************************************
1753//------------------------------------------------------------------------------
1762//------------------------------------------------------------------------------
1763 template<jit::float_scalar T, bool SAFE_MATH=false>
1764 class pseudo_variable_node final : public straight_node<T, SAFE_MATH> {
1765 private:
1766//------------------------------------------------------------------------------
1771//------------------------------------------------------------------------------
1772 static std::string to_string(leaf_node<T, SAFE_MATH> *p) {
1773 return jit::format_to_string(reinterpret_cast<size_t> (p));
1774 }
1775
1776 public:
1777//------------------------------------------------------------------------------
1781//------------------------------------------------------------------------------
1784
1785//------------------------------------------------------------------------------
1791//------------------------------------------------------------------------------
1793 return this->shared_from_this();
1794 }
1795
1796//------------------------------------------------------------------------------
1801//------------------------------------------------------------------------------
1803 return constant<T, SAFE_MATH> (static_cast<T> (this->is_match(x)));
1804 }
1805
1806//------------------------------------------------------------------------------
1808//------------------------------------------------------------------------------
1809 virtual void to_latex() const {
1810 std::cout << "\\left(";
1811 this->arg->to_latex();
1812 std::cout << "\\right)";
1813 }
1814
1815//------------------------------------------------------------------------------
1819//------------------------------------------------------------------------------
1820 virtual bool is_all_variables() const {
1821 return true;
1822 }
1823
1824//------------------------------------------------------------------------------
1828//------------------------------------------------------------------------------
1829 virtual bool is_power_like() const {
1830 return true;
1831 }
1832
1833//------------------------------------------------------------------------------
1837//------------------------------------------------------------------------------
1839 return this->arg->get_power_base();
1840 }
1841
1842//------------------------------------------------------------------------------
1846//------------------------------------------------------------------------------
1848 return this->arg->get_power_exponent();
1849 }
1850
1851//------------------------------------------------------------------------------
1855//------------------------------------------------------------------------------
1856 virtual bool has_pseudo() const {
1857 return true;
1858 }
1859
1860//------------------------------------------------------------------------------
1864//------------------------------------------------------------------------------
1866 return this->arg;
1867 }
1868
1869//------------------------------------------------------------------------------
1875//------------------------------------------------------------------------------
1876 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
1877 jit::register_map &registers) {
1878 if (registers.find(this) == registers.end()) {
1879 const std::string name = jit::to_string('r', this);
1880 registers[this] = name;
1881 stream << " " << name
1882 << " [label = \"pseudo_variable\", shape = oval, style = filled, fillcolor = blue, fontcolor = white];" << std::endl;
1883
1884 auto a = this->arg->to_vizgraph(stream, registers);
1885 stream << " " << name << " -- " << registers[a.get()] << ";" << std::endl;
1886 }
1887
1888 return this->shared_from_this();
1889 }
1890 };
1891
1892//------------------------------------------------------------------------------
1900//------------------------------------------------------------------------------
1901 template<jit::float_scalar T, bool SAFE_MATH=false>
1903 return std::make_shared<pseudo_variable_node<T, SAFE_MATH>> (x);
1904 }
1905
1907 template<jit::float_scalar T, bool SAFE_MATH=false>
1908 using shared_pseudo_variable = std::shared_ptr<pseudo_variable_node<T, SAFE_MATH>>;
1909
1910//------------------------------------------------------------------------------
1918//------------------------------------------------------------------------------
1919 template<jit::float_scalar T, bool SAFE_MATH=false>
1921 return std::dynamic_pointer_cast<pseudo_variable_node<T, SAFE_MATH>> (x);
1922 }
1923}
1924
1925#endif /* node_h */
Class signature to impliment compute backends.
Class representing a generic buffer.
Definition backend.hpp:29
bool is_normal() const
Check for normal values.
Definition backend.hpp:279
const T at(const size_t index) const
Get value at.
Definition backend.hpp:91
bool has_zero() const
Is any element zero.
Definition backend.hpp:156
size_t size() const
Get size of the buffer.
Definition backend.hpp:116
void set(const T d)
Assign a constant value.
Definition backend.hpp:100
T * data()
Get a pointer to the basic memory buffer.
Definition backend.hpp:270
Class representing a branch node.
Definition node.hpp:1173
shared_leaf< T, SAFE_MATH > get_left()
Get the left branch.
Definition node.hpp:1252
virtual bool is_all_variables() const
Test if node acts like a variable.
Definition node.hpp:1268
virtual void compile_preamble(std::ostringstream &stream, jit::register_map &registers, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d, int &avail_const_mem)
Compile preamble.
Definition node.hpp:1224
branch_node(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r, const std::string s, const size_t count, const bool pseudo)
Assigns the left and right branches.
Definition node.hpp:1205
shared_leaf< T, SAFE_MATH > get_right()
Get the right branch.
Definition node.hpp:1259
virtual std::shared_ptr< leaf_node< T, SAFE_MATH > > get_power_exponent() const
Get the exponent of a power.
Definition node.hpp:1279
branch_node(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r, const std::string s)
Assigns the left and right branches.
Definition node.hpp:1189
shared_leaf< T, SAFE_MATH > right
Right branch of the tree.
Definition node.hpp:1178
shared_leaf< T, SAFE_MATH > left
Left branch of the tree.
Definition node.hpp:1176
Class representing data that cannot change.
Definition node.hpp:726
virtual shared_leaf< T, SAFE_MATH > get_power_exponent() const
Get the exponent of a power.
Definition node.hpp:936
virtual void to_latex() const
Convert the node to latex.
Definition node.hpp:862
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition node.hpp:792
virtual shared_leaf< T, SAFE_MATH > get_power_base()
Get the base of a power.
Definition node.hpp:927
bool is(const T d)
Check if the constant is value.
Definition node.hpp:846
virtual backend::buffer< T > evaluate()
Evaluate method.
Definition node.hpp:757
virtual bool is_all_variables() const
Test if node acts like a variable.
Definition node.hpp:909
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Querey if the nodes match.
Definition node.hpp:830
constant_node(const backend::buffer< T > &d)
Construct a constant node from a vector.
Definition node.hpp:747
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition node.hpp:873
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition node.hpp:778
virtual bool is_constant() const
Test if node is a constant.
Definition node.hpp:891
virtual bool has_constant_zero() const
Test the constant node has a zero.
Definition node.hpp:900
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduction method.
Definition node.hpp:768
bool is_integer()
Check if the value is an integer.
Definition node.hpp:853
virtual bool is_power_like() const
Test if the node acts like a power of variable.
Definition node.hpp:918
Class representing a node leaf.
Definition node.hpp:364
virtual ~leaf_node()
Destructor.
Definition node.hpp:392
virtual void endline(std::ostringstream &stream, const jit::register_usage &usage) const final
End a line in the kernel source.
Definition node.hpp:637
virtual void set(const std::vector< T > &d)
Set the value of variable data.
Definition node.hpp:505
virtual bool is_power_like() const
Test if the node acts like a power of variable.
Definition node.hpp:570
virtual bool is_match(std::shared_ptr< leaf_node< T, SAFE_MATH > > x)
Querey if the nodes match.
Definition node.hpp:470
virtual std::shared_ptr< leaf_node< T, SAFE_MATH > > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)=0
Compile the node.
virtual std::shared_ptr< leaf_node< T, SAFE_MATH > > get_power_exponent() const =0
Get the exponent of a power.
virtual void to_latex() const =0
Convert the node to latex.
T base
Type def to retrieve the backend type.
Definition node.hpp:668
virtual std::shared_ptr< leaf_node< T, SAFE_MATH > > df(std::shared_ptr< leaf_node< T, SAFE_MATH > > x)=0
Transform node to derivative.
const bool contains_pseudo
Node contains pseudo variables.
Definition node.hpp:373
virtual std::shared_ptr< leaf_node< T, SAFE_MATH > > remove_pseudo()
Remove pseudo variable nodes.
Definition node.hpp:627
virtual void set(const backend::buffer< T > &d)
Set the value of variable data.
Definition node.hpp:512
static thread_local caches_t caches
A per thread instance of the cache structure.
Definition node.hpp:665
virtual bool is_constant() const
Test if node is a constant.
Definition node.hpp:534
virtual void set(const size_t index, const T d)
Set the value of variable data.
Definition node.hpp:497
virtual std::shared_ptr< leaf_node< T, SAFE_MATH > > to_vizgraph(std::stringstream &stream, jit::register_map &registers)=0
Convert the node to vizgraph.
virtual std::shared_ptr< leaf_node< T, SAFE_MATH > > get_power_base()
Get the base of a power.
Definition node.hpp:581
leaf_node(const std::string s, const size_t count, const bool pseudo)
Construct a basic node.
Definition node.hpp:383
virtual bool has_constant_zero() const
Test the constant node has a zero.
Definition node.hpp:543
virtual void set(const T d)
Set the value of variable data.
Definition node.hpp:489
const size_t complexity
Graph complexity.
Definition node.hpp:369
virtual backend::buffer< T > evaluate()=0
Evaluate method.
std::map< size_t, std::shared_ptr< leaf_node< T, SAFE_MATH > > > df_cache
Cache derivative terms.
Definition node.hpp:371
bool is_power_base_match(std::shared_ptr< leaf_node< T, SAFE_MATH > > x)
Check if the base of the powers match.
Definition node.hpp:480
virtual bool is_all_variables() const =0
Test if all the subnodes terminate in variables.
virtual std::shared_ptr< leaf_node > reduce()=0
Reduction method.
size_t get_complexity() const
Get the number of nodes in the subgraph.
Definition node.hpp:609
virtual bool has_pseudo() const
Query if the node contains pseudo variables.
Definition node.hpp:618
const size_t hash
Hash for node.
Definition node.hpp:367
virtual void compile_preamble(std::ostringstream &stream, jit::register_map &registers, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d, int &avail_const_mem)
Compile preamble.
Definition node.hpp:433
bool is_normal()
Test if the result is normal.
Definition node.hpp:552
size_t get_hash() const
Get the hash for the node.
Definition node.hpp:600
Class representing a subexpression that acts like a variable.
Definition node.hpp:1764
virtual void to_latex() const
Convert the node to latex.
Definition node.hpp:1809
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition node.hpp:1802
virtual bool is_power_like() const
Test if the node acts like a power of variable.
Definition node.hpp:1829
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition node.hpp:1876
virtual shared_leaf< T, SAFE_MATH > get_power_exponent() const
Get the exponent of a power.
Definition node.hpp:1847
virtual bool is_all_variables() const
Test if node acts like a variable.
Definition node.hpp:1820
virtual shared_leaf< T, SAFE_MATH > remove_pseudo()
Remove pseudo variable nodes.
Definition node.hpp:1865
virtual shared_leaf< T, SAFE_MATH > get_power_base()
Get the base of a power.
Definition node.hpp:1838
virtual bool has_pseudo() const
Query if the node contains pseudo variables.
Definition node.hpp:1856
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduction method.
Definition node.hpp:1792
pseudo_variable_node(shared_leaf< T, SAFE_MATH > a)
Construct a pseudo variable node.
Definition node.hpp:1782
Class representing a straight node.
Definition node.hpp:1059
virtual shared_leaf< T, SAFE_MATH > get_power_exponent() const
Get the exponent of a power.
Definition node.hpp:1155
shared_leaf< T, SAFE_MATH > arg
Argument.
Definition node.hpp:1062
virtual backend::buffer< T > evaluate()
Evaluate method.
Definition node.hpp:1081
shared_leaf< T, SAFE_MATH > get_arg()
Get the argument.
Definition node.hpp:1137
virtual void compile_preamble(std::ostringstream &stream, jit::register_map &registers, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d, int &avail_const_mem)
Compile preamble.
Definition node.hpp:1096
straight_node(shared_leaf< T, SAFE_MATH > a, const std::string s)
Construct a straight node.
Definition node.hpp:1071
virtual bool is_all_variables() const
Test if node acts like a variable.
Definition node.hpp:1146
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition node.hpp:1127
Class representing a triple branch node.
Definition node.hpp:1297
virtual void compile_preamble(std::ostringstream &stream, jit::register_map &registers, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d, int &avail_const_mem)
Compile preamble.
Definition node.hpp:1336
shared_leaf< T, SAFE_MATH > get_middle()
Get the right branch.
Definition node.hpp:1368
virtual bool is_all_variables() const
Test if node acts like a variable.
Definition node.hpp:1377
triple_node(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > m, shared_leaf< T, SAFE_MATH > r, const std::string s)
Reduces and assigns the left and right branches.
Definition node.hpp:1312
shared_leaf< T, SAFE_MATH > middle
Middle branch of the tree.
Definition node.hpp:1300
Class representing data that can change.
Definition node.hpp:1394
virtual shared_leaf< T, SAFE_MATH > get_power_base()
Get the base of a power.
Definition node.hpp:1650
variable_node(const backend::buffer< T > &d, const std::string &symbol)
Construct a variable node from backend buffer.
Definition node.hpp:1456
virtual bool is_all_variables() const
Test if node acts like a variable.
Definition node.hpp:1632
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition node.hpp:1489
variable_node(const size_t s, const T d, const std::string &symbol)
Construct a variable node from a scalar.
Definition node.hpp:1430
virtual void to_latex() const
Convert the node to latex.
Definition node.hpp:1587
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduction method.
Definition node.hpp:1479
size_t size()
Get the size of the variable buffer.
Definition node.hpp:1614
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition node.hpp:1533
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition node.hpp:1598
virtual void compile_preamble(std::ostringstream &stream, jit::register_map &registers, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d, int &avail_const_mem)
Compile preamble.
Definition node.hpp:1507
virtual backend::buffer< T > evaluate()
Evaluate method.
Definition node.hpp:1468
T * data()
Get a pointer to raw buffer.
Definition node.hpp:1623
virtual void set(const T d)
Set the value of variable data.
Definition node.hpp:1545
virtual bool is_power_like() const
Test if the node acts like a power of variable.
Definition node.hpp:1641
virtual void set(const backend::buffer< T > &d)
Set the value of variable data.
Definition node.hpp:1573
virtual void set(const size_t index, const T d)
Set the value of variable data.
Definition node.hpp:1555
variable_node(const size_t s, const std::string &symbol)
Construct a variable node with a size.
Definition node.hpp:1418
virtual void set(const std::vector< T > &d)
Set the value of variable data.
Definition node.hpp:1564
variable_node(const std::vector< T > &d, const std::string &symbol)
Construct a variable node from a vector.
Definition node.hpp:1443
std::string get_symbol() const
Get Symbol.
Definition node.hpp:1580
virtual shared_leaf< T, SAFE_MATH > get_power_exponent() const
Get the exponent of a power.
Definition node.hpp:1659
Complex scalar concept.
Definition register.hpp:24
subroutine assert(test, message)
Assert check.
Definition f_binding_test.f90:38
Name space for graph nodes.
Definition arithmetic.hpp:13
constexpr shared_leaf< T, SAFE_MATH > zero()
Forward declare for zero.
Definition node.hpp:994
std::shared_ptr< variable_node< T, SAFE_MATH > > shared_variable
Convenience type alias for shared variable nodes.
Definition node.hpp:1727
constexpr shared_leaf< T, SAFE_MATH > none()
Create a one constant.
Definition node.hpp:1020
void make_vizgraph(shared_leaf< T, SAFE_MATH > node)
Build the vizgraph input.
Definition node.hpp:703
constexpr shared_leaf< T, SAFE_MATH > one()
Forward declare for one.
Definition node.hpp:1007
std::vector< shared_variable< T, SAFE_MATH > > input_nodes
Convenience type alias for a vector of inputs.
Definition node.hpp:1730
shared_constant< T, SAFE_MATH > constant_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a constant node.
Definition node.hpp:1042
shared_variable< T, SAFE_MATH > variable_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a variable node.
Definition node.hpp:1746
constexpr T i
Convinece type for imaginary constant.
Definition node.hpp:1026
shared_leaf< T, SAFE_MATH > constant(const backend::buffer< T > &d)
Construct a constant.
Definition node.hpp:951
shared_leaf< T, SAFE_MATH > variable(const size_t s, const std::string &symbol)
Construct a variable.
Definition node.hpp:1674
std::shared_ptr< pseudo_variable_node< T, SAFE_MATH > > shared_pseudo_variable
Convenience type alias for shared pseudo variable nodes.
Definition node.hpp:1908
std::shared_ptr< leaf_node< T, SAFE_MATH > > shared_leaf
Convenience type alias for shared leaf nodes.
Definition node.hpp:673
std::vector< std::pair< shared_leaf< T, SAFE_MATH >, shared_variable< T, SAFE_MATH > > > map_nodes
Convenience type alias for maping end codes back to inputs.
Definition node.hpp:1734
constexpr shared_leaf< T, SAFE_MATH > null_leaf()
Create a null leaf.
Definition node.hpp:683
std::shared_ptr< constant_node< T, SAFE_MATH > > shared_constant
Convenience type alias for shared constant nodes.
Definition node.hpp:1030
std::vector< shared_leaf< T, SAFE_MATH > > output_nodes
Convenience type alias for a vector of output nodes.
Definition node.hpp:688
shared_leaf< T, SAFE_MATH > pseudo_variable(shared_leaf< T, SAFE_MATH > x)
Define pseudo variable convience function.
Definition node.hpp:1902
shared_pseudo_variable< T, SAFE_MATH > pseudo_variable_cast(shared_leaf< T, SAFE_MATH > &x)
Cast to a pseudo variable node.
Definition node.hpp:1920
std::map< void *, size_t > texture1d_list
Type alias for indexing 1D textures.
Definition register.hpp:262
std::string format_to_string(const T value)
Convert a value to a string while avoiding locale.
Definition register.hpp:211
std::map< void *, std::array< size_t, 2 > > texture2d_list
Type alias for indexing 2D textures.
Definition register.hpp:264
std::map< void *, size_t > register_usage
Type alias for counting register usage.
Definition register.hpp:258
std::map< void *, std::string > register_map
Type alias for mapping node pointers to register names.
Definition register.hpp:256
std::set< void * > visiter_map
Type alias for listing visited nodes.
Definition register.hpp:260
std::string to_string(const char prefix, const NODE *pointer)
Convert a graph::leaf_node pointer to a string.
Definition register.hpp:245
Data structure to contain the two caches.
Definition node.hpp:657
std::map< size_t, std::shared_ptr< leaf_node< T, SAFE_MATH > > > nodes
Cache of node.
Definition node.hpp:659
std::map< size_t, backend::buffer< T > > backends
Cache of backend buffers.
Definition node.hpp:661