Graph Framework
Loading...
Searching...
No Matches
piecewise.hpp
Go to the documentation of this file.
1//------------------------------------------------------------------------------
6//------------------------------------------------------------------------------
7
8#ifndef piecewise_h
9#define piecewise_h
10
11#include "node.hpp"
12
13namespace graph {
14//------------------------------------------------------------------------------
24//------------------------------------------------------------------------------
25template<jit::float_scalar T>
26void compile_index(std::ostringstream &stream,
27 const std::string &register_name,
28 const size_t length,
29 const T scale,
30 const T offset) {
31 const std::string type = jit::smallest_int_type<T> (length);
32 stream << "min(max(("
33 << type
34 << ")";
35 if constexpr (jit::complex_scalar<T>) {
36 stream << "real(";
37 }
38 stream << "((" << register_name << " - ";
39 if constexpr (jit::complex_scalar<T>) {
41 }
42 stream << offset << ")/";
43 if constexpr (jit::complex_scalar<T>) {
45 }
46 stream << scale << ")";
47 if constexpr (jit::complex_scalar<T>) {
48 stream << ")";
49 }
50 stream << ",(" << type << ")0),("
51 << type << ")" << length - 1 << ")";
52}
53
54//******************************************************************************
55// 1D Piecewise node.
56//******************************************************************************
57//------------------------------------------------------------------------------
90//------------------------------------------------------------------------------
91 template<jit::float_scalar T, bool SAFE_MATH=false>
92 class piecewise_1D_node final : public straight_node<T, SAFE_MATH> {
93 private:
95 const T scale;
97 const T offset;
98
99//------------------------------------------------------------------------------
104//------------------------------------------------------------------------------
105 static std::string to_string(const backend::buffer<T> &d) {
106 std::string temp;
107 for (size_t i = 0, ie = d.size(); i < ie; i++) {
109 }
110
111 return temp;
112 }
113
114//------------------------------------------------------------------------------
120//------------------------------------------------------------------------------
121 static std::string to_string(const backend::buffer<T> &d,
123 return piecewise_1D_node::to_string(d) +
124 jit::format_to_string(x->get_hash());
125 }
126
127//------------------------------------------------------------------------------
132//------------------------------------------------------------------------------
133 static size_t hash_data(const backend::buffer<T> &d) {
134 const size_t h = std::hash<std::string>{} (piecewise_1D_node::to_string(d));
135 for (size_t i = h; i < std::numeric_limits<size_t>::max(); i++) {
136 if (leaf_node<T, SAFE_MATH>::caches.backends.find(i) ==
137 leaf_node<T, SAFE_MATH>::caches.backends.end()) {
139 return i;
140 } else if (d == leaf_node<T, SAFE_MATH>::caches.backends[i]) {
141 return i;
142 }
143 }
144#if defined(__clang__) || defined(__GNUC__)
146#else
147 assert(false && "Should never reach.");
148#endif
149 }
150
152 const size_t data_hash;
153
154 public:
155//------------------------------------------------------------------------------
162//------------------------------------------------------------------------------
165 const T scale,
166 const T offset) :
167 straight_node<T, SAFE_MATH> (x, piecewise_1D_node::to_string(d, x)),
168 data_hash(piecewise_1D_node::hash_data(d)), scale(scale),
169 offset(offset) {}
170
171//------------------------------------------------------------------------------
179//------------------------------------------------------------------------------
181 return leaf_node<T, SAFE_MATH>::caches.backends[data_hash];
182 }
183
184//------------------------------------------------------------------------------
191//------------------------------------------------------------------------------
193 if (constant_cast(this->arg).get()) {
194 const T arg = (this->arg->evaluate().at(0) + offset)/scale;
195 const size_t i = std::min(static_cast<size_t> (std::real(arg)),
196 this->get_size() - 1);
197 return constant<T, SAFE_MATH> (leaf_node<T, SAFE_MATH>::caches.backends[data_hash][i]);
198 }
199
200 if (evaluate().is_same()) {
201 return constant<T, SAFE_MATH> (evaluate().at(0));
202 }
203 return this->shared_from_this();
204 }
205
206//------------------------------------------------------------------------------
211//------------------------------------------------------------------------------
213 return constant<T, SAFE_MATH> (static_cast<T> (this->is_match(x)));
214 }
215
216//------------------------------------------------------------------------------
226//------------------------------------------------------------------------------
227 virtual void compile_preamble(std::ostringstream &stream,
228 jit::register_map &registers,
233 int &avail_const_mem) {
234 if (visited.find(this) == visited.end()) {
235 this->arg->compile_preamble(stream, registers,
236 visited, usage,
239 if (registers.find(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data()) == registers.end()) {
240 registers[leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data()] =
241 jit::to_string('a', leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data());
242 const size_t length = leaf_node<T, SAFE_MATH>::caches.backends[data_hash].size();
243 if constexpr (jit::use_metal<T> ()) {
244 textures1d.try_emplace(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data(),
245 length);
246#ifdef USE_CUDA_TEXTURES
247 } else if constexpr (jit::use_cuda()) {
248 textures1d.try_emplace(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data(),
249 length);
250#endif
251 } else {
252 if constexpr (jit::use_cuda()) {
253 const int buffer_size = length*sizeof(T);
254 if (avail_const_mem - buffer_size > 0) {
256 stream << "__constant__ ";
257 }
258 }
259 stream << "const ";
260 jit::add_type<T> (stream);
261 stream << " " << registers[leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data()] << "[] = {";
262 if constexpr (jit::complex_scalar<T>) {
263 jit::add_type<T> (stream);
264 }
265 stream << leaf_node<T, SAFE_MATH>::caches.backends[data_hash][0];
266 for (size_t i = 1; i < length; i++) {
267 stream << ", ";
268 if constexpr (jit::complex_scalar<T>) {
269 jit::add_type<T> (stream);
270 }
271 stream << leaf_node<T, SAFE_MATH>::caches.backends[data_hash][i];
272 }
273 stream << "};" << std::endl;
274 }
275 } else {
276// When using textures, the register can be defined in a previous kernel. We
277// need to add the textures again.
278 const size_t length = leaf_node<T, SAFE_MATH>::caches.backends[data_hash].size();
279 if constexpr (jit::use_metal<T> ()) {
280 textures1d.try_emplace(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data(),
281 length);
282#ifdef USE_CUDA_TEXTURES
283 } else if constexpr (jit::use_cuda()) {
284 textures1d.try_emplace(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data(),
285 length);
286#endif
287 }
288 }
289 visited.insert(this);
290#ifdef SHOW_USE_COUNT
291 usage[this] = 1;
292 } else {
293 ++usage[this];
294#endif
295 }
296 }
297
298//------------------------------------------------------------------------------
318//------------------------------------------------------------------------------
320 compile(std::ostringstream &stream,
321 jit::register_map &registers,
323 const jit::register_usage &usage) {
324 if (registers.find(this) == registers.end()) {
325#ifdef USE_INDEX_CACHE
326 if (indices.find(this->arg.get()) == indices.end()) {
327#endif
328 const size_t length = leaf_node<T, SAFE_MATH>::caches.backends[data_hash].size();
329 shared_leaf<T, SAFE_MATH> a = this->arg->compile(stream,
330 registers,
331 indices,
332 usage);
333#ifdef USE_INDEX_CACHE
334 indices[a.get()] = jit::to_string('i', a.get());
335 stream << " const "
336 << jit::smallest_int_type<T> (length) << " "
337 << indices[a.get()] << " = ";
338 compile_index<T> (stream, registers[a.get()], length,
339 scale, offset);
340 a->endline(stream, usage);
341 }
342#endif
343
344 registers[this] = jit::to_string('r', this);
345 stream << " const ";
346 jit::add_type<T> (stream);
347 stream << " " << registers[this] << " = ";
348#ifdef USE_CUDA_TEXTURES
349 if constexpr (jit::use_cuda()) {
350 if constexpr (float_base<T>) {
351 if constexpr (complex_scalar<T>) {
352 stream << "to_cmp_float(tex1D<float2> (";
353 } else {
354 stream << "tex1D<float> (";
355 }
356 } else {
357 if constexpr (complex_scalar<T>) {
358 stream << "to_cmp_double(tex1D<uint4> (";
359 } else {
360 stream << "to_double(tex1D<uint2> (";
361 }
362 }
363 }
364#endif
365 stream << registers[leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data()];
366 if constexpr (jit::use_metal<T> ()) {
367#ifdef USE_INDEX_CACHE
368 stream << ".read("
369 << indices[this->arg.get()]
370 << ").r";
371#else
372 stream << ".read(";
373 compile_index<T> (stream, registers[a.get()], length,
374 scale, offset);
375 stream << ").r";
376#endif
377#ifdef USE_CUDA_TEXTURES
378 } else if constexpr (jit::use_cuda()) {
379#ifdef USE_INDEX_CACHE
380 stream << ", "
381 << indices[this->arg.get()];
382#else
383 stream << ", ";
384 compile_index<T> (stream, registers[a.get()], length,
385 scale, offset);
386#endif
388 stream << ")";
389 }
390 stream << ")";
391#endif
392 } else {
393#ifdef USE_INDEX_CACHE
394 stream << "["
395 << indices[this->arg.get()]
396 << "]";
397#else
398 stream << "[";
399 compile_index<T> (stream, registers[a.get()], length,
400 scale, offset);
401 stream << "]";
402#endif
403 }
404 this->endline(stream, usage);
405 }
406
407 return this->shared_from_this();
408 }
409
410//------------------------------------------------------------------------------
418//------------------------------------------------------------------------------
420 auto x_cast = piecewise_1D_cast(x);
421
422 if (x_cast.get()) {
423 return this->data_hash == x_cast->data_hash &&
424 this->arg->is_match(x_cast->get_arg());
425 }
426
427 return false;
428 }
429
430//------------------------------------------------------------------------------
432//------------------------------------------------------------------------------
433 virtual void to_latex() const {
434 std::cout << "r\\_" << reinterpret_cast<size_t> (this) << "_{i}";
435 }
436
437//------------------------------------------------------------------------------
443//------------------------------------------------------------------------------
444 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
445 jit::register_map &registers) {
446 if (registers.find(this) == registers.end()) {
447 const std::string name = jit::to_string('r', this);
448 registers[this] = name;
449 stream << " " << name
450 << " [label = \"r_" << reinterpret_cast<size_t> (this)
451 << "_{i}\", shape = hexagon, style = filled, fillcolor = black, fontcolor = white];" << std::endl;
452
453 auto a = this->arg->to_vizgraph(stream, registers);
454 stream << " " << name << " -- " << registers[a.get()] << ";" << std::endl;
455 }
456
457 return this->shared_from_this();
458 }
459
460//------------------------------------------------------------------------------
464//------------------------------------------------------------------------------
465 virtual bool is_constant() const {
466 return true;
467 }
468
469//------------------------------------------------------------------------------
473//------------------------------------------------------------------------------
474 virtual bool has_constant_zero() const {
475 return leaf_node<T, SAFE_MATH>::caches.backends[data_hash].has_zero();
476 }
477
478//------------------------------------------------------------------------------
482//------------------------------------------------------------------------------
483 virtual bool is_all_variables() const {
484 return false;
485 }
486
487//------------------------------------------------------------------------------
491//------------------------------------------------------------------------------
492 virtual bool is_power_like() const {
493 return true;
494 }
495
496//------------------------------------------------------------------------------
500//------------------------------------------------------------------------------
502 return this->shared_from_this();
503 }
504
505//------------------------------------------------------------------------------
509//------------------------------------------------------------------------------
511 return one<T, SAFE_MATH> ();
512 }
513
514//------------------------------------------------------------------------------
519//------------------------------------------------------------------------------
521 auto temp = piecewise_1D_cast(x);
522 return temp.get() &&
523 this->arg->is_match(temp->get_arg()) &&
524 (temp->get_size() == this->get_size()) &&
525 (temp->get_scale() == this->scale) &&
526 (temp->get_offset() == this->offset);
527 }
528
529//------------------------------------------------------------------------------
533//------------------------------------------------------------------------------
534 T get_scale() const {
535 return scale;
536 }
537
538//------------------------------------------------------------------------------
542//------------------------------------------------------------------------------
543 T get_offset() const {
544 return offset;
545 }
546
547//------------------------------------------------------------------------------
551//------------------------------------------------------------------------------
552 size_t get_size() const {
553 return leaf_node<T, SAFE_MATH>::caches.backends[data_hash].size();
554 }
555 };
556
557//------------------------------------------------------------------------------
568//------------------------------------------------------------------------------
569 template<jit::float_scalar T, bool SAFE_MATH=false>
572 const T scale,
573 const T offset) {
574 auto temp = std::make_shared<piecewise_1D_node<T, SAFE_MATH>> (d, x,
575 scale,
576 offset)->reduce();
577// Test for hash collisions.
578 for (size_t i = temp->get_hash(); i < std::numeric_limits<size_t>::max(); i++) {
579 if (leaf_node<T, SAFE_MATH>::caches.nodes.find(i) ==
582 return temp;
583 } else if (temp->is_match(leaf_node<T, SAFE_MATH>::caches.nodes[i])) {
585 }
586 }
587#if defined(__clang__) || defined(__GNUC__)
589#else
590 assert(false && "Should never reach.");
591#endif
592 }
593
595 template<jit::float_scalar T, bool SAFE_MATH=false>
596 using shared_piecewise_1D = std::shared_ptr<piecewise_1D_node<T, SAFE_MATH>>;
597
598//------------------------------------------------------------------------------
606//------------------------------------------------------------------------------
607 template<jit::float_scalar T, bool SAFE_MATH=false>
609 return std::dynamic_pointer_cast<piecewise_1D_node<T, SAFE_MATH>> (x);
610 }
611
612//******************************************************************************
613// 2D Piecewise node.
614//******************************************************************************
615//------------------------------------------------------------------------------
655//------------------------------------------------------------------------------
656 template<jit::float_scalar T, bool SAFE_MATH=false>
657 class piecewise_2D_node final : public branch_node<T, SAFE_MATH> {
658 private:
660 const T x_scale;
662 const T x_offset;
664 const T y_scale;
666 const T y_offset;
667
668//------------------------------------------------------------------------------
673//------------------------------------------------------------------------------
674 static std::string to_string(const backend::buffer<T> &d) {
675 std::string temp;
676 for (size_t i = 0, ie = d.size(); i < ie; i++) {
678 }
679
680 return temp;
681 }
682
683//------------------------------------------------------------------------------
690//------------------------------------------------------------------------------
691 static std::string to_string(const backend::buffer<T> &d,
694 return piecewise_2D_node::to_string(d) +
695 jit::format_to_string(x->get_hash()) +
696 jit::format_to_string(y->get_hash());
697 }
698
699//------------------------------------------------------------------------------
704//------------------------------------------------------------------------------
705 static size_t hash_data(const backend::buffer<T> &d) {
706 const size_t h = std::hash<std::string>{} (piecewise_2D_node::to_string(d));
707 for (size_t i = h; i < std::numeric_limits<size_t>::max(); i++) {
708 if (leaf_node<T, SAFE_MATH>::caches.backends.find(i) ==
709 leaf_node<T, SAFE_MATH>::caches.backends.end()) {
711 return i;
712 } else if (d == leaf_node<T, SAFE_MATH>::caches.backends[i]) {
713 return i;
714 }
715 }
716#if defined(__clang__) || defined(__GNUC__)
718#else
719 assert(false && "Should never reach.");
720#endif
721 }
722
724 const size_t data_hash;
726 const size_t num_columns;
727
728 public:
729//------------------------------------------------------------------------------
740//------------------------------------------------------------------------------
742 const size_t n,
744 const T x_scale,
745 const T x_offset,
747 const T y_scale,
748 const T y_offset) :
749 branch_node<T, SAFE_MATH> (x, y, piecewise_2D_node::to_string(d, x, y)),
750 data_hash(piecewise_2D_node::hash_data(d)),
751 num_columns(n), x_scale(x_scale), x_offset(x_offset), y_scale(y_scale),
752 y_offset(y_offset) {
753 assert(d.size()%n == 0 &&
754 "Expected the data buffer to be a multiple of the number of columns.");
755 }
756
757//------------------------------------------------------------------------------
761//------------------------------------------------------------------------------
762 size_t get_num_columns() const {
763 return num_columns;
764 }
765
766//------------------------------------------------------------------------------
770//------------------------------------------------------------------------------
771 size_t get_num_rows() const {
772 return leaf_node<T, SAFE_MATH>::caches.backends[data_hash].size() /
773 num_columns;
774 }
775
776//------------------------------------------------------------------------------
780//------------------------------------------------------------------------------
781 T get_x_scale() const {
782 return x_scale;
783 }
784
785//------------------------------------------------------------------------------
789//------------------------------------------------------------------------------
790 T get_x_offset() const {
791 return x_offset;
792 }
793
794//------------------------------------------------------------------------------
798//------------------------------------------------------------------------------
799 T get_y_scale() const {
800 return y_scale;
801 }
802
803//------------------------------------------------------------------------------
807//------------------------------------------------------------------------------
808 T get_y_offset() const {
809 return y_offset;
810 }
811
812//------------------------------------------------------------------------------
820//------------------------------------------------------------------------------
822 return leaf_node<T, SAFE_MATH>::caches.backends[data_hash];
823 }
824
825//------------------------------------------------------------------------------
832//------------------------------------------------------------------------------
834 if (constant_cast(this->left).get() &&
835 constant_cast(this->right).get()) {
836 const T l = (this->left->evaluate().at(0) + x_offset)/x_scale;
837 const T r = (this->right->evaluate().at(0) + y_offset)/y_scale;
838 const size_t i = std::min(static_cast<size_t> (std::real(l)),
839 this->get_num_rows() - 1);
840 const size_t j = std::min(static_cast<size_t> (std::real(r)),
841 this->get_num_columns() - 1);
842 return constant<T, SAFE_MATH> (leaf_node<T, SAFE_MATH>::caches.backends[data_hash][i*this->get_num_columns() + j]);
843 } else if (constant_cast(this->left).get()) {
844 const T l = (this->left->evaluate().at(0) + x_offset)/x_scale;
845 const size_t i = std::min(static_cast<size_t> (std::real(l)),
846 this->get_num_rows() - 1);
847
848 return piecewise_1D(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].index_row(i, this->get_num_columns()),
849 this->right, y_scale, y_offset);
850 } else if (constant_cast(this->right).get()) {
851 const T r = (this->right->evaluate().at(0) + y_offset)/y_scale;
852 const size_t j = std::min(static_cast<size_t> (std::real(r)),
853 this->get_num_columns() - 1);
854
855 return piecewise_1D(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].index_column(j, this->get_num_columns()),
856 this->left, x_scale, x_offset);
857 }
858
859 if (evaluate().is_same()) {
860 return constant<T, SAFE_MATH> (evaluate().at(0));
861 }
862
863 return this->shared_from_this();
864 }
865
866//------------------------------------------------------------------------------
871//------------------------------------------------------------------------------
873 return constant<T, SAFE_MATH> (static_cast<T> (this->is_match(x)));
874 }
875
876//------------------------------------------------------------------------------
886//------------------------------------------------------------------------------
887 virtual void compile_preamble(std::ostringstream &stream,
888 jit::register_map &registers,
893 int &avail_const_mem) {
894 if (visited.find(this) == visited.end()) {
895 this->left->compile_preamble(stream, registers,
896 visited, usage,
899 this->right->compile_preamble(stream, registers,
900 visited, usage,
903 if (registers.find(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data()) == registers.end()) {
904 registers[leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data()] =
905 jit::to_string('a', leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data());
906 const size_t length = leaf_node<T, SAFE_MATH>::caches.backends[data_hash].size();
907 if constexpr (jit::use_metal<T> ()) {
908 textures2d.try_emplace(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data(),
909 std::array<size_t, 2> ({length/num_columns, num_columns}));
910#ifdef USE_CUDA_TEXTURES
911 } else if constexpr (jit::use_cuda()) {
912 textures2d.try_emplace(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data(),
913 std::array<size_t, 2> ({length/num_columns, num_columns}));
914#endif
915 } else {
916 if constexpr (jit::use_cuda()) {
917 const int buffer_size = length*sizeof(T);
918 if (avail_const_mem - buffer_size > 0) {
920 stream << "__constant__ ";
921 }
922 }
923 stream << "const ";
924 jit::add_type<T> (stream);
925 stream << " " << registers[leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data()] << "[] = {";
926 if constexpr (jit::complex_scalar<T>) {
927 jit::add_type<T> (stream);
928 }
929 stream << leaf_node<T, SAFE_MATH>::caches.backends[data_hash][0];
930 for (size_t i = 1; i < length; i++) {
931 stream << ", ";
932 if constexpr (jit::complex_scalar<T>) {
933 jit::add_type<T> (stream);
934 }
935 stream << leaf_node<T, SAFE_MATH>::caches.backends[data_hash][i];
936 }
937 stream << "};" << std::endl;
938 }
939 } else {
940// When using textures, the register can be defined in a previous kernel. We
941// need to add the textures again.
942 const size_t length = leaf_node<T, SAFE_MATH>::caches.backends[data_hash].size();
943 if constexpr (jit::use_metal<T> ()) {
944 textures2d.try_emplace(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data(),
945 std::array<size_t, 2> ({length/num_columns, num_columns}));
946#ifdef USE_CUDA_TEXTURES
947 } else if constexpr (jit::use_cuda()) {
948 textures2d.try_emplace(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data(),
949 std::array<size_t, 2> ({length/num_columns, num_columns}));
950#endif
951 }
952 }
953 visited.insert(this);
954#ifdef SHOW_USE_COUNT
955 usage[this] = 1;
956 } else {
957 ++usage[this];
958#endif
959 }
960 }
961
962//------------------------------------------------------------------------------
995//------------------------------------------------------------------------------
997 compile(std::ostringstream &stream,
998 jit::register_map &registers,
1000 const jit::register_usage &usage) {
1001 if (registers.find(this) == registers.end()) {
1002 const size_t length = leaf_node<T, SAFE_MATH>::caches.backends[data_hash].size();
1003 const size_t num_rows = length/num_columns;
1004
1005 shared_leaf<T, SAFE_MATH> x = this->left->compile(stream,
1006 registers,
1007 indices,
1008 usage);
1009 shared_leaf<T, SAFE_MATH> y = this->right->compile(stream,
1010 registers,
1011 indices,
1012 usage);
1013
1014#ifdef USE_INDEX_CACHE
1015 if (indices.find(x.get()) == indices.end()) {
1016 indices[x.get()] = jit::to_string('i', x.get());
1017 stream << " const "
1018 << jit::smallest_int_type<T> (num_rows) << " "
1019 << indices[x.get()] << " = ";
1020 compile_index<T> (stream, registers[x.get()], num_rows,
1021 x_scale, x_offset);
1022 x->endline(stream, usage);
1023 }
1024 if (indices.find(y.get()) == indices.end()) {
1025 indices[y.get()] = jit::to_string('i', y.get());
1026 stream << " const "
1027 << jit::smallest_int_type<T> (num_columns) << " "
1028 << indices[y.get()] << " = ";
1029 compile_index<T> (stream, registers[y.get()], num_columns,
1030 y_scale, y_offset);
1031 y->endline(stream, usage);
1032 }
1033
1034 auto temp = this->left + this->right;
1035 if constexpr (!jit::use_metal<T> ()
1036#ifdef USE_CUDA_TEXTURES
1037 || !jit::use_cuda()
1038#endif
1039 ) {
1040 if (indices.find(temp.get()) == indices.end()) {
1041 indices[temp.get()] = jit::to_string('i', temp.get());
1042 stream << " const "
1043 << jit::smallest_int_type<T> (length) << " "
1044 << indices[temp.get()] << " = "
1045 << indices[x.get()]
1046 << "*" << num_columns << " + "
1047 << indices[y.get()]
1048 << ";" << std::endl;
1049 }
1050 }
1051#endif
1052
1053 registers[this] = jit::to_string('r', this);
1054 stream << " const ";
1055 jit::add_type<T> (stream);
1056 stream << " " << registers[this] << " = ";
1057#ifdef USE_CUDA_TEXTURES
1058 if constexpr (jit::use_cuda()) {
1059 if constexpr (float_base<T>) {
1060 if constexpr (complex_scalar<T>) {
1061 stream << "to_cmp_float(tex1D<float2> (";
1062 } else {
1063 stream << "tex1D<float> (";
1064 }
1065 } else {
1066 if constexpr (complex_scalar<T>) {
1067 stream << "to_cmp_double(tex1D<uint4> (";
1068 } else {
1069 stream << "to_double(tex1D<uint2> (";
1070 }
1071 }
1072 }
1073#endif
1074 stream << registers[leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data()];
1075 if constexpr (jit::use_metal<T> ()) {
1076#ifdef USE_INDEX_CACHE
1077 stream << ".read("
1078 << jit::smallest_int_type<T> (std::max(num_rows,
1079 num_columns))
1080 << "2("
1081 << indices[y.get()]
1082 << ","
1083 << indices[x.get()]
1084 << ")).r";
1085#else
1086 stream << ".read(uint2(";
1087 compile_index<T> (stream, registers[y.get()], num_columns,
1088 y_scale, y_offset);
1089 stream << ",";
1090 compile_index<T> (stream, registers[x.get()], num_rows,
1091 x_scale, x_offset);
1092 stream << ")).r";
1093#endif
1094#ifdef USE_CUDA_TEXTURES
1095 } else if constexpr (jit::use_cuda()) {
1096#ifdef USE_INDEX_CACHE
1097 stream << ", "
1098 << indices[y.get()]
1099 << ", "
1100 << indices[x.get()];
1101#else
1102 stream << ", ";
1103 compile_index<T> (stream, registers[y.get()], num_columns,
1104 y_scale, y_offset);
1105 stream << ", ";
1106 compile_index<T> (stream, registers[x.get()], num_rows,
1107 x_scale, x_offset);
1108#endif
1110 stream << ")";
1111 }
1112 stream << ")";
1113#endif
1114 } else {
1115#ifdef USE_INDEX_CACHE
1116 stream << "["
1117 << indices[temp.get()]
1118 << "]";
1119#else
1120 stream << "[";
1121 compile_index<T> (stream, registers[x.get()], num_rows,
1122 x_scale, x_offset);
1123 stream << "*" << num_columns << " + ";
1124 compile_index<T> (stream, registers[y.get()], num_columns,
1125 y_scale, y_offset);
1126 stream << "]";
1127#endif
1128 }
1129 this->endline(stream, usage);
1130 }
1131
1132 return this->shared_from_this();
1133 }
1134
1135//------------------------------------------------------------------------------
1142//------------------------------------------------------------------------------
1144 auto x_cast = piecewise_2D_cast(x);
1145
1146 if (x_cast.get()) {
1147 return this->data_hash == x_cast->data_hash &&
1148 this->left->is_match(x_cast->get_left()) &&
1149 this->right->is_match(x_cast->get_right());
1150 }
1151
1152 return false;
1153 }
1154
1155//------------------------------------------------------------------------------
1159//------------------------------------------------------------------------------
1160 virtual void to_latex() const {
1161 std::cout << "r\\_" << reinterpret_cast<size_t> (this) << "_{ij}";
1162 }
1163
1164//------------------------------------------------------------------------------
1170//------------------------------------------------------------------------------
1171 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
1172 jit::register_map &registers) {
1173 if (registers.find(this) == registers.end()) {
1174 const std::string name = jit::to_string('r', this);
1175 registers[this] = name;
1176 stream << " " << name
1177 << " [label = \"r_" << reinterpret_cast<size_t> (this)
1178 << "_{ij}\", shape = hexagon, style = filled, fillcolor = black, fontcolor = white];" << std::endl;
1179
1180 auto l = this->left->to_vizgraph(stream, registers);
1181 stream << " " << name << " -- " << registers[l.get()] << ";" << std::endl;
1182 auto r = this->right->to_vizgraph(stream, registers);
1183 stream << " " << name << " -- " << registers[r.get()] << ";" << std::endl;
1184 }
1185
1186 return this->shared_from_this();
1187 }
1188
1189//------------------------------------------------------------------------------
1193//------------------------------------------------------------------------------
1194 virtual bool is_constant() const {
1195 return true;
1196 }
1197
1198//------------------------------------------------------------------------------
1202//------------------------------------------------------------------------------
1203 virtual bool has_constant_zero() const {
1204 return leaf_node<T, SAFE_MATH>::caches.backends[data_hash].has_zero();
1205 }
1206
1207//------------------------------------------------------------------------------
1211//------------------------------------------------------------------------------
1212 virtual bool is_all_variables() const {
1213 return false;
1214 }
1215
1216//------------------------------------------------------------------------------
1220//------------------------------------------------------------------------------
1221 virtual bool is_power_like() const {
1222 return true;
1223 }
1224
1225//------------------------------------------------------------------------------
1229//------------------------------------------------------------------------------
1231 return this->shared_from_this();
1232 }
1233
1234//------------------------------------------------------------------------------
1238//------------------------------------------------------------------------------
1240 return one<T, SAFE_MATH> ();
1241 }
1242
1243//------------------------------------------------------------------------------
1248//------------------------------------------------------------------------------
1250 auto temp = piecewise_2D_cast(x);
1251 return temp.get() &&
1252 this->left->is_match(temp->get_left()) &&
1253 this->right->is_match(temp->get_right()) &&
1254 (temp->get_num_rows() == this->get_num_rows()) &&
1255 (temp->get_num_columns() == this->get_num_columns()) &&
1256 (temp->get_x_scale() == this->x_scale) &&
1257 (temp->get_x_offset() == this->x_offset) &&
1258 (temp->get_y_scale() == this->y_scale) &&
1259 (temp->get_y_offset() == this->y_offset);
1260 }
1261
1262//------------------------------------------------------------------------------
1267//------------------------------------------------------------------------------
1269 auto temp = piecewise_1D_cast(x);
1270 return temp.get() &&
1271 this->left->is_match(temp->get_arg()) &&
1272 (temp->get_size() == this->get_num_rows()) &&
1273 (temp->get_scale() == this->x_scale) &&
1274 (temp->get_offset() == this->x_offset);
1275 }
1276
1277//------------------------------------------------------------------------------
1284//------------------------------------------------------------------------------
1286 auto temp = piecewise_1D_cast(x);
1287 return temp.get() &&
1288 this->right->is_match(temp->get_arg()) &&
1289 (temp->get_size() == this->get_num_columns()) &&
1290 (temp->get_scale() == this->y_scale) &&
1291 (temp->get_offset() == this->y_offset);
1292 }
1293 };
1294
1295//------------------------------------------------------------------------------
1310//------------------------------------------------------------------------------
1311 template<jit::float_scalar T, bool SAFE_MATH=false>
1313 const size_t n,
1315 const T x_scale,
1316 const T x_offset,
1318 const T y_scale,
1319 const T y_offset) {
1320 auto temp = std::make_shared<piecewise_2D_node<T, SAFE_MATH>> (d, n,
1321 x, x_scale, x_offset,
1322 y, y_scale, y_offset)->reduce();
1323// Test for hash collisions.
1324 for (size_t i = temp->get_hash(); i < std::numeric_limits<size_t>::max(); i++) {
1325 if (leaf_node<T, SAFE_MATH>::caches.nodes.find(i) ==
1326 leaf_node<T, SAFE_MATH>::caches.nodes.end()) {
1328 return temp;
1329 } else if (temp->is_match(leaf_node<T, SAFE_MATH>::caches.nodes[i])) {
1330 return leaf_node<T, SAFE_MATH>::caches.nodes[i];
1331 }
1332 }
1333#if defined(__clang__) || defined(__GNUC__)
1335#else
1336 assert(false && "Should never reach.");
1337#endif
1338 }
1339
1341 template<jit::float_scalar T, bool SAFE_MATH=false>
1342 using shared_piecewise_2D = std::shared_ptr<piecewise_2D_node<T, SAFE_MATH>>;
1343
1344//------------------------------------------------------------------------------
1352//------------------------------------------------------------------------------
1353 template<jit::float_scalar T, bool SAFE_MATH=false>
1355 return std::dynamic_pointer_cast<piecewise_2D_node<T, SAFE_MATH>> (x);
1356 }
1357
1358//******************************************************************************
1359// 1D Index node.
1360//******************************************************************************
1361//------------------------------------------------------------------------------
1372//------------------------------------------------------------------------------
1373 template<jit::float_scalar T, bool SAFE_MATH=false>
1374 class index_1D_node final : public branch_node<T, SAFE_MATH> {
1375 private:
1377 const T scale;
1379 const T offset;
1380
1381//------------------------------------------------------------------------------
1387//------------------------------------------------------------------------------
1388 static std::string to_string(shared_leaf<T, SAFE_MATH> v,
1390 return jit::format_to_string(v->get_hash()) + "[" +
1391 jit::format_to_string(x->get_hash()) + "]";
1392 }
1393
1394 public:
1395//------------------------------------------------------------------------------
1402//------------------------------------------------------------------------------
1405 const T scale,
1406 const T offset) :
1407 branch_node<T, SAFE_MATH> (var, x, index_1D_node::to_string(var, x)),
1408 scale(scale), offset(offset) {}
1409
1410//------------------------------------------------------------------------------
1418//------------------------------------------------------------------------------
1420 return this->right->evaluate();
1421 }
1422
1423//------------------------------------------------------------------------------
1428//------------------------------------------------------------------------------
1430 return constant<T, SAFE_MATH> (static_cast<T> (this->is_match(x)));
1431 }
1432
1433//------------------------------------------------------------------------------
1446//------------------------------------------------------------------------------
1448 compile(std::ostringstream &stream,
1449 jit::register_map &registers,
1451 const jit::register_usage &usage) {
1452 if (registers.find(this) == registers.end()) {
1453#ifdef USE_INDEX_CACHE
1454 if (indices.find(this->right.get()) == indices.end()) {
1455#endif
1456 const size_t length = variable_cast(this->left)->size();
1457 shared_leaf<T, SAFE_MATH> a = this->right->compile(stream,
1458 registers,
1459 indices,
1460 usage);
1461#ifdef USE_INDEX_CACHE
1462 indices[a.get()] = jit::to_string('i', a.get());
1463 stream << " const "
1464 << jit::smallest_int_type<T> (length) << " "
1465 << indices[a.get()] << " = ";
1466 compile_index<T> (stream, registers[a.get()], length,
1467 scale, offset);
1468 a->endline(stream, usage);
1469 }
1470#endif
1471
1472 registers[this] = jit::to_string('r', this);
1473 stream << " const ";
1474 jit::add_type<T> (stream);
1475 auto var = this->left->compile(stream,
1476 registers,
1477 indices,
1478 usage);
1479 stream << " " << registers[this] << " = "
1480 << jit::to_string('v', var.get());
1481#ifdef USE_INDEX_CACHE
1482 stream << "[" << indices[this->right.get()] << "]";
1483#else
1484 stream << "[";
1485 compile_index<T> (stream, registers[a.get()], length,
1486 scale, offset);
1487 stream << "]";
1488#endif
1489 this->endline(stream, usage);
1490 }
1491
1492 return this->shared_from_this();
1493 }
1494
1495//------------------------------------------------------------------------------
1497//------------------------------------------------------------------------------
1498 virtual void to_latex() const {
1499 std::cout << "r\\_" << reinterpret_cast<size_t> (this->left.get())
1500 << "\\left[i\\_"
1501 << reinterpret_cast<size_t> (this->right.get())
1502 << "\\right]";
1503 }
1504
1505//------------------------------------------------------------------------------
1511//------------------------------------------------------------------------------
1512 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
1513 jit::register_map &registers) {
1514 if (registers.find(this) == registers.end()) {
1515 const std::string name = jit::to_string('r', this);
1516 registers[this] = name;
1517 stream << " " << name
1518 << " [label = \"r_" << reinterpret_cast<size_t> (this->left.get())
1519 << "\", shape = hexagon, style = filled, fillcolor = black, fontcolor = white];" << std::endl;
1520
1521 auto l = this->left->to_vizgraph(stream, registers);
1522 stream << " " << name << " -- " << registers[l.get()] << ";" << std::endl;
1523 auto r = this->right->to_vizgraph(stream, registers);
1524 stream << " " << name << " -- " << registers[r.get()] << ";" << std::endl;
1525 }
1526
1527 return this->shared_from_this();
1528 }
1529
1530//------------------------------------------------------------------------------
1534//------------------------------------------------------------------------------
1535 virtual bool is_constant() const {
1536 return false;
1537 }
1538
1539//------------------------------------------------------------------------------
1543//------------------------------------------------------------------------------
1544 virtual bool is_all_variables() const {
1545 return false;
1546 }
1547
1548//------------------------------------------------------------------------------
1552//------------------------------------------------------------------------------
1553 virtual bool is_power_like() const {
1554 return true;
1555 }
1556
1557//------------------------------------------------------------------------------
1561//------------------------------------------------------------------------------
1563 return one<T, SAFE_MATH> ();
1564 }
1565
1566//------------------------------------------------------------------------------
1571//------------------------------------------------------------------------------
1573 auto temp = index_1D_cast(x);
1574 return temp.get() &&
1575 this->right->is_match(temp->get_right()) &&
1576 (temp->get_size() == this->get_size()) &&
1577 (temp->get_scale() == this->scale) &&
1578 (temp->get_offset() == this->offset);
1579 }
1580
1581//------------------------------------------------------------------------------
1585//------------------------------------------------------------------------------
1586 T get_scale() const {
1587 return scale;
1588 }
1589
1590//------------------------------------------------------------------------------
1594//------------------------------------------------------------------------------
1595 T get_offset() const {
1596 return offset;
1597 }
1598
1599//------------------------------------------------------------------------------
1603//------------------------------------------------------------------------------
1604 size_t get_size() const {
1605 return variable_cast(this->left)->size();
1606 }
1607 };
1608
1609//------------------------------------------------------------------------------
1620//------------------------------------------------------------------------------
1621 template<jit::float_scalar T, bool SAFE_MATH=false>
1624 const T scale,
1625 const T offset) {
1627 "index_1D requires a variable node for first arg.");
1628 auto temp = std::make_shared<index_1D_node<T, SAFE_MATH>> (v, x,
1629 scale,
1630 offset)->reduce();
1631// Test for hash collisions.
1632 for (size_t i = temp->get_hash(); i < std::numeric_limits<size_t>::max(); i++) {
1633 if (leaf_node<T, SAFE_MATH>::caches.nodes.find(i) ==
1634 leaf_node<T, SAFE_MATH>::caches.nodes.end()) {
1636 return temp;
1637 } else if (temp->is_match(leaf_node<T, SAFE_MATH>::caches.nodes[i])) {
1638 return leaf_node<T, SAFE_MATH>::caches.nodes[i];
1639 }
1640 }
1641#if defined(__clang__) || defined(__GNUC__)
1643#else
1644 assert(false && "Should never reach.");
1645#endif
1646 }
1647
1649 template<jit::float_scalar T, bool SAFE_MATH=false>
1650 using shared_index_1D = std::shared_ptr<index_1D_node<T, SAFE_MATH>>;
1651
1652//------------------------------------------------------------------------------
1660//------------------------------------------------------------------------------
1661 template<jit::float_scalar T, bool SAFE_MATH=false>
1663 return std::dynamic_pointer_cast<index_1D_node<T, SAFE_MATH>> (x);
1664 }
1665
1666//******************************************************************************
1667// 2D Index node.
1668//******************************************************************************
1669//------------------------------------------------------------------------------
1681//------------------------------------------------------------------------------
1682 template<jit::float_scalar T, bool SAFE_MATH=false>
1683 class index_2D_node final : public triple_node<T, SAFE_MATH> {
1684 private:
1686 const T x_scale;
1688 const T x_offset;
1690 const T y_scale;
1692 const T y_offset;
1694 const size_t num_columns;
1695
1696//------------------------------------------------------------------------------
1703//------------------------------------------------------------------------------
1704 static std::string to_string(shared_leaf<T, SAFE_MATH> v,
1707 return jit::format_to_string(v->get_hash()) + "[" +
1708 jit::format_to_string(x->get_hash()) + "," +
1709 jit::format_to_string(y->get_hash()) + "]";
1710 }
1711
1712 public:
1713//------------------------------------------------------------------------------
1724//------------------------------------------------------------------------------
1726 const size_t n,
1728 const T x_scale,
1729 const T x_offset,
1731 const T y_scale,
1732 const T y_offset) :
1733 triple_node<T, SAFE_MATH> (var, x, y, index_2D_node::to_string(var, x, y)),
1734 num_columns(n), x_scale(x_scale), x_offset(x_offset), y_scale(y_scale),
1735 y_offset(y_offset) {
1736 assert(variable_cast(this->left)->size()%n == 0 &&
1737 "Expected the data buffer to be a multiple of the number of columns.");
1738 }
1739
1740//------------------------------------------------------------------------------
1748//------------------------------------------------------------------------------
1750 return this->left->evaluate();
1751 }
1752
1753//------------------------------------------------------------------------------
1758//------------------------------------------------------------------------------
1760 return constant<T, SAFE_MATH> (static_cast<T> (this->is_match(x)));
1761 }
1762
1763//------------------------------------------------------------------------------
1777//------------------------------------------------------------------------------
1779 compile(std::ostringstream &stream,
1780 jit::register_map &registers,
1782 const jit::register_usage &usage) {
1783 if (registers.find(this) == registers.end()) {
1784 const size_t length = variable_cast(this->left)->size();
1785 const size_t num_rows = length/num_columns;
1786
1787 shared_leaf<T, SAFE_MATH> x = this->middle->compile(stream,
1788 registers,
1789 indices,
1790 usage);
1791 shared_leaf<T, SAFE_MATH> y = this->right->compile(stream,
1792 registers,
1793 indices,
1794 usage);
1795
1796#ifdef USE_INDEX_CACHE
1797 if (indices.find(x.get()) == indices.end()) {
1798 indices[x.get()] = jit::to_string('i', x.get());
1799 stream << " const "
1800 << jit::smallest_int_type<T> (num_rows) << " "
1801 << indices[x.get()] << " = ";
1802 compile_index<T> (stream, registers[x.get()], num_rows,
1803 x_scale, x_offset);
1804 x->endline(stream, usage);
1805 }
1806 if (indices.find(y.get()) == indices.end()) {
1807 indices[y.get()] = jit::to_string('i', y.get());
1808 stream << " const "
1809 << jit::smallest_int_type<T> (num_columns) << " "
1810 << indices[y.get()] << " = ";
1811 compile_index<T> (stream, registers[y.get()], num_columns,
1812 y_scale, y_offset);
1813 y->endline(stream, usage);
1814 }
1815
1816 auto temp = this->middle + this->right;
1817 if constexpr (!jit::use_metal<T> () ||
1818 !jit::use_cuda()) {
1819 if (indices.find(temp.get()) == indices.end()) {
1820 indices[temp.get()] = jit::to_string('i', temp.get());
1821 stream << " const "
1822 << jit::smallest_int_type<T> (length) << " "
1823 << indices[temp.get()] << " = "
1824 << indices[x.get()]
1825 << "*" << num_columns << " + "
1826 << indices[y.get()]
1827 << ";" << std::endl;
1828 }
1829 }
1830#endif
1831
1832 registers[this] = jit::to_string('r', this);
1833 stream << " const ";
1834 jit::add_type<T> (stream);
1835 auto var = this->left->compile(stream,
1836 registers,
1837 indices,
1838 usage);
1839 stream << " " << registers[this] << " = "
1840 << jit::to_string('v', var.get());
1841#ifdef USE_INDEX_CACHE
1842 stream << "["
1843 << indices[temp.get()]
1844 << "]";
1845#else
1846 stream << "[";
1847 compile_index<T> (stream, registers[x.get()], num_rows,
1848 x_scale, x_offset);
1849 stream << "*" << num_columns << " + ";
1850 compile_index<T> (stream, registers[y.get()], num_columns,
1851 y_scale, y_offset);
1852 stream << "]";
1853#endif
1854 this->endline(stream, usage);
1855 }
1856
1857 return this->shared_from_this();
1858 }
1859
1860//------------------------------------------------------------------------------
1862//------------------------------------------------------------------------------
1863 virtual void to_latex() const {
1864 std::cout << "r\\_" << reinterpret_cast<size_t> (this->left.get())
1865 << "\\left[i\\_"
1866 << reinterpret_cast<size_t> (this->middle.get())
1867 << ",j\\_"
1868 << reinterpret_cast<size_t> (this->right.get())
1869 << "\\right]";
1870 }
1871
1872//------------------------------------------------------------------------------
1878//------------------------------------------------------------------------------
1879 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
1880 jit::register_map &registers) {
1881 if (registers.find(this) == registers.end()) {
1882 const std::string name = jit::to_string('r', this);
1883 registers[this] = name;
1884 stream << " " << name
1885 << " [label = \"r_" << reinterpret_cast<size_t> (this->left.get())
1886 << "\", shape = hexagon, style = filled, fillcolor = black, fontcolor = white];" << std::endl;
1887
1888 auto l = this->left->to_vizgraph(stream, registers);
1889 stream << " " << name << " -- " << registers[l.get()] << ";" << std::endl;
1890 auto m = this->middle->to_vizgraph(stream, registers);
1891 stream << " " << name << " -- " << registers[m.get()] << ";" << std::endl;
1892 auto r = this->right->to_vizgraph(stream, registers);
1893 stream << " " << name << " -- " << registers[r.get()] << ";" << std::endl;
1894 }
1895
1896 return this->shared_from_this();
1897 }
1898
1899//------------------------------------------------------------------------------
1903//------------------------------------------------------------------------------
1904 virtual bool is_constant() const {
1905 return false;
1906 }
1907
1908//------------------------------------------------------------------------------
1912//------------------------------------------------------------------------------
1913 virtual bool is_all_variables() const {
1914 return false;
1915 }
1916
1917//------------------------------------------------------------------------------
1921//------------------------------------------------------------------------------
1922 virtual bool is_power_like() const {
1923 return true;
1924 }
1925
1926//------------------------------------------------------------------------------
1930//------------------------------------------------------------------------------
1932 return one<T, SAFE_MATH> ();
1933 }
1934
1935//------------------------------------------------------------------------------
1940//------------------------------------------------------------------------------
1942 auto temp = index_1D_cast(x);
1943 return temp.get() &&
1944 this->right->is_match(temp->get_right()) &&
1945 (temp->get_size() == this->get_size()) &&
1946 (temp->get_scale() == this->scale) &&
1947 (temp->get_offset() == this->offset);
1948 }
1949
1950//------------------------------------------------------------------------------
1954//------------------------------------------------------------------------------
1955 T get_x_scale() const {
1956 return x_scale;
1957 }
1958
1959//------------------------------------------------------------------------------
1963//------------------------------------------------------------------------------
1964 T get_x_offset() const {
1965 return x_offset;
1966 }
1967
1968//------------------------------------------------------------------------------
1972//------------------------------------------------------------------------------
1973 T get_y_scale() const {
1974 return y_scale;
1975 }
1976
1977//------------------------------------------------------------------------------
1981//------------------------------------------------------------------------------
1982 T get_y_offset() const {
1983 return y_offset;
1984 }
1985
1986//------------------------------------------------------------------------------
1990//------------------------------------------------------------------------------
1991 size_t get_size() const {
1992 return variable_cast(this->left)->size();
1993 }
1994 };
1995
1996//------------------------------------------------------------------------------
2011//------------------------------------------------------------------------------
2012 template<jit::float_scalar T, bool SAFE_MATH=false>
2014 const size_t n,
2016 const T x_scale,
2017 const T x_offset,
2019 const T y_scale,
2020 const T y_offset) {
2022 "index_2D requires a variable node for first arg.");
2023 auto temp = std::make_shared<index_2D_node<T, SAFE_MATH>> (v, n,
2024 x, x_scale, x_offset,
2025 y, y_scale, y_offset)->reduce();
2026// Test for hash collisions.
2027 for (size_t i = temp->get_hash(); i < std::numeric_limits<size_t>::max(); i++) {
2028 if (leaf_node<T, SAFE_MATH>::caches.nodes.find(i) ==
2029 leaf_node<T, SAFE_MATH>::caches.nodes.end()) {
2031 return temp;
2032 } else if (temp->is_match(leaf_node<T, SAFE_MATH>::caches.nodes[i])) {
2033 return leaf_node<T, SAFE_MATH>::caches.nodes[i];
2034 }
2035 }
2036#if defined(__clang__) || defined(__GNUC__)
2038#else
2039 assert(false && "Should never reach.");
2040#endif
2041 }
2042
2044 template<jit::float_scalar T, bool SAFE_MATH=false>
2045 using shared_index_2D = std::shared_ptr<index_2D_node<T, SAFE_MATH>>;
2046
2047//------------------------------------------------------------------------------
2055//------------------------------------------------------------------------------
2056 template<jit::float_scalar T, bool SAFE_MATH=false>
2058 return std::dynamic_pointer_cast<index_2D_node<T, SAFE_MATH>> (x);
2059 }
2060}
2061
2062#endif /* piecewise_h */
Class representing a generic buffer.
Definition backend.hpp:29
Class representing a branch node.
Definition node.hpp:1165
shared_leaf< T, SAFE_MATH > right
Right branch of the tree.
Definition node.hpp:1170
shared_leaf< T, SAFE_MATH > left
Left branch of the tree.
Definition node.hpp:1168
Class representing a 1D index.
Definition piecewise.hpp:1374
virtual shared_leaf< T, SAFE_MATH > get_power_exponent() const
Get the exponent of a power.
Definition piecewise.hpp:1562
virtual backend::buffer< T > evaluate()
Evaluate the results of the piecewise constant.
Definition piecewise.hpp:1419
size_t get_size() const
Get the size of the buffer.
Definition piecewise.hpp:1604
virtual bool is_power_like() const
Test if the node acts like a power of variable.
Definition piecewise.hpp:1553
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
the node.
Definition piecewise.hpp:1448
T get_scale() const
Get x argument scale.
Definition piecewise.hpp:1586
bool is_arg_match(shared_leaf< T, SAFE_MATH > x)
Check if the args match.
Definition piecewise.hpp:1572
T get_offset() const
Get x argument offset.
Definition piecewise.hpp:1595
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition piecewise.hpp:1512
index_1D_node(shared_leaf< T, SAFE_MATH > var, shared_leaf< T, SAFE_MATH > x, const T scale, const T offset)
Construct a 1D index.
Definition piecewise.hpp:1403
virtual bool is_all_variables() const
Test if node acts like a variable.
Definition piecewise.hpp:1544
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition piecewise.hpp:1429
virtual bool is_constant() const
Test if node is a constant.
Definition piecewise.hpp:1535
virtual void to_latex() const
Convert the node to latex.
Definition piecewise.hpp:1498
Class representing a 2D index.
Definition piecewise.hpp:1683
size_t get_size() const
Get the size of the buffer.
Definition piecewise.hpp:1991
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition piecewise.hpp:1759
virtual bool is_power_like() const
Test if the node acts like a power of variable.
Definition piecewise.hpp:1922
virtual backend::buffer< T > evaluate()
Evaluate the results of the piecewise constant.
Definition piecewise.hpp:1749
virtual void to_latex() const
Convert the node to latex.
Definition piecewise.hpp:1863
T get_x_offset() const
Get x argument offset.
Definition piecewise.hpp:1964
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
the node.
Definition piecewise.hpp:1779
T get_y_scale() const
Get y argument scale.
Definition piecewise.hpp:1973
virtual shared_leaf< T, SAFE_MATH > get_power_exponent() const
Get the exponent of a power.
Definition piecewise.hpp:1931
bool is_arg_match(shared_leaf< T, SAFE_MATH > x)
Check if the args match.
Definition piecewise.hpp:1941
T get_x_scale() const
Get x argument scale.
Definition piecewise.hpp:1955
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition piecewise.hpp:1879
virtual bool is_all_variables() const
Test if node acts like a variable.
Definition piecewise.hpp:1913
T get_y_offset() const
Get y argument offset.
Definition piecewise.hpp:1982
index_2D_node(shared_leaf< T, SAFE_MATH > var, const size_t n, shared_leaf< T, SAFE_MATH > x, const T x_scale, const T x_offset, shared_leaf< T, SAFE_MATH > y, const T y_scale, const T y_offset)
Construct a 2D index.
Definition piecewise.hpp:1725
virtual bool is_constant() const
Test if node is a constant.
Definition piecewise.hpp:1904
Class representing a node leaf.
Definition node.hpp:364
virtual void endline(std::ostringstream &stream, const jit::register_usage &usage) const final
End a line in the kernel source.
Definition node.hpp:639
virtual bool is_match(std::shared_ptr< leaf_node< T, SAFE_MATH > > x)
Query if the nodes match.
Definition node.hpp:472
virtual std::shared_ptr< leaf_node< T, SAFE_MATH > > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)=0
Compile the node.
Class representing a 1D piecewise constant.
Definition piecewise.hpp:92
virtual void compile_preamble(std::ostringstream &stream, jit::register_map &registers, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d, int &avail_const_mem)
Compile preamble.
Definition piecewise.hpp:227
virtual bool is_power_like() const
Test if the node acts like a power of variable.
Definition piecewise.hpp:492
piecewise_1D_node(const backend::buffer< T > &d, shared_leaf< T, SAFE_MATH > x, const T scale, const T offset)
Construct 1D a piecewise constant node.
Definition piecewise.hpp:163
virtual bool is_constant() const
Test if node is a constant.
Definition piecewise.hpp:465
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Query if the nodes match.
Definition piecewise.hpp:419
virtual shared_leaf< T, SAFE_MATH > get_power_exponent() const
Get the exponent of a power.
Definition piecewise.hpp:510
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduction method.
Definition piecewise.hpp:192
virtual backend::buffer< T > evaluate()
Evaluate the results of the piecewise constant.
Definition piecewise.hpp:180
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition piecewise.hpp:212
T get_scale() const
Get x argument scale.
Definition piecewise.hpp:534
virtual void to_latex() const
Convert the node to latex.
Definition piecewise.hpp:433
virtual shared_leaf< T, SAFE_MATH > get_power_base()
Get the base of a power.
Definition piecewise.hpp:501
size_t get_size() const
Get the size of the buffer.
Definition piecewise.hpp:552
virtual bool has_constant_zero() const
Test the constant node has a zero.
Definition piecewise.hpp:474
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition piecewise.hpp:444
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition piecewise.hpp:320
T get_offset() const
Get x argument offset.
Definition piecewise.hpp:543
virtual bool is_all_variables() const
Test if node acts like a variable.
Definition piecewise.hpp:483
bool is_arg_match(shared_leaf< T, SAFE_MATH > x)
Check if the args match.
Definition piecewise.hpp:520
Class representing a 2D piecewise constant.
Definition piecewise.hpp:657
T get_y_scale() const
Get y argument scale.
Definition piecewise.hpp:799
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition piecewise.hpp:872
bool is_arg_match(shared_leaf< T, SAFE_MATH > x)
Check if the args match.
Definition piecewise.hpp:1249
T get_y_offset() const
Get y argument offset.
Definition piecewise.hpp:808
virtual bool is_power_like() const
Test if the node acts like a power of variable.
Definition piecewise.hpp:1221
T get_x_scale() const
Get x argument scale.
Definition piecewise.hpp:781
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition piecewise.hpp:997
size_t get_num_columns() const
Get the number of columns.
Definition piecewise.hpp:762
virtual shared_leaf< T, SAFE_MATH > get_power_base()
Get the base of a power.
Definition piecewise.hpp:1230
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Query if the nodes match.
Definition piecewise.hpp:1143
virtual bool is_all_variables() const
Test if node acts like a variable.
Definition piecewise.hpp:1212
virtual bool is_constant() const
Test if node is a constant.
Definition piecewise.hpp:1194
size_t get_num_rows() const
Get the number of columns.
Definition piecewise.hpp:771
bool is_col_match(shared_leaf< T, SAFE_MATH > x)
Do the columns match.
Definition piecewise.hpp:1285
T get_x_offset() const
Get x argument offset.
Definition piecewise.hpp:790
virtual void to_latex() const
Convert the node to latex.
Definition piecewise.hpp:1160
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition piecewise.hpp:1171
bool is_row_match(shared_leaf< T, SAFE_MATH > x)
Do the rows match.
Definition piecewise.hpp:1268
virtual shared_leaf< T, SAFE_MATH > get_power_exponent() const
Get the exponent of a power.
Definition piecewise.hpp:1239
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduction method.
Definition piecewise.hpp:833
piecewise_2D_node(const backend::buffer< T > &d, const size_t n, shared_leaf< T, SAFE_MATH > x, const T x_scale, const T x_offset, shared_leaf< T, SAFE_MATH > y, const T y_scale, const T y_offset)
Construct 2D a piecewise constant node.
Definition piecewise.hpp:741
virtual backend::buffer< T > evaluate()
Evaluate the results of the piecewise constant.
Definition piecewise.hpp:821
virtual bool has_constant_zero() const
Test the constant node has a zero.
Definition piecewise.hpp:1203
virtual void compile_preamble(std::ostringstream &stream, jit::register_map &registers, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d, int &avail_const_mem)
Compile preamble.
Definition piecewise.hpp:887
Class representing a straight node.
Definition node.hpp:1051
shared_leaf< T, SAFE_MATH > arg
Argument.
Definition node.hpp:1054
Class representing a triple branch node.
Definition node.hpp:1289
shared_leaf< T, SAFE_MATH > middle
Middle branch of the tree.
Definition node.hpp:1292
Complex scalar concept.
Definition register.hpp:24
Double base concept.
Definition register.hpp:42
subroutine assert(test, message)
Assert check.
Definition f_binding_test.f90:38
Name space for graph nodes.
Definition arithmetic.hpp:13
shared_piecewise_2D< T, SAFE_MATH > piecewise_2D_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a piecewise 2D node.
Definition piecewise.hpp:1354
constexpr shared_leaf< T, SAFE_MATH > zero()
Forward declare for zero.
Definition node.hpp:986
shared_index_2D< T, SAFE_MATH > index_2D_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a index 2D node.
Definition piecewise.hpp:2057
std::shared_ptr< index_2D_node< T, SAFE_MATH > > shared_index_2D
Convenience type alias for shared index 2D nodes.
Definition piecewise.hpp:2045
shared_piecewise_1D< T, SAFE_MATH > piecewise_1D_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a piecewise 1D node.
Definition piecewise.hpp:608
std::shared_ptr< piecewise_2D_node< T, SAFE_MATH > > shared_piecewise_2D
Convenience type alias for shared piecewise 2D nodes.
Definition piecewise.hpp:1342
shared_constant< T, SAFE_MATH > constant_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a constant node.
Definition node.hpp:1034
shared_variable< T, SAFE_MATH > variable_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a variable node.
Definition node.hpp:1727
std::shared_ptr< piecewise_1D_node< T, SAFE_MATH > > shared_piecewise_1D
Convenience type alias for shared piecewise 1D nodes.
Definition piecewise.hpp:596
constexpr T i
Convenience type for imaginary constant.
Definition node.hpp:1018
std::shared_ptr< leaf_node< T, SAFE_MATH > > shared_leaf
Convenience type alias for shared leaf nodes.
Definition node.hpp:676
std::shared_ptr< index_1D_node< T, SAFE_MATH > > shared_index_1D
Convenience type alias for shared index 1D nodes.
Definition piecewise.hpp:1650
void compile_index(std::ostringstream &stream, const std::string &register_name, const size_t length, const T scale, const T offset)
Compile an index.
Definition piecewise.hpp:26
shared_index_1D< T, SAFE_MATH > index_1D_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a index 1D node.
Definition piecewise.hpp:1662
std::map< void *, size_t > texture1d_list
Type alias for indexing 1D textures.
Definition register.hpp:263
std::string format_to_string(const T value)
Convert a value to a string while avoiding locale.
Definition register.hpp:212
std::map< void *, std::array< size_t, 2 > > texture2d_list
Type alias for indexing 2D textures.
Definition register.hpp:265
std::map< void *, size_t > register_usage
Type alias for counting register usage.
Definition register.hpp:259
std::map< void *, std::string > register_map
Type alias for mapping node pointers to register names.
Definition register.hpp:257
constexpr bool use_cuda()
Test to use Cuda.
Definition register.hpp:67
std::set< void * > visiter_map
Type alias for listing visited nodes.
Definition register.hpp:261
std::string to_string(const char prefix, const NODE *pointer)
Convert a graph::leaf_node pointer to a string.
Definition register.hpp:246
Base nodes of graph computation framework.
void index_2D()
Tests for 2D index nodes.
Definition piecewise_test.cpp:758
void piecewise_1D()
Tests for 1D piecewise nodes.
Definition piecewise_test.cpp:80
void index_1D()
Tests for 1D index nodes.
Definition piecewise_test.cpp:738
void piecewise_2D()
Tests for 2D piecewise nodes.
Definition piecewise_test.cpp:292