Graph Framework
Loading...
Searching...
No Matches
piecewise.hpp
Go to the documentation of this file.
1//------------------------------------------------------------------------------
6//------------------------------------------------------------------------------
7
8#ifndef piecewise_h
9#define piecewise_h
10
11#include "node.hpp"
12
13namespace graph {
14//------------------------------------------------------------------------------
24//------------------------------------------------------------------------------
25template<jit::float_scalar T>
26void compile_index(std::ostringstream &stream,
27 const std::string &register_name,
28 const size_t length,
29 const T scale,
30 const T offset) {
31 const std::string type = jit::type_to_string<T> ();
32 stream << "(" << jit::smallest_uint_type<T> (length) << ")min";
33 if constexpr (!jit::use_metal<T> ()) {
34 stream << "<" << type << ">";
35 }
36 stream << "(max";
37 if constexpr (!jit::use_metal<T> ()) {
38 stream << "<" << type << ">";
39 }
40 stream << "(";
41 if constexpr (jit::complex_scalar<T>) {
42 stream << "real(";
43 }
44 stream << "(" << register_name << " - ";
45 if constexpr (jit::complex_scalar<T>) {
47 }
48 stream << offset << ")/";
49 if constexpr (jit::complex_scalar<T>) {
51 }
52 stream << scale;
53 if constexpr (jit::complex_scalar<T>) {
54 stream << ")";
55 }
56 stream << ",";
57 if constexpr (jit::use_metal<T> ()) {
58 stream << "(" << type << ")";
59 }
60 stream << "0),";
61 if constexpr (jit::use_metal<T> ()) {
62 stream << "(" << type << ")";
63 }
64 stream << length - 1 << ")";
65}
66
67//******************************************************************************
68// 1D Piecewise node.
69//******************************************************************************
70//------------------------------------------------------------------------------
103//------------------------------------------------------------------------------
104 template<jit::float_scalar T, bool SAFE_MATH=false>
105 class piecewise_1D_node final : public straight_node<T, SAFE_MATH> {
106 private:
108 const T scale;
110 const T offset;
111
112//------------------------------------------------------------------------------
117//------------------------------------------------------------------------------
118 static std::string to_string(const backend::buffer<T> &d) {
119 std::string temp;
120 for (size_t i = 0, ie = d.size(); i < ie; i++) {
122 }
123
124 return temp;
125 }
126
127//------------------------------------------------------------------------------
133//------------------------------------------------------------------------------
134 static std::string to_string(const backend::buffer<T> &d,
136 return piecewise_1D_node::to_string(d) +
137 jit::format_to_string(x->get_hash());
138 }
139
140//------------------------------------------------------------------------------
145//------------------------------------------------------------------------------
146 static size_t hash_data(const backend::buffer<T> &d) {
147 const size_t h = std::hash<std::string>{} (piecewise_1D_node::to_string(d));
148 for (size_t i = h; i < std::numeric_limits<size_t>::max(); i++) {
149 if (leaf_node<T, SAFE_MATH>::caches.backends.find(i) ==
150 leaf_node<T, SAFE_MATH>::caches.backends.end()) {
152 return i;
153 } else if (d == leaf_node<T, SAFE_MATH>::caches.backends[i]) {
154 return i;
155 }
156 }
157#if defined(__clang__) || defined(__GNUC__)
159#else
160 assert(false && "Should never reach.");
161#endif
162 }
163
165 const size_t data_hash;
166
167 public:
168//------------------------------------------------------------------------------
175//------------------------------------------------------------------------------
178 const T scale,
179 const T offset) :
180 straight_node<T, SAFE_MATH> (x, piecewise_1D_node::to_string(d, x)),
181 data_hash(piecewise_1D_node::hash_data(d)), scale(scale),
182 offset(offset) {}
183
184//------------------------------------------------------------------------------
192//------------------------------------------------------------------------------
194 return leaf_node<T, SAFE_MATH>::caches.backends[data_hash];
195 }
196
197//------------------------------------------------------------------------------
204//------------------------------------------------------------------------------
206 if (constant_cast(this->arg).get()) {
207 const T arg = (this->arg->evaluate().at(0) + offset)/scale;
208 if constexpr (jit::float_base<T>) {
209 const size_t i = std::max<float> (std::min<float> (std::real(arg),
210 this->get_size() - 1),
211 0);
212 return constant<T, SAFE_MATH> (leaf_node<T, SAFE_MATH>::caches.backends[data_hash][i]);
213 } else {
214 const size_t i = std::max<double> (std::min<double> (std::real(arg),
215 this->get_size() - 1),
216 0);
217 return constant<T, SAFE_MATH> (leaf_node<T, SAFE_MATH>::caches.backends[data_hash][i]);
218 }
219 }
220
221 if (evaluate().is_same()) {
222 return constant<T, SAFE_MATH> (evaluate().at(0));
223 }
224 return this->shared_from_this();
225 }
226
227//------------------------------------------------------------------------------
232//------------------------------------------------------------------------------
234 return constant<T, SAFE_MATH> (static_cast<T> (this->is_match(x)));
235 }
236
237//------------------------------------------------------------------------------
247//------------------------------------------------------------------------------
248 virtual void compile_preamble(std::ostringstream &stream,
249 jit::register_map &registers,
254 int &avail_const_mem) {
255 if (visited.find(this) == visited.end()) {
256 this->arg->compile_preamble(stream, registers,
257 visited, usage,
260 if (registers.find(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data()) == registers.end()) {
261 registers[leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data()] =
262 jit::to_string('a', leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data());
263 const size_t length = leaf_node<T, SAFE_MATH>::caches.backends[data_hash].size();
264 if constexpr (jit::use_metal<T> ()) {
265 textures1d.try_emplace(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data(),
266 length);
267#ifdef USE_CUDA_TEXTURES
268 } else if constexpr (jit::use_cuda()) {
269 textures1d.try_emplace(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data(),
270 length);
271#endif
272 } else {
273 if constexpr (jit::use_cuda()) {
274 const int buffer_size = length*sizeof(T);
275 if (avail_const_mem - buffer_size > 0) {
277 stream << "__constant__ ";
278 }
279 }
280 stream << "const ";
281 jit::add_type<T> (stream);
282 stream << " " << registers[leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data()] << "[] = {";
283 if constexpr (jit::complex_scalar<T>) {
284 jit::add_type<T> (stream);
285 }
286 stream << leaf_node<T, SAFE_MATH>::caches.backends[data_hash][0];
287 for (size_t i = 1; i < length; i++) {
288 stream << ", ";
289 if constexpr (jit::complex_scalar<T>) {
290 jit::add_type<T> (stream);
291 }
292 stream << leaf_node<T, SAFE_MATH>::caches.backends[data_hash][i];
293 }
294 stream << "};" << std::endl;
295 }
296 } else {
297// When using textures, the register can be defined in a previous kernel. We
298// need to add the textures again.
299 const size_t length = leaf_node<T, SAFE_MATH>::caches.backends[data_hash].size();
300 if constexpr (jit::use_metal<T> ()) {
301 textures1d.try_emplace(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data(),
302 length);
303#ifdef USE_CUDA_TEXTURES
304 } else if constexpr (jit::use_cuda()) {
305 textures1d.try_emplace(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data(),
306 length);
307#endif
308 }
309 }
310 visited.insert(this);
311#ifdef SHOW_USE_COUNT
312 usage[this] = 1;
313 } else {
314 ++usage[this];
315#endif
316 }
317 }
318
319//------------------------------------------------------------------------------
339//------------------------------------------------------------------------------
341 compile(std::ostringstream &stream,
342 jit::register_map &registers,
344 const jit::register_usage &usage) {
345 if (registers.find(this) == registers.end()) {
346#ifdef USE_INDEX_CACHE
347 if (indices.find(this->arg.get()) == indices.end()) {
348#endif
349 const size_t length = leaf_node<T, SAFE_MATH>::caches.backends[data_hash].size();
350 shared_leaf<T, SAFE_MATH> a = this->arg->compile(stream,
351 registers,
352 indices,
353 usage);
354#ifdef USE_INDEX_CACHE
355 indices[a.get()] = jit::to_string('i', a.get());
356 stream << " const "
357 << jit::smallest_uint_type<T> (length) << " "
358 << indices[a.get()] << " = ";
359 compile_index<T> (stream, registers[a.get()], length,
360 scale, offset);
361 a->endline(stream, usage);
362 }
363#endif
364
365 registers[this] = jit::to_string('r', this);
366 stream << " const ";
367 jit::add_type<T> (stream);
368 stream << " " << registers[this] << " = ";
369#ifdef USE_CUDA_TEXTURES
370 if constexpr (jit::use_cuda()) {
371 if constexpr (jit::float_base<T>) {
372 if constexpr (complex_scalar<T>) {
373 stream << "to_cmp_float(tex1D<float2> (";
374 } else {
375 stream << "tex1D<float> (";
376 }
377 } else {
378 if constexpr (complex_scalar<T>) {
379 stream << "to_cmp_double(tex1D<uint4> (";
380 } else {
381 stream << "to_double(tex1D<uint2> (";
382 }
383 }
384 }
385#endif
386 stream << registers[leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data()];
387 if constexpr (jit::use_metal<T> ()) {
388#ifdef USE_INDEX_CACHE
389 stream << ".read("
390 << indices[this->arg.get()]
391 << ").r";
392#else
393 stream << ".read(";
394 compile_index<T> (stream, registers[a.get()], length,
395 scale, offset);
396 stream << ").r";
397#endif
398#ifdef USE_CUDA_TEXTURES
399 } else if constexpr (jit::use_cuda()) {
400#ifdef USE_INDEX_CACHE
401 stream << ", "
402 << indices[this->arg.get()];
403#else
404 stream << ", ";
405 compile_index<T> (stream, registers[a.get()], length,
406 scale, offset);
407#endif
409 stream << ")";
410 }
411 stream << ")";
412#endif
413 } else {
414#ifdef USE_INDEX_CACHE
415 stream << "["
416 << indices[this->arg.get()]
417 << "]";
418#else
419 stream << "[";
420 compile_index<T> (stream, registers[a.get()], length,
421 scale, offset);
422 stream << "]";
423#endif
424 }
425 this->endline(stream, usage);
426 }
427
428 return this->shared_from_this();
429 }
430
431//------------------------------------------------------------------------------
439//------------------------------------------------------------------------------
441 auto x_cast = piecewise_1D_cast(x);
442
443 if (x_cast.get()) {
444 return this->data_hash == x_cast->data_hash &&
445 this->arg->is_match(x_cast->get_arg());
446 }
447
448 return false;
449 }
450
451//------------------------------------------------------------------------------
453//------------------------------------------------------------------------------
454 virtual void to_latex() const {
455 std::cout << "r\\_" << reinterpret_cast<size_t> (this) << "_{i}";
456 }
457
458//------------------------------------------------------------------------------
464//------------------------------------------------------------------------------
465 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
466 jit::register_map &registers) {
467 if (registers.find(this) == registers.end()) {
468 const std::string name = jit::to_string('r', this);
469 registers[this] = name;
470 stream << " " << name
471 << " [label = \"r_" << reinterpret_cast<size_t> (this)
472 << "_{i}\", shape = hexagon, style = filled, fillcolor = black, fontcolor = white];" << std::endl;
473
474 auto a = this->arg->to_vizgraph(stream, registers);
475 stream << " " << name << " -- " << registers[a.get()] << ";" << std::endl;
476 }
477
478 return this->shared_from_this();
479 }
480
481//------------------------------------------------------------------------------
485//------------------------------------------------------------------------------
486 virtual bool is_constant() const {
487 return true;
488 }
489
490//------------------------------------------------------------------------------
494//------------------------------------------------------------------------------
495 virtual bool has_constant_zero() const {
496 return leaf_node<T, SAFE_MATH>::caches.backends[data_hash].has_zero();
497 }
498
499//------------------------------------------------------------------------------
503//------------------------------------------------------------------------------
504 virtual bool is_all_variables() const {
505 return false;
506 }
507
508//------------------------------------------------------------------------------
512//------------------------------------------------------------------------------
513 virtual bool is_power_like() const {
514 return true;
515 }
516
517//------------------------------------------------------------------------------
521//------------------------------------------------------------------------------
523 return this->shared_from_this();
524 }
525
526//------------------------------------------------------------------------------
530//------------------------------------------------------------------------------
532 return one<T, SAFE_MATH> ();
533 }
534
535//------------------------------------------------------------------------------
540//------------------------------------------------------------------------------
542 auto temp = piecewise_1D_cast(x);
543 return temp.get() &&
544 this->arg->is_match(temp->get_arg()) &&
545 (temp->get_size() == this->get_size()) &&
546 (temp->get_scale() == this->scale) &&
547 (temp->get_offset() == this->offset);
548 }
549
550//------------------------------------------------------------------------------
554//------------------------------------------------------------------------------
555 T get_scale() const {
556 return scale;
557 }
558
559//------------------------------------------------------------------------------
563//------------------------------------------------------------------------------
564 T get_offset() const {
565 return offset;
566 }
567
568//------------------------------------------------------------------------------
572//------------------------------------------------------------------------------
573 size_t get_size() const {
574 return leaf_node<T, SAFE_MATH>::caches.backends[data_hash].size();
575 }
576 };
577
578//------------------------------------------------------------------------------
589//------------------------------------------------------------------------------
590 template<jit::float_scalar T, bool SAFE_MATH=false>
593 const T scale,
594 const T offset) {
595 auto temp = std::make_shared<piecewise_1D_node<T, SAFE_MATH>> (d, x,
596 scale,
597 offset)->reduce();
598// Test for hash collisions.
599 for (size_t i = temp->get_hash(); i < std::numeric_limits<size_t>::max(); i++) {
600 if (leaf_node<T, SAFE_MATH>::caches.nodes.find(i) ==
603 return temp;
604 } else if (temp->is_match(leaf_node<T, SAFE_MATH>::caches.nodes[i])) {
606 }
607 }
608#if defined(__clang__) || defined(__GNUC__)
610#else
611 assert(false && "Should never reach.");
612#endif
613 }
614
616 template<jit::float_scalar T, bool SAFE_MATH=false>
617 using shared_piecewise_1D = std::shared_ptr<piecewise_1D_node<T, SAFE_MATH>>;
618
619//------------------------------------------------------------------------------
627//------------------------------------------------------------------------------
628 template<jit::float_scalar T, bool SAFE_MATH=false>
630 return std::dynamic_pointer_cast<piecewise_1D_node<T, SAFE_MATH>> (x);
631 }
632
633//******************************************************************************
634// 2D Piecewise node.
635//******************************************************************************
636//------------------------------------------------------------------------------
676//------------------------------------------------------------------------------
677 template<jit::float_scalar T, bool SAFE_MATH=false>
678 class piecewise_2D_node final : public branch_node<T, SAFE_MATH> {
679 private:
681 const T x_scale;
683 const T x_offset;
685 const T y_scale;
687 const T y_offset;
688
689//------------------------------------------------------------------------------
694//------------------------------------------------------------------------------
695 static std::string to_string(const backend::buffer<T> &d) {
696 std::string temp;
697 for (size_t i = 0, ie = d.size(); i < ie; i++) {
699 }
700
701 return temp;
702 }
703
704//------------------------------------------------------------------------------
711//------------------------------------------------------------------------------
712 static std::string to_string(const backend::buffer<T> &d,
715 return piecewise_2D_node::to_string(d) +
716 jit::format_to_string(x->get_hash()) +
717 jit::format_to_string(y->get_hash());
718 }
719
720//------------------------------------------------------------------------------
725//------------------------------------------------------------------------------
726 static size_t hash_data(const backend::buffer<T> &d) {
727 const size_t h = std::hash<std::string>{} (piecewise_2D_node::to_string(d));
728 for (size_t i = h; i < std::numeric_limits<size_t>::max(); i++) {
729 if (leaf_node<T, SAFE_MATH>::caches.backends.find(i) ==
730 leaf_node<T, SAFE_MATH>::caches.backends.end()) {
732 return i;
733 } else if (d == leaf_node<T, SAFE_MATH>::caches.backends[i]) {
734 return i;
735 }
736 }
737#if defined(__clang__) || defined(__GNUC__)
739#else
740 assert(false && "Should never reach.");
741#endif
742 }
743
745 const size_t data_hash;
747 const size_t num_columns;
748
749 public:
750//------------------------------------------------------------------------------
761//------------------------------------------------------------------------------
763 const size_t n,
765 const T x_scale,
766 const T x_offset,
768 const T y_scale,
769 const T y_offset) :
770 branch_node<T, SAFE_MATH> (x, y, piecewise_2D_node::to_string(d, x, y)),
771 data_hash(piecewise_2D_node::hash_data(d)),
772 num_columns(n), x_scale(x_scale), x_offset(x_offset), y_scale(y_scale),
773 y_offset(y_offset) {
774 assert(d.size()%n == 0 &&
775 "Expected the data buffer to be a multiple of the number of columns.");
776 }
777
778//------------------------------------------------------------------------------
782//------------------------------------------------------------------------------
783 size_t get_num_columns() const {
784 return num_columns;
785 }
786
787//------------------------------------------------------------------------------
791//------------------------------------------------------------------------------
792 size_t get_num_rows() const {
793 return leaf_node<T, SAFE_MATH>::caches.backends[data_hash].size() /
794 num_columns;
795 }
796
797//------------------------------------------------------------------------------
801//------------------------------------------------------------------------------
802 T get_x_scale() const {
803 return x_scale;
804 }
805
806//------------------------------------------------------------------------------
810//------------------------------------------------------------------------------
811 T get_x_offset() const {
812 return x_offset;
813 }
814
815//------------------------------------------------------------------------------
819//------------------------------------------------------------------------------
820 T get_y_scale() const {
821 return y_scale;
822 }
823
824//------------------------------------------------------------------------------
828//------------------------------------------------------------------------------
829 T get_y_offset() const {
830 return y_offset;
831 }
832
833//------------------------------------------------------------------------------
841//------------------------------------------------------------------------------
843 return leaf_node<T, SAFE_MATH>::caches.backends[data_hash];
844 }
845
846//------------------------------------------------------------------------------
853//------------------------------------------------------------------------------
855 if (constant_cast(this->left).get() &&
856 constant_cast(this->right).get()) {
857 const T l = (this->left->evaluate().at(0) + x_offset)/x_scale;
858 const T r = (this->right->evaluate().at(0) + y_offset)/y_scale;
859
860 if constexpr (jit::float_base<T>) {
861 const size_t i = std::max<float> (std::min<float> (std::real(l),
862 this->get_num_rows() - 1),
863 0);
864 const size_t j = std::max<float> (std::min<float> (std::real(r),
865 this->get_num_columns() - 1),
866 0);
867 return constant<T, SAFE_MATH> (leaf_node<T, SAFE_MATH>::caches.backends[data_hash][i*this->get_num_columns() + j]);
868 } else {
869 const size_t i = std::max<double> (std::min<double> (std::real(l),
870 this->get_num_rows() - 1),
871 0);
872 const size_t j = std::max<double> (std::min<double> (std::real(r),
873 this->get_num_columns() - 1),
874 0);
875 return constant<T, SAFE_MATH> (leaf_node<T, SAFE_MATH>::caches.backends[data_hash][i*this->get_num_columns() + j]);
876 }
877 } else if (constant_cast(this->left).get()) {
878 const T l = (this->left->evaluate().at(0) + x_offset)/x_scale;
879
880 if constexpr (jit::float_base<T>) {
881 const size_t i = std::max<float> (std::min<float> (std::real(l),
882 this->get_num_rows() - 1),
883 0);
884 return piecewise_1D(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].index_row(i, this->get_num_columns()),
885 this->right, y_scale, y_offset);
886 } else {
887 const size_t i = std::max<double> (std::min<double> (std::real(l),
888 this->get_num_rows() - 1),
889 0);
890 return piecewise_1D(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].index_row(i, this->get_num_columns()),
891 this->right, y_scale, y_offset);
892 }
893 } else if (constant_cast(this->right).get()) {
894 const T r = (this->right->evaluate().at(0) + y_offset)/y_scale;
895
896 if constexpr (jit::float_base<T>) {
897 const size_t j = std::max<float> (std::min<float> (std::real(r),
898 this->get_num_columns() - 1),
899 0);
900 return piecewise_1D(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].index_column(j, this->get_num_columns()),
901 this->left, x_scale, x_offset);
902 } else {
903 const size_t j = std::max<double> (std::min<double> (std::real(r),
904 this->get_num_columns() - 1),
905 0);
906 return piecewise_1D(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].index_column(j, this->get_num_columns()),
907 this->left, x_scale, x_offset);
908 }
909 }
910
911 if (evaluate().is_same()) {
912 return constant<T, SAFE_MATH> (evaluate().at(0));
913 }
914
915 return this->shared_from_this();
916 }
917
918//------------------------------------------------------------------------------
923//------------------------------------------------------------------------------
925 return constant<T, SAFE_MATH> (static_cast<T> (this->is_match(x)));
926 }
927
928//------------------------------------------------------------------------------
938//------------------------------------------------------------------------------
939 virtual void compile_preamble(std::ostringstream &stream,
940 jit::register_map &registers,
945 int &avail_const_mem) {
946 if (visited.find(this) == visited.end()) {
947 this->left->compile_preamble(stream, registers,
948 visited, usage,
951 this->right->compile_preamble(stream, registers,
952 visited, usage,
955 if (registers.find(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data()) == registers.end()) {
956 registers[leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data()] =
957 jit::to_string('a', leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data());
958 const size_t length = leaf_node<T, SAFE_MATH>::caches.backends[data_hash].size();
959 if constexpr (jit::use_metal<T> ()) {
960 textures2d.try_emplace(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data(),
961 std::array<size_t, 2> ({length/num_columns, num_columns}));
962#ifdef USE_CUDA_TEXTURES
963 } else if constexpr (jit::use_cuda()) {
964 textures2d.try_emplace(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data(),
965 std::array<size_t, 2> ({length/num_columns, num_columns}));
966#endif
967 } else {
968 if constexpr (jit::use_cuda()) {
969 const int buffer_size = length*sizeof(T);
970 if (avail_const_mem - buffer_size > 0) {
972 stream << "__constant__ ";
973 }
974 }
975 stream << "const ";
976 jit::add_type<T> (stream);
977 stream << " " << registers[leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data()] << "[] = {";
978 if constexpr (jit::complex_scalar<T>) {
979 jit::add_type<T> (stream);
980 }
981 stream << leaf_node<T, SAFE_MATH>::caches.backends[data_hash][0];
982 for (size_t i = 1; i < length; i++) {
983 stream << ", ";
984 if constexpr (jit::complex_scalar<T>) {
985 jit::add_type<T> (stream);
986 }
987 stream << leaf_node<T, SAFE_MATH>::caches.backends[data_hash][i];
988 }
989 stream << "};" << std::endl;
990 }
991 } else {
992// When using textures, the register can be defined in a previous kernel. We
993// need to add the textures again.
994 const size_t length = leaf_node<T, SAFE_MATH>::caches.backends[data_hash].size();
995 if constexpr (jit::use_metal<T> ()) {
996 textures2d.try_emplace(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data(),
997 std::array<size_t, 2> ({length/num_columns, num_columns}));
998#ifdef USE_CUDA_TEXTURES
999 } else if constexpr (jit::use_cuda()) {
1000 textures2d.try_emplace(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data(),
1001 std::array<size_t, 2> ({length/num_columns, num_columns}));
1002#endif
1003 }
1004 }
1005 visited.insert(this);
1006#ifdef SHOW_USE_COUNT
1007 usage[this] = 1;
1008 } else {
1009 ++usage[this];
1010#endif
1011 }
1012 }
1013
1014//------------------------------------------------------------------------------
1047//------------------------------------------------------------------------------
1049 compile(std::ostringstream &stream,
1050 jit::register_map &registers,
1052 const jit::register_usage &usage) {
1053 if (registers.find(this) == registers.end()) {
1054 const size_t length = leaf_node<T, SAFE_MATH>::caches.backends[data_hash].size();
1055 const size_t num_rows = length/num_columns;
1056
1057 shared_leaf<T, SAFE_MATH> x = this->left->compile(stream,
1058 registers,
1059 indices,
1060 usage);
1061 shared_leaf<T, SAFE_MATH> y = this->right->compile(stream,
1062 registers,
1063 indices,
1064 usage);
1065
1066#ifdef USE_INDEX_CACHE
1067 if (indices.find(x.get()) == indices.end()) {
1068 indices[x.get()] = jit::to_string('i', x.get());
1069 stream << " const "
1070 << jit::smallest_uint_type<T> (num_rows) << " "
1071 << indices[x.get()] << " = ";
1072 compile_index<T> (stream, registers[x.get()], num_rows,
1073 x_scale, x_offset);
1074 x->endline(stream, usage);
1075 }
1076 if (indices.find(y.get()) == indices.end()) {
1077 indices[y.get()] = jit::to_string('i', y.get());
1078 stream << " const "
1079 << jit::smallest_uint_type<T> (num_columns) << " "
1080 << indices[y.get()] << " = ";
1081 compile_index<T> (stream, registers[y.get()], num_columns,
1082 y_scale, y_offset);
1083 y->endline(stream, usage);
1084 }
1085
1086 auto temp = this->left + this->right;
1087 if constexpr (!jit::use_metal<T> ()
1088#ifdef USE_CUDA_TEXTURES
1089 || !jit::use_cuda()
1090#endif
1091 ) {
1092 if (indices.find(temp.get()) == indices.end()) {
1093 indices[temp.get()] = jit::to_string('i', temp.get());
1094 stream << " const "
1095 << jit::smallest_uint_type<T> (length) << " "
1096 << indices[temp.get()] << " = "
1097 << indices[x.get()]
1098 << "*" << num_columns << " + "
1099 << indices[y.get()]
1100 << ";" << std::endl;
1101 }
1102 }
1103#endif
1104
1105 registers[this] = jit::to_string('r', this);
1106 stream << " const ";
1107 jit::add_type<T> (stream);
1108 stream << " " << registers[this] << " = ";
1109#ifdef USE_CUDA_TEXTURES
1110 if constexpr (jit::use_cuda()) {
1111 if constexpr (jit::float_base<T>) {
1112 if constexpr (complex_scalar<T>) {
1113 stream << "to_cmp_float(tex1D<float2> (";
1114 } else {
1115 stream << "tex1D<float> (";
1116 }
1117 } else {
1118 if constexpr (complex_scalar<T>) {
1119 stream << "to_cmp_double(tex1D<uint4> (";
1120 } else {
1121 stream << "to_double(tex1D<uint2> (";
1122 }
1123 }
1124 }
1125#endif
1126 stream << registers[leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data()];
1127 if constexpr (jit::use_metal<T> ()) {
1128#ifdef USE_INDEX_CACHE
1129 stream << ".read("
1130 << jit::smallest_uint_type<T> (std::max(num_rows,
1131 num_columns))
1132 << "2("
1133 << indices[y.get()]
1134 << ","
1135 << indices[x.get()]
1136 << ")).r";
1137#else
1138 stream << ".read(uint2(";
1139 compile_index<T> (stream, registers[y.get()], num_columns,
1140 y_scale, y_offset);
1141 stream << ",";
1142 compile_index<T> (stream, registers[x.get()], num_rows,
1143 x_scale, x_offset);
1144 stream << ")).r";
1145#endif
1146#ifdef USE_CUDA_TEXTURES
1147 } else if constexpr (jit::use_cuda()) {
1148#ifdef USE_INDEX_CACHE
1149 stream << ", "
1150 << indices[y.get()]
1151 << ", "
1152 << indices[x.get()];
1153#else
1154 stream << ", ";
1155 compile_index<T> (stream, registers[y.get()], num_columns,
1156 y_scale, y_offset);
1157 stream << ", ";
1158 compile_index<T> (stream, registers[x.get()], num_rows,
1159 x_scale, x_offset);
1160#endif
1162 stream << ")";
1163 }
1164 stream << ")";
1165#endif
1166 } else {
1167#ifdef USE_INDEX_CACHE
1168 stream << "["
1169 << indices[temp.get()]
1170 << "]";
1171#else
1172 stream << "[";
1173 compile_index<T> (stream, registers[x.get()], num_rows,
1174 x_scale, x_offset);
1175 stream << "*" << num_columns << " + ";
1176 compile_index<T> (stream, registers[y.get()], num_columns,
1177 y_scale, y_offset);
1178 stream << "]";
1179#endif
1180 }
1181 this->endline(stream, usage);
1182 }
1183
1184 return this->shared_from_this();
1185 }
1186
1187//------------------------------------------------------------------------------
1194//------------------------------------------------------------------------------
1196 auto x_cast = piecewise_2D_cast(x);
1197
1198 if (x_cast.get()) {
1199 return this->data_hash == x_cast->data_hash &&
1200 this->left->is_match(x_cast->get_left()) &&
1201 this->right->is_match(x_cast->get_right());
1202 }
1203
1204 return false;
1205 }
1206
1207//------------------------------------------------------------------------------
1211//------------------------------------------------------------------------------
1212 virtual void to_latex() const {
1213 std::cout << "r\\_" << reinterpret_cast<size_t> (this) << "_{ij}";
1214 }
1215
1216//------------------------------------------------------------------------------
1222//------------------------------------------------------------------------------
1223 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
1224 jit::register_map &registers) {
1225 if (registers.find(this) == registers.end()) {
1226 const std::string name = jit::to_string('r', this);
1227 registers[this] = name;
1228 stream << " " << name
1229 << " [label = \"r_" << reinterpret_cast<size_t> (this)
1230 << "_{ij}\", shape = hexagon, style = filled, fillcolor = black, fontcolor = white];" << std::endl;
1231
1232 auto l = this->left->to_vizgraph(stream, registers);
1233 stream << " " << name << " -- " << registers[l.get()] << ";" << std::endl;
1234 auto r = this->right->to_vizgraph(stream, registers);
1235 stream << " " << name << " -- " << registers[r.get()] << ";" << std::endl;
1236 }
1237
1238 return this->shared_from_this();
1239 }
1240
1241//------------------------------------------------------------------------------
1245//------------------------------------------------------------------------------
1246 virtual bool is_constant() const {
1247 return true;
1248 }
1249
1250//------------------------------------------------------------------------------
1254//------------------------------------------------------------------------------
1255 virtual bool has_constant_zero() const {
1256 return leaf_node<T, SAFE_MATH>::caches.backends[data_hash].has_zero();
1257 }
1258
1259//------------------------------------------------------------------------------
1263//------------------------------------------------------------------------------
1264 virtual bool is_all_variables() const {
1265 return false;
1266 }
1267
1268//------------------------------------------------------------------------------
1272//------------------------------------------------------------------------------
1273 virtual bool is_power_like() const {
1274 return true;
1275 }
1276
1277//------------------------------------------------------------------------------
1281//------------------------------------------------------------------------------
1283 return this->shared_from_this();
1284 }
1285
1286//------------------------------------------------------------------------------
1290//------------------------------------------------------------------------------
1292 return one<T, SAFE_MATH> ();
1293 }
1294
1295//------------------------------------------------------------------------------
1300//------------------------------------------------------------------------------
1302 auto temp = piecewise_2D_cast(x);
1303 return temp.get() &&
1304 this->left->is_match(temp->get_left()) &&
1305 this->right->is_match(temp->get_right()) &&
1306 (temp->get_num_rows() == this->get_num_rows()) &&
1307 (temp->get_num_columns() == this->get_num_columns()) &&
1308 (temp->get_x_scale() == this->x_scale) &&
1309 (temp->get_x_offset() == this->x_offset) &&
1310 (temp->get_y_scale() == this->y_scale) &&
1311 (temp->get_y_offset() == this->y_offset);
1312 }
1313
1314//------------------------------------------------------------------------------
1319//------------------------------------------------------------------------------
1321 auto temp = piecewise_1D_cast(x);
1322 return temp.get() &&
1323 this->left->is_match(temp->get_arg()) &&
1324 (temp->get_size() == this->get_num_rows()) &&
1325 (temp->get_scale() == this->x_scale) &&
1326 (temp->get_offset() == this->x_offset);
1327 }
1328
1329//------------------------------------------------------------------------------
1336//------------------------------------------------------------------------------
1338 auto temp = piecewise_1D_cast(x);
1339 return temp.get() &&
1340 this->right->is_match(temp->get_arg()) &&
1341 (temp->get_size() == this->get_num_columns()) &&
1342 (temp->get_scale() == this->y_scale) &&
1343 (temp->get_offset() == this->y_offset);
1344 }
1345 };
1346
1347//------------------------------------------------------------------------------
1362//------------------------------------------------------------------------------
1363 template<jit::float_scalar T, bool SAFE_MATH=false>
1365 const size_t n,
1367 const T x_scale,
1368 const T x_offset,
1370 const T y_scale,
1371 const T y_offset) {
1372 auto temp = std::make_shared<piecewise_2D_node<T, SAFE_MATH>> (d, n,
1373 x, x_scale, x_offset,
1374 y, y_scale, y_offset)->reduce();
1375// Test for hash collisions.
1376 for (size_t i = temp->get_hash(); i < std::numeric_limits<size_t>::max(); i++) {
1377 if (leaf_node<T, SAFE_MATH>::caches.nodes.find(i) ==
1378 leaf_node<T, SAFE_MATH>::caches.nodes.end()) {
1380 return temp;
1381 } else if (temp->is_match(leaf_node<T, SAFE_MATH>::caches.nodes[i])) {
1382 return leaf_node<T, SAFE_MATH>::caches.nodes[i];
1383 }
1384 }
1385#if defined(__clang__) || defined(__GNUC__)
1387#else
1388 assert(false && "Should never reach.");
1389#endif
1390 }
1391
1393 template<jit::float_scalar T, bool SAFE_MATH=false>
1394 using shared_piecewise_2D = std::shared_ptr<piecewise_2D_node<T, SAFE_MATH>>;
1395
1396//------------------------------------------------------------------------------
1404//------------------------------------------------------------------------------
1405 template<jit::float_scalar T, bool SAFE_MATH=false>
1407 return std::dynamic_pointer_cast<piecewise_2D_node<T, SAFE_MATH>> (x);
1408 }
1409
1410//******************************************************************************
1411// 1D Index node.
1412//******************************************************************************
1413//------------------------------------------------------------------------------
1424//------------------------------------------------------------------------------
1425 template<jit::float_scalar T, bool SAFE_MATH=false>
1426 class index_1D_node final : public branch_node<T, SAFE_MATH> {
1427 private:
1429 const T scale;
1431 const T offset;
1432
1433//------------------------------------------------------------------------------
1439//------------------------------------------------------------------------------
1440 static std::string to_string(shared_leaf<T, SAFE_MATH> v,
1442 return jit::format_to_string(v->get_hash()) + "[" +
1443 jit::format_to_string(x->get_hash()) + "]";
1444 }
1445
1446 public:
1447//------------------------------------------------------------------------------
1454//------------------------------------------------------------------------------
1457 const T scale,
1458 const T offset) :
1459 branch_node<T, SAFE_MATH> (var, x, index_1D_node::to_string(var, x)),
1460 scale(scale), offset(offset) {}
1461
1462//------------------------------------------------------------------------------
1470//------------------------------------------------------------------------------
1472 return this->right->evaluate();
1473 }
1474
1475//------------------------------------------------------------------------------
1480//------------------------------------------------------------------------------
1482 return constant<T, SAFE_MATH> (static_cast<T> (this->is_match(x)));
1483 }
1484
1485//------------------------------------------------------------------------------
1498//------------------------------------------------------------------------------
1500 compile(std::ostringstream &stream,
1501 jit::register_map &registers,
1503 const jit::register_usage &usage) {
1504 if (registers.find(this) == registers.end()) {
1505#ifdef USE_INDEX_CACHE
1506 if (indices.find(this->right.get()) == indices.end()) {
1507#endif
1508 const size_t length = variable_cast(this->left)->size();
1509 shared_leaf<T, SAFE_MATH> a = this->right->compile(stream,
1510 registers,
1511 indices,
1512 usage);
1513#ifdef USE_INDEX_CACHE
1514 indices[a.get()] = jit::to_string('i', a.get());
1515 stream << " const "
1516 << jit::smallest_uint_type<T> (length) << " "
1517 << indices[a.get()] << " = ";
1518 compile_index<T> (stream, registers[a.get()], length,
1519 scale, offset);
1520 a->endline(stream, usage);
1521 }
1522#endif
1523
1524 registers[this] = jit::to_string('r', this);
1525 stream << " const ";
1526 jit::add_type<T> (stream);
1527 auto var = this->left->compile(stream,
1528 registers,
1529 indices,
1530 usage);
1531 stream << " " << registers[this] << " = "
1532 << jit::to_string('v', var.get());
1533#ifdef USE_INDEX_CACHE
1534 stream << "[" << indices[this->right.get()] << "]";
1535#else
1536 stream << "[";
1537 compile_index<T> (stream, registers[a.get()], length,
1538 scale, offset);
1539 stream << "]";
1540#endif
1541 this->endline(stream, usage);
1542 }
1543
1544 return this->shared_from_this();
1545 }
1546
1547//------------------------------------------------------------------------------
1549//------------------------------------------------------------------------------
1550 virtual void to_latex() const {
1551 std::cout << "r\\_" << reinterpret_cast<size_t> (this->left.get())
1552 << "\\left[i\\_"
1553 << reinterpret_cast<size_t> (this->right.get())
1554 << "\\right]";
1555 }
1556
1557//------------------------------------------------------------------------------
1563//------------------------------------------------------------------------------
1564 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
1565 jit::register_map &registers) {
1566 if (registers.find(this) == registers.end()) {
1567 const std::string name = jit::to_string('r', this);
1568 registers[this] = name;
1569 stream << " " << name
1570 << " [label = \"r_" << reinterpret_cast<size_t> (this->left.get())
1571 << "\", shape = hexagon, style = filled, fillcolor = black, fontcolor = white];" << std::endl;
1572
1573 auto l = this->left->to_vizgraph(stream, registers);
1574 stream << " " << name << " -- " << registers[l.get()] << ";" << std::endl;
1575 auto r = this->right->to_vizgraph(stream, registers);
1576 stream << " " << name << " -- " << registers[r.get()] << ";" << std::endl;
1577 }
1578
1579 return this->shared_from_this();
1580 }
1581
1582//------------------------------------------------------------------------------
1586//------------------------------------------------------------------------------
1587 virtual bool is_constant() const {
1588 return false;
1589 }
1590
1591//------------------------------------------------------------------------------
1595//------------------------------------------------------------------------------
1596 virtual bool is_all_variables() const {
1597 return false;
1598 }
1599
1600//------------------------------------------------------------------------------
1604//------------------------------------------------------------------------------
1605 virtual bool is_power_like() const {
1606 return true;
1607 }
1608
1609//------------------------------------------------------------------------------
1613//------------------------------------------------------------------------------
1615 return one<T, SAFE_MATH> ();
1616 }
1617
1618//------------------------------------------------------------------------------
1623//------------------------------------------------------------------------------
1625 auto temp = index_1D_cast(x);
1626 return temp.get() &&
1627 this->right->is_match(temp->get_right()) &&
1628 (temp->get_size() == this->get_size()) &&
1629 (temp->get_scale() == this->scale) &&
1630 (temp->get_offset() == this->offset);
1631 }
1632
1633//------------------------------------------------------------------------------
1637//------------------------------------------------------------------------------
1638 T get_scale() const {
1639 return scale;
1640 }
1641
1642//------------------------------------------------------------------------------
1646//------------------------------------------------------------------------------
1647 T get_offset() const {
1648 return offset;
1649 }
1650
1651//------------------------------------------------------------------------------
1655//------------------------------------------------------------------------------
1656 size_t get_size() const {
1657 return variable_cast(this->left)->size();
1658 }
1659 };
1660
1661//------------------------------------------------------------------------------
1672//------------------------------------------------------------------------------
1673 template<jit::float_scalar T, bool SAFE_MATH=false>
1676 const T scale,
1677 const T offset) {
1679 "index_1D requires a variable node for first arg.");
1680 auto temp = std::make_shared<index_1D_node<T, SAFE_MATH>> (v, x,
1681 scale,
1682 offset)->reduce();
1683// Test for hash collisions.
1684 for (size_t i = temp->get_hash(); i < std::numeric_limits<size_t>::max(); i++) {
1685 if (leaf_node<T, SAFE_MATH>::caches.nodes.find(i) ==
1686 leaf_node<T, SAFE_MATH>::caches.nodes.end()) {
1688 return temp;
1689 } else if (temp->is_match(leaf_node<T, SAFE_MATH>::caches.nodes[i])) {
1690 return leaf_node<T, SAFE_MATH>::caches.nodes[i];
1691 }
1692 }
1693#if defined(__clang__) || defined(__GNUC__)
1695#else
1696 assert(false && "Should never reach.");
1697#endif
1698 }
1699
1701 template<jit::float_scalar T, bool SAFE_MATH=false>
1702 using shared_index_1D = std::shared_ptr<index_1D_node<T, SAFE_MATH>>;
1703
1704//------------------------------------------------------------------------------
1712//------------------------------------------------------------------------------
1713 template<jit::float_scalar T, bool SAFE_MATH=false>
1715 return std::dynamic_pointer_cast<index_1D_node<T, SAFE_MATH>> (x);
1716 }
1717
1718//******************************************************************************
1719// 2D Index node.
1720//******************************************************************************
1721//------------------------------------------------------------------------------
1733//------------------------------------------------------------------------------
1734 template<jit::float_scalar T, bool SAFE_MATH=false>
1735 class index_2D_node final : public triple_node<T, SAFE_MATH> {
1736 private:
1738 const T x_scale;
1740 const T x_offset;
1742 const T y_scale;
1744 const T y_offset;
1746 const size_t num_columns;
1747
1748//------------------------------------------------------------------------------
1755//------------------------------------------------------------------------------
1756 static std::string to_string(shared_leaf<T, SAFE_MATH> v,
1759 return jit::format_to_string(v->get_hash()) + "[" +
1760 jit::format_to_string(x->get_hash()) + "," +
1761 jit::format_to_string(y->get_hash()) + "]";
1762 }
1763
1764 public:
1765//------------------------------------------------------------------------------
1776//------------------------------------------------------------------------------
1778 const size_t n,
1780 const T x_scale,
1781 const T x_offset,
1783 const T y_scale,
1784 const T y_offset) :
1785 triple_node<T, SAFE_MATH> (var, x, y, index_2D_node::to_string(var, x, y)),
1786 num_columns(n), x_scale(x_scale), x_offset(x_offset), y_scale(y_scale),
1787 y_offset(y_offset) {
1788 assert(variable_cast(this->left)->size()%n == 0 &&
1789 "Expected the data buffer to be a multiple of the number of columns.");
1790 }
1791
1792//------------------------------------------------------------------------------
1800//------------------------------------------------------------------------------
1802 return this->left->evaluate();
1803 }
1804
1805//------------------------------------------------------------------------------
1810//------------------------------------------------------------------------------
1812 return constant<T, SAFE_MATH> (static_cast<T> (this->is_match(x)));
1813 }
1814
1815//------------------------------------------------------------------------------
1829//------------------------------------------------------------------------------
1831 compile(std::ostringstream &stream,
1832 jit::register_map &registers,
1834 const jit::register_usage &usage) {
1835 if (registers.find(this) == registers.end()) {
1836 const size_t length = variable_cast(this->left)->size();
1837 const size_t num_rows = length/num_columns;
1838
1839 shared_leaf<T, SAFE_MATH> x = this->middle->compile(stream,
1840 registers,
1841 indices,
1842 usage);
1843 shared_leaf<T, SAFE_MATH> y = this->right->compile(stream,
1844 registers,
1845 indices,
1846 usage);
1847
1848#ifdef USE_INDEX_CACHE
1849 if (indices.find(x.get()) == indices.end()) {
1850 indices[x.get()] = jit::to_string('i', x.get());
1851 stream << " const "
1852 << jit::smallest_uint_type<T> (num_rows) << " "
1853 << indices[x.get()] << " = ";
1854 compile_index<T> (stream, registers[x.get()], num_rows,
1855 x_scale, x_offset);
1856 x->endline(stream, usage);
1857 }
1858 if (indices.find(y.get()) == indices.end()) {
1859 indices[y.get()] = jit::to_string('i', y.get());
1860 stream << " const "
1861 << jit::smallest_uint_type<T> (num_columns) << " "
1862 << indices[y.get()] << " = ";
1863 compile_index<T> (stream, registers[y.get()], num_columns,
1864 y_scale, y_offset);
1865 y->endline(stream, usage);
1866 }
1867
1868 auto temp = this->middle + this->right;
1869 if constexpr (!jit::use_metal<T> () ||
1870 !jit::use_cuda()) {
1871 if (indices.find(temp.get()) == indices.end()) {
1872 indices[temp.get()] = jit::to_string('i', temp.get());
1873 stream << " const "
1874 << jit::smallest_uint_type<T> (length) << " "
1875 << indices[temp.get()] << " = "
1876 << indices[x.get()]
1877 << "*" << num_columns << " + "
1878 << indices[y.get()]
1879 << ";" << std::endl;
1880 }
1881 }
1882#endif
1883
1884 registers[this] = jit::to_string('r', this);
1885 stream << " const ";
1886 jit::add_type<T> (stream);
1887 auto var = this->left->compile(stream,
1888 registers,
1889 indices,
1890 usage);
1891 stream << " " << registers[this] << " = "
1892 << jit::to_string('v', var.get());
1893#ifdef USE_INDEX_CACHE
1894 stream << "["
1895 << indices[temp.get()]
1896 << "]";
1897#else
1898 stream << "[";
1899 compile_index<T> (stream, registers[x.get()], num_rows,
1900 x_scale, x_offset);
1901 stream << "*" << num_columns << " + ";
1902 compile_index<T> (stream, registers[y.get()], num_columns,
1903 y_scale, y_offset);
1904 stream << "]";
1905#endif
1906 this->endline(stream, usage);
1907 }
1908
1909 return this->shared_from_this();
1910 }
1911
1912//------------------------------------------------------------------------------
1914//------------------------------------------------------------------------------
1915 virtual void to_latex() const {
1916 std::cout << "r\\_" << reinterpret_cast<size_t> (this->left.get())
1917 << "\\left[i\\_"
1918 << reinterpret_cast<size_t> (this->middle.get())
1919 << ",j\\_"
1920 << reinterpret_cast<size_t> (this->right.get())
1921 << "\\right]";
1922 }
1923
1924//------------------------------------------------------------------------------
1930//------------------------------------------------------------------------------
1931 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
1932 jit::register_map &registers) {
1933 if (registers.find(this) == registers.end()) {
1934 const std::string name = jit::to_string('r', this);
1935 registers[this] = name;
1936 stream << " " << name
1937 << " [label = \"r_" << reinterpret_cast<size_t> (this->left.get())
1938 << "\", shape = hexagon, style = filled, fillcolor = black, fontcolor = white];" << std::endl;
1939
1940 auto l = this->left->to_vizgraph(stream, registers);
1941 stream << " " << name << " -- " << registers[l.get()] << ";" << std::endl;
1942 auto m = this->middle->to_vizgraph(stream, registers);
1943 stream << " " << name << " -- " << registers[m.get()] << ";" << std::endl;
1944 auto r = this->right->to_vizgraph(stream, registers);
1945 stream << " " << name << " -- " << registers[r.get()] << ";" << std::endl;
1946 }
1947
1948 return this->shared_from_this();
1949 }
1950
1951//------------------------------------------------------------------------------
1955//------------------------------------------------------------------------------
1956 virtual bool is_constant() const {
1957 return false;
1958 }
1959
1960//------------------------------------------------------------------------------
1964//------------------------------------------------------------------------------
1965 virtual bool is_all_variables() const {
1966 return false;
1967 }
1968
1969//------------------------------------------------------------------------------
1973//------------------------------------------------------------------------------
1974 virtual bool is_power_like() const {
1975 return true;
1976 }
1977
1978//------------------------------------------------------------------------------
1982//------------------------------------------------------------------------------
1984 return one<T, SAFE_MATH> ();
1985 }
1986
1987//------------------------------------------------------------------------------
1992//------------------------------------------------------------------------------
1994 auto temp = index_1D_cast(x);
1995 return temp.get() &&
1996 this->right->is_match(temp->get_right()) &&
1997 (temp->get_size() == this->get_size()) &&
1998 (temp->get_scale() == this->scale) &&
1999 (temp->get_offset() == this->offset);
2000 }
2001
2002//------------------------------------------------------------------------------
2006//------------------------------------------------------------------------------
2007 T get_x_scale() const {
2008 return x_scale;
2009 }
2010
2011//------------------------------------------------------------------------------
2015//------------------------------------------------------------------------------
2016 T get_x_offset() const {
2017 return x_offset;
2018 }
2019
2020//------------------------------------------------------------------------------
2024//------------------------------------------------------------------------------
2025 T get_y_scale() const {
2026 return y_scale;
2027 }
2028
2029//------------------------------------------------------------------------------
2033//------------------------------------------------------------------------------
2034 T get_y_offset() const {
2035 return y_offset;
2036 }
2037
2038//------------------------------------------------------------------------------
2042//------------------------------------------------------------------------------
2043 size_t get_size() const {
2044 return variable_cast(this->left)->size();
2045 }
2046 };
2047
2048//------------------------------------------------------------------------------
2063//------------------------------------------------------------------------------
2064 template<jit::float_scalar T, bool SAFE_MATH=false>
2066 const size_t n,
2068 const T x_scale,
2069 const T x_offset,
2071 const T y_scale,
2072 const T y_offset) {
2074 "index_2D requires a variable node for first arg.");
2075 auto temp = std::make_shared<index_2D_node<T, SAFE_MATH>> (v, n,
2076 x, x_scale, x_offset,
2077 y, y_scale, y_offset)->reduce();
2078// Test for hash collisions.
2079 for (size_t i = temp->get_hash(); i < std::numeric_limits<size_t>::max(); i++) {
2080 if (leaf_node<T, SAFE_MATH>::caches.nodes.find(i) ==
2081 leaf_node<T, SAFE_MATH>::caches.nodes.end()) {
2083 return temp;
2084 } else if (temp->is_match(leaf_node<T, SAFE_MATH>::caches.nodes[i])) {
2085 return leaf_node<T, SAFE_MATH>::caches.nodes[i];
2086 }
2087 }
2088#if defined(__clang__) || defined(__GNUC__)
2090#else
2091 assert(false && "Should never reach.");
2092#endif
2093 }
2094
2096 template<jit::float_scalar T, bool SAFE_MATH=false>
2097 using shared_index_2D = std::shared_ptr<index_2D_node<T, SAFE_MATH>>;
2098
2099//------------------------------------------------------------------------------
2107//------------------------------------------------------------------------------
2108 template<jit::float_scalar T, bool SAFE_MATH=false>
2110 return std::dynamic_pointer_cast<index_2D_node<T, SAFE_MATH>> (x);
2111 }
2112}
2113
2114#endif /* piecewise_h */
Class representing a generic buffer.
Definition backend.hpp:29
Class representing a branch node.
Definition node.hpp:1165
shared_leaf< T, SAFE_MATH > right
Right branch of the tree.
Definition node.hpp:1170
shared_leaf< T, SAFE_MATH > left
Left branch of the tree.
Definition node.hpp:1168
Class representing a 1D index.
Definition piecewise.hpp:1426
virtual shared_leaf< T, SAFE_MATH > get_power_exponent() const
Get the exponent of a power.
Definition piecewise.hpp:1614
virtual backend::buffer< T > evaluate()
Evaluate the results of the piecewise constant.
Definition piecewise.hpp:1471
size_t get_size() const
Get the size of the buffer.
Definition piecewise.hpp:1656
virtual bool is_power_like() const
Test if the node acts like a power of variable.
Definition piecewise.hpp:1605
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
the node.
Definition piecewise.hpp:1500
T get_scale() const
Get x argument scale.
Definition piecewise.hpp:1638
bool is_arg_match(shared_leaf< T, SAFE_MATH > x)
Check if the args match.
Definition piecewise.hpp:1624
T get_offset() const
Get x argument offset.
Definition piecewise.hpp:1647
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition piecewise.hpp:1564
index_1D_node(shared_leaf< T, SAFE_MATH > var, shared_leaf< T, SAFE_MATH > x, const T scale, const T offset)
Construct a 1D index.
Definition piecewise.hpp:1455
virtual bool is_all_variables() const
Test if node acts like a variable.
Definition piecewise.hpp:1596
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition piecewise.hpp:1481
virtual bool is_constant() const
Test if node is a constant.
Definition piecewise.hpp:1587
virtual void to_latex() const
Convert the node to latex.
Definition piecewise.hpp:1550
Class representing a 2D index.
Definition piecewise.hpp:1735
size_t get_size() const
Get the size of the buffer.
Definition piecewise.hpp:2043
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition piecewise.hpp:1811
virtual bool is_power_like() const
Test if the node acts like a power of variable.
Definition piecewise.hpp:1974
virtual backend::buffer< T > evaluate()
Evaluate the results of the piecewise constant.
Definition piecewise.hpp:1801
virtual void to_latex() const
Convert the node to latex.
Definition piecewise.hpp:1915
T get_x_offset() const
Get x argument offset.
Definition piecewise.hpp:2016
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
the node.
Definition piecewise.hpp:1831
T get_y_scale() const
Get y argument scale.
Definition piecewise.hpp:2025
virtual shared_leaf< T, SAFE_MATH > get_power_exponent() const
Get the exponent of a power.
Definition piecewise.hpp:1983
bool is_arg_match(shared_leaf< T, SAFE_MATH > x)
Check if the args match.
Definition piecewise.hpp:1993
T get_x_scale() const
Get x argument scale.
Definition piecewise.hpp:2007
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition piecewise.hpp:1931
virtual bool is_all_variables() const
Test if node acts like a variable.
Definition piecewise.hpp:1965
T get_y_offset() const
Get y argument offset.
Definition piecewise.hpp:2034
index_2D_node(shared_leaf< T, SAFE_MATH > var, const size_t n, shared_leaf< T, SAFE_MATH > x, const T x_scale, const T x_offset, shared_leaf< T, SAFE_MATH > y, const T y_scale, const T y_offset)
Construct a 2D index.
Definition piecewise.hpp:1777
virtual bool is_constant() const
Test if node is a constant.
Definition piecewise.hpp:1956
Class representing a node leaf.
Definition node.hpp:364
virtual void endline(std::ostringstream &stream, const jit::register_usage &usage) const final
End a line in the kernel source.
Definition node.hpp:639
virtual bool is_match(std::shared_ptr< leaf_node< T, SAFE_MATH > > x)
Query if the nodes match.
Definition node.hpp:472
virtual std::shared_ptr< leaf_node< T, SAFE_MATH > > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)=0
Compile the node.
Class representing a 1D piecewise constant.
Definition piecewise.hpp:105
virtual void compile_preamble(std::ostringstream &stream, jit::register_map &registers, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d, int &avail_const_mem)
Compile preamble.
Definition piecewise.hpp:248
virtual bool is_power_like() const
Test if the node acts like a power of variable.
Definition piecewise.hpp:513
piecewise_1D_node(const backend::buffer< T > &d, shared_leaf< T, SAFE_MATH > x, const T scale, const T offset)
Construct 1D a piecewise constant node.
Definition piecewise.hpp:176
virtual bool is_constant() const
Test if node is a constant.
Definition piecewise.hpp:486
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Query if the nodes match.
Definition piecewise.hpp:440
virtual shared_leaf< T, SAFE_MATH > get_power_exponent() const
Get the exponent of a power.
Definition piecewise.hpp:531
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduction method.
Definition piecewise.hpp:205
virtual backend::buffer< T > evaluate()
Evaluate the results of the piecewise constant.
Definition piecewise.hpp:193
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition piecewise.hpp:233
T get_scale() const
Get x argument scale.
Definition piecewise.hpp:555
virtual void to_latex() const
Convert the node to latex.
Definition piecewise.hpp:454
virtual shared_leaf< T, SAFE_MATH > get_power_base()
Get the base of a power.
Definition piecewise.hpp:522
size_t get_size() const
Get the size of the buffer.
Definition piecewise.hpp:573
virtual bool has_constant_zero() const
Test the constant node has a zero.
Definition piecewise.hpp:495
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition piecewise.hpp:465
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition piecewise.hpp:341
T get_offset() const
Get x argument offset.
Definition piecewise.hpp:564
virtual bool is_all_variables() const
Test if node acts like a variable.
Definition piecewise.hpp:504
bool is_arg_match(shared_leaf< T, SAFE_MATH > x)
Check if the args match.
Definition piecewise.hpp:541
Class representing a 2D piecewise constant.
Definition piecewise.hpp:678
T get_y_scale() const
Get y argument scale.
Definition piecewise.hpp:820
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition piecewise.hpp:924
bool is_arg_match(shared_leaf< T, SAFE_MATH > x)
Check if the args match.
Definition piecewise.hpp:1301
T get_y_offset() const
Get y argument offset.
Definition piecewise.hpp:829
virtual bool is_power_like() const
Test if the node acts like a power of variable.
Definition piecewise.hpp:1273
T get_x_scale() const
Get x argument scale.
Definition piecewise.hpp:802
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition piecewise.hpp:1049
size_t get_num_columns() const
Get the number of columns.
Definition piecewise.hpp:783
virtual shared_leaf< T, SAFE_MATH > get_power_base()
Get the base of a power.
Definition piecewise.hpp:1282
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Query if the nodes match.
Definition piecewise.hpp:1195
virtual bool is_all_variables() const
Test if node acts like a variable.
Definition piecewise.hpp:1264
virtual bool is_constant() const
Test if node is a constant.
Definition piecewise.hpp:1246
size_t get_num_rows() const
Get the number of columns.
Definition piecewise.hpp:792
bool is_col_match(shared_leaf< T, SAFE_MATH > x)
Do the columns match.
Definition piecewise.hpp:1337
T get_x_offset() const
Get x argument offset.
Definition piecewise.hpp:811
virtual void to_latex() const
Convert the node to latex.
Definition piecewise.hpp:1212
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition piecewise.hpp:1223
bool is_row_match(shared_leaf< T, SAFE_MATH > x)
Do the rows match.
Definition piecewise.hpp:1320
virtual shared_leaf< T, SAFE_MATH > get_power_exponent() const
Get the exponent of a power.
Definition piecewise.hpp:1291
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduction method.
Definition piecewise.hpp:854
piecewise_2D_node(const backend::buffer< T > &d, const size_t n, shared_leaf< T, SAFE_MATH > x, const T x_scale, const T x_offset, shared_leaf< T, SAFE_MATH > y, const T y_scale, const T y_offset)
Construct 2D a piecewise constant node.
Definition piecewise.hpp:762
virtual backend::buffer< T > evaluate()
Evaluate the results of the piecewise constant.
Definition piecewise.hpp:842
virtual bool has_constant_zero() const
Test the constant node has a zero.
Definition piecewise.hpp:1255
virtual void compile_preamble(std::ostringstream &stream, jit::register_map &registers, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d, int &avail_const_mem)
Compile preamble.
Definition piecewise.hpp:939
Class representing a straight node.
Definition node.hpp:1051
shared_leaf< T, SAFE_MATH > arg
Argument.
Definition node.hpp:1054
Class representing a triple branch node.
Definition node.hpp:1289
shared_leaf< T, SAFE_MATH > middle
Middle branch of the tree.
Definition node.hpp:1292
Complex scalar concept.
Definition register.hpp:24
Double base concept.
Definition register.hpp:42
float base concept.
Definition register.hpp:37
subroutine assert(test, message)
Assert check.
Definition f_binding_test.f90:38
Name space for graph nodes.
Definition arithmetic.hpp:13
shared_piecewise_2D< T, SAFE_MATH > piecewise_2D_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a piecewise 2D node.
Definition piecewise.hpp:1406
constexpr shared_leaf< T, SAFE_MATH > zero()
Forward declare for zero.
Definition node.hpp:986
shared_index_2D< T, SAFE_MATH > index_2D_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a index 2D node.
Definition piecewise.hpp:2109
std::shared_ptr< index_2D_node< T, SAFE_MATH > > shared_index_2D
Convenience type alias for shared index 2D nodes.
Definition piecewise.hpp:2097
shared_piecewise_1D< T, SAFE_MATH > piecewise_1D_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a piecewise 1D node.
Definition piecewise.hpp:629
std::shared_ptr< piecewise_2D_node< T, SAFE_MATH > > shared_piecewise_2D
Convenience type alias for shared piecewise 2D nodes.
Definition piecewise.hpp:1394
shared_constant< T, SAFE_MATH > constant_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a constant node.
Definition node.hpp:1034
shared_variable< T, SAFE_MATH > variable_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a variable node.
Definition node.hpp:1727
std::shared_ptr< piecewise_1D_node< T, SAFE_MATH > > shared_piecewise_1D
Convenience type alias for shared piecewise 1D nodes.
Definition piecewise.hpp:617
constexpr T i
Convenience type for imaginary constant.
Definition node.hpp:1018
std::shared_ptr< leaf_node< T, SAFE_MATH > > shared_leaf
Convenience type alias for shared leaf nodes.
Definition node.hpp:676
std::shared_ptr< index_1D_node< T, SAFE_MATH > > shared_index_1D
Convenience type alias for shared index 1D nodes.
Definition piecewise.hpp:1702
void compile_index(std::ostringstream &stream, const std::string &register_name, const size_t length, const T scale, const T offset)
Compile an index.
Definition piecewise.hpp:26
shared_index_1D< T, SAFE_MATH > index_1D_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a index 1D node.
Definition piecewise.hpp:1714
std::map< void *, size_t > texture1d_list
Type alias for indexing 1D textures.
Definition register.hpp:263
std::string format_to_string(const T value)
Convert a value to a string while avoiding locale.
Definition register.hpp:212
std::map< void *, std::array< size_t, 2 > > texture2d_list
Type alias for indexing 2D textures.
Definition register.hpp:265
std::map< void *, size_t > register_usage
Type alias for counting register usage.
Definition register.hpp:259
std::map< void *, std::string > register_map
Type alias for mapping node pointers to register names.
Definition register.hpp:257
constexpr bool use_cuda()
Test to use Cuda.
Definition register.hpp:67
std::set< void * > visiter_map
Type alias for listing visited nodes.
Definition register.hpp:261
std::string to_string(const char prefix, const NODE *pointer)
Convert a graph::leaf_node pointer to a string.
Definition register.hpp:246
Base nodes of graph computation framework.
void index_2D()
Tests for 2D index nodes.
Definition piecewise_test.cpp:834
void piecewise_1D()
Tests for 1D piecewise nodes.
Definition piecewise_test.cpp:80
void index_1D()
Tests for 1D index nodes.
Definition piecewise_test.cpp:807
void piecewise_2D()
Tests for 2D piecewise nodes.
Definition piecewise_test.cpp:306