Graph Framework
Loading...
Searching...
No Matches
piecewise.hpp
Go to the documentation of this file.
1//------------------------------------------------------------------------------
6//------------------------------------------------------------------------------
7
8#ifndef piecewise_h
9#define piecewise_h
10
11#include "node.hpp"
12
13namespace graph {
14//------------------------------------------------------------------------------
24//------------------------------------------------------------------------------
25template<jit::float_scalar T>
26void compile_index(std::ostringstream &stream,
27 const std::string &register_name,
28 const size_t length,
29 const T scale,
30 const T offset) {
31 const std::string type = jit::smallest_int_type<T> (length);
32 stream << "min(max(("
33 << type
34 << ")";
35 if constexpr (jit::complex_scalar<T>) {
36 stream << "real(";
37 }
38 stream << "((" << register_name << " - ";
39 if constexpr (jit::complex_scalar<T>) {
41 }
42 stream << offset << ")/";
43 if constexpr (jit::complex_scalar<T>) {
45 }
46 stream << scale << ")";
47 if constexpr (jit::complex_scalar<T>) {
48 stream << ")";
49 }
50 stream << ",(" << type << ")0),("
51 << type << ")" << length - 1 << ")";
52}
53
54//******************************************************************************
55// 1D Piecewise node.
56//******************************************************************************
57//------------------------------------------------------------------------------
90//------------------------------------------------------------------------------
91 template<jit::float_scalar T, bool SAFE_MATH=false>
92 class piecewise_1D_node final : public straight_node<T, SAFE_MATH> {
94 const T scale;
96 const T offset;
97
98 private:
99//------------------------------------------------------------------------------
104//------------------------------------------------------------------------------
105 static std::string to_string(const backend::buffer<T> &d) {
106 std::string temp;
107 for (size_t i = 0, ie = d.size(); i < ie; i++) {
109 }
110
111 return temp;
112 }
113
114//------------------------------------------------------------------------------
120//------------------------------------------------------------------------------
121 static std::string to_string(const backend::buffer<T> &d,
123 return piecewise_1D_node::to_string(d) +
124 jit::format_to_string(x->get_hash());
125 }
126
127//------------------------------------------------------------------------------
132//------------------------------------------------------------------------------
133 static size_t hash_data(const backend::buffer<T> &d) {
134 const size_t h = std::hash<std::string>{} (piecewise_1D_node::to_string(d));
135 for (size_t i = h; i < std::numeric_limits<size_t>::max(); i++) {
136 if (leaf_node<T, SAFE_MATH>::caches.backends.find(i) ==
137 leaf_node<T, SAFE_MATH>::caches.backends.end()) {
139 return i;
140 } else if (d == leaf_node<T, SAFE_MATH>::caches.backends[i]) {
141 return i;
142 }
143 }
144#if defined(__clang__) || defined(__GNUC__)
146#else
147 assert(false && "Should never reach.");
148#endif
149 }
150
152 const size_t data_hash;
153
154 public:
155//------------------------------------------------------------------------------
162//------------------------------------------------------------------------------
165 const T scale,
166 const T offset) :
167 straight_node<T, SAFE_MATH> (x, piecewise_1D_node::to_string(d, x)),
168 data_hash(piecewise_1D_node::hash_data(d)), scale(scale),
169 offset(offset) {}
170
171//------------------------------------------------------------------------------
179//------------------------------------------------------------------------------
181 return leaf_node<T, SAFE_MATH>::caches.backends[data_hash];
182 }
183
184//------------------------------------------------------------------------------
191//------------------------------------------------------------------------------
193 if (evaluate().is_same()) {
194 return constant<T, SAFE_MATH> (evaluate().at(0));
195 }
196 return this->shared_from_this();
197 }
198
199//------------------------------------------------------------------------------
204//------------------------------------------------------------------------------
206 return constant<T, SAFE_MATH> (static_cast<T> (this->is_match(x)));
207 }
208
209//------------------------------------------------------------------------------
219//------------------------------------------------------------------------------
220 virtual void compile_preamble(std::ostringstream &stream,
221 jit::register_map &registers,
226 int &avail_const_mem) {
227 if (visited.find(this) == visited.end()) {
228 this->arg->compile_preamble(stream, registers,
229 visited, usage,
232 if (registers.find(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data()) == registers.end()) {
233 registers[leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data()] =
234 jit::to_string('a', leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data());
235 const size_t length = leaf_node<T, SAFE_MATH>::caches.backends[data_hash].size();
236 if constexpr (jit::use_metal<T> ()) {
237 textures1d.try_emplace(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data(),
238 length);
239#ifdef USE_CUDA_TEXTURES
240 } else if constexpr (jit::use_cuda()) {
241 textures1d.try_emplace(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data(),
242 length);
243#endif
244 } else {
245 if constexpr (jit::use_cuda()) {
246 const int buffer_size = length*sizeof(T);
247 if (avail_const_mem - buffer_size > 0) {
249 stream << "__constant__ ";
250 }
251 }
252 stream << "const ";
253 jit::add_type<T> (stream);
254 stream << " " << registers[leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data()] << "[] = {";
255 if constexpr (jit::complex_scalar<T>) {
256 jit::add_type<T> (stream);
257 }
258 stream << leaf_node<T, SAFE_MATH>::caches.backends[data_hash][0];
259 for (size_t i = 1; i < length; i++) {
260 stream << ", ";
261 if constexpr (jit::complex_scalar<T>) {
262 jit::add_type<T> (stream);
263 }
264 stream << leaf_node<T, SAFE_MATH>::caches.backends[data_hash][i];
265 }
266 stream << "};" << std::endl;
267 }
268 } else {
269// When using textures, the register can be defined in a previous kernel. We
270// need to add the textures again.
271 const size_t length = leaf_node<T, SAFE_MATH>::caches.backends[data_hash].size();
272 if constexpr (jit::use_metal<T> ()) {
273 textures1d.try_emplace(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data(),
274 length);
275#ifdef USE_CUDA_TEXTURES
276 } else if constexpr (jit::use_cuda()) {
277 textures1d.try_emplace(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data(),
278 length);
279#endif
280 }
281 }
282 visited.insert(this);
283#ifdef SHOW_USE_COUNT
284 usage[this] = 1;
285 } else {
286 ++usage[this];
287#endif
288 }
289 }
290
291//------------------------------------------------------------------------------
311//------------------------------------------------------------------------------
313 compile(std::ostringstream &stream,
314 jit::register_map &registers,
316 const jit::register_usage &usage) {
317 if (registers.find(this) == registers.end()) {
318#ifdef USE_INDEX_CACHE
319 if (indices.find(this->arg.get()) == indices.end()) {
320#endif
321 const size_t length = leaf_node<T, SAFE_MATH>::caches.backends[data_hash].size();
322 shared_leaf<T, SAFE_MATH> a = this->arg->compile(stream,
323 registers,
324 indices,
325 usage);
326#ifdef USE_INDEX_CACHE
327 indices[a.get()] = jit::to_string('i', a.get());
328 stream << " const "
329 << jit::smallest_int_type<T> (length) << " "
330 << indices[a.get()] << " = ";
331 compile_index<T> (stream, registers[a.get()], length,
332 scale, offset);
333 a->endline(stream, usage);
334 }
335#endif
336
337 registers[this] = jit::to_string('r', this);
338 stream << " const ";
339 jit::add_type<T> (stream);
340 stream << " " << registers[this] << " = ";
341#ifdef USE_CUDA_TEXTURES
342 if constexpr (jit::use_cuda()) {
343 if constexpr (float_base<T>) {
344 if constexpr (complex_scalar<T>) {
345 stream << "to_cmp_float(tex1D<float2> (";
346 } else {
347 stream << "tex1D<float> (";
348 }
349 } else {
350 if constexpr (complex_scalar<T>) {
351 stream << "to_cmp_double(tex1D<uint4> (";
352 } else {
353 stream << "to_double(tex1D<uint2> (";
354 }
355 }
356 }
357#endif
358 stream << registers[leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data()];
359 if constexpr (jit::use_metal<T> ()) {
360#ifdef USE_INDEX_CACHE
361 stream << ".read("
362 << indices[this->arg.get()]
363 << ").r";
364#else
365 stream << ".read(";
366 compile_index<T> (stream, registers[a.get()], length,
367 scale, offset);
368 stream << ").r";
369#endif
370#ifdef USE_CUDA_TEXTURES
371 } else if constexpr (jit::use_cuda()) {
372#ifdef USE_INDEX_CACHE
373 stream << ", "
374 << indices[this->arg.get()];
375#else
376 stream << ", ";
377 compile_index<T> (stream, registers[a.get()], length,
378 scale, offset);
379#endif
381 stream << ")";
382 }
383 stream << ")";
384#endif
385 } else {
386#ifdef USE_INDEX_CACHE
387 stream << "["
388 << indices[this->arg.get()]
389 << "]";
390#else
391 stream << "[";
392 compile_index<T> (stream, registers[a.get()], length,
393 scale, offset);
394 stream << "]";
395#endif
396 }
397 this->endline(stream, usage);
398 }
399
400 return this->shared_from_this();
401 }
402
403//------------------------------------------------------------------------------
411//------------------------------------------------------------------------------
413 auto x_cast = piecewise_1D_cast(x);
414
415 if (x_cast.get()) {
416 return this->data_hash == x_cast->data_hash &&
417 this->arg->is_match(x_cast->get_arg());
418 }
419
420 return false;
421 }
422
423//------------------------------------------------------------------------------
425//------------------------------------------------------------------------------
426 virtual void to_latex() const {
427 std::cout << "r\\_" << reinterpret_cast<size_t> (this) << "_{i}";
428 }
429
430//------------------------------------------------------------------------------
436//------------------------------------------------------------------------------
437 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
438 jit::register_map &registers) {
439 if (registers.find(this) == registers.end()) {
440 const std::string name = jit::to_string('r', this);
441 registers[this] = name;
442 stream << " " << name
443 << " [label = \"r_" << reinterpret_cast<size_t> (this)
444 << "_{i}\", shape = hexagon, style = filled, fillcolor = black, fontcolor = white];" << std::endl;
445
446 auto a = this->arg->to_vizgraph(stream, registers);
447 stream << " " << name << " -- " << registers[a.get()] << ";" << std::endl;
448 }
449
450 return this->shared_from_this();
451 }
452
453//------------------------------------------------------------------------------
457//------------------------------------------------------------------------------
458 virtual bool is_constant() const {
459 return true;
460 }
461
462//------------------------------------------------------------------------------
466//------------------------------------------------------------------------------
467 virtual bool has_constant_zero() const {
468 return leaf_node<T, SAFE_MATH>::caches.backends[data_hash].has_zero();
469 }
470
471//------------------------------------------------------------------------------
475//------------------------------------------------------------------------------
476 virtual bool is_all_variables() const {
477 return false;
478 }
479
480//------------------------------------------------------------------------------
484//------------------------------------------------------------------------------
485 virtual bool is_power_like() const {
486 return true;
487 }
488
489//------------------------------------------------------------------------------
493//------------------------------------------------------------------------------
495 return this->shared_from_this();
496 }
497
498//------------------------------------------------------------------------------
502//------------------------------------------------------------------------------
504 return one<T, SAFE_MATH> ();
505 }
506
507//------------------------------------------------------------------------------
512//------------------------------------------------------------------------------
514 auto temp = piecewise_1D_cast(x);
515 return temp.get() &&
516 this->arg->is_match(temp->get_arg()) &&
517 (temp->get_size() == this->get_size()) &&
518 (temp->get_scale() == this->scale) &&
519 (temp->get_offset() == this->offset);
520 }
521
522//------------------------------------------------------------------------------
526//------------------------------------------------------------------------------
527 T get_scale() const {
528 return scale;
529 }
530
531//------------------------------------------------------------------------------
535//------------------------------------------------------------------------------
536 T get_offset() const {
537 return offset;
538 }
539
540//------------------------------------------------------------------------------
544//------------------------------------------------------------------------------
545 size_t get_size() const {
546 return leaf_node<T, SAFE_MATH>::caches.backends[data_hash].size();
547 }
548 };
549
550//------------------------------------------------------------------------------
561//------------------------------------------------------------------------------
562 template<jit::float_scalar T, bool SAFE_MATH=false>
565 const T scale,
566 const T offset) {
567 auto temp = std::make_shared<piecewise_1D_node<T, SAFE_MATH>> (d, x,
568 scale,
569 offset)->reduce();
570// Test for hash collisions.
571 for (size_t i = temp->get_hash(); i < std::numeric_limits<size_t>::max(); i++) {
572 if (leaf_node<T, SAFE_MATH>::caches.nodes.find(i) ==
575 return temp;
576 } else if (temp->is_match(leaf_node<T, SAFE_MATH>::caches.nodes[i])) {
578 }
579 }
580#if defined(__clang__) || defined(__GNUC__)
582#else
583 assert(false && "Should never reach.");
584#endif
585 }
586
588 template<jit::float_scalar T, bool SAFE_MATH=false>
589 using shared_piecewise_1D = std::shared_ptr<piecewise_1D_node<T, SAFE_MATH>>;
590
591//------------------------------------------------------------------------------
599//------------------------------------------------------------------------------
600 template<jit::float_scalar T, bool SAFE_MATH=false>
602 return std::dynamic_pointer_cast<piecewise_1D_node<T, SAFE_MATH>> (x);
603 }
604
605//******************************************************************************
606// 2D Piecewise node.
607//******************************************************************************
608//------------------------------------------------------------------------------
648//------------------------------------------------------------------------------
649 template<jit::float_scalar T, bool SAFE_MATH=false>
650 class piecewise_2D_node final : public branch_node<T, SAFE_MATH> {
651 private:
653 const T x_scale;
655 const T x_offset;
657 const T y_scale;
659 const T y_offset;
660
661//------------------------------------------------------------------------------
666//------------------------------------------------------------------------------
667 static std::string to_string(const backend::buffer<T> &d) {
668 std::string temp;
669 for (size_t i = 0, ie = d.size(); i < ie; i++) {
671 }
672
673 return temp;
674 }
675
676//------------------------------------------------------------------------------
683//------------------------------------------------------------------------------
684 static std::string to_string(const backend::buffer<T> &d,
687 return piecewise_2D_node::to_string(d) +
688 jit::format_to_string(x->get_hash()) +
689 jit::format_to_string(y->get_hash());
690 }
691
692//------------------------------------------------------------------------------
697//------------------------------------------------------------------------------
698 static size_t hash_data(const backend::buffer<T> &d) {
699 const size_t h = std::hash<std::string>{} (piecewise_2D_node::to_string(d));
700 for (size_t i = h; i < std::numeric_limits<size_t>::max(); i++) {
701 if (leaf_node<T, SAFE_MATH>::caches.backends.find(i) ==
702 leaf_node<T, SAFE_MATH>::caches.backends.end()) {
704 return i;
705 } else if (d == leaf_node<T, SAFE_MATH>::caches.backends[i]) {
706 return i;
707 }
708 }
709#if defined(__clang__) || defined(__GNUC__)
711#else
712 assert(false && "Should never reach.");
713#endif
714 }
715
717 const size_t data_hash;
719 const size_t num_columns;
720
721 public:
722//------------------------------------------------------------------------------
733//------------------------------------------------------------------------------
735 const size_t n,
737 const T x_scale,
738 const T x_offset,
740 const T y_scale,
741 const T y_offset) :
742 branch_node<T, SAFE_MATH> (x, y, piecewise_2D_node::to_string(d, x, y)),
743 data_hash(piecewise_2D_node::hash_data(d)),
744 num_columns(n), x_scale(x_scale), x_offset(x_offset), y_scale(y_scale),
745 y_offset(y_offset) {
746 assert(d.size()%n == 0 &&
747 "Expected the data buffer to be a multiple of the number of columns.");
748 }
749
750//------------------------------------------------------------------------------
754//------------------------------------------------------------------------------
755 size_t get_num_columns() const {
756 return num_columns;
757 }
758
759//------------------------------------------------------------------------------
763//------------------------------------------------------------------------------
764 size_t get_num_rows() const {
765 return leaf_node<T, SAFE_MATH>::caches.backends[data_hash].size() /
766 num_columns;
767 }
768
769//------------------------------------------------------------------------------
773//------------------------------------------------------------------------------
774 T get_x_scale() const {
775 return x_scale;
776 }
777
778//------------------------------------------------------------------------------
782//------------------------------------------------------------------------------
783 T get_x_offset() const {
784 return x_offset;
785 }
786
787//------------------------------------------------------------------------------
791//------------------------------------------------------------------------------
792 T get_y_scale() const {
793 return y_scale;
794 }
795
796//------------------------------------------------------------------------------
800//------------------------------------------------------------------------------
801 T get_y_offset() const {
802 return y_offset;
803 }
804
805//------------------------------------------------------------------------------
813//------------------------------------------------------------------------------
815 return leaf_node<T, SAFE_MATH>::caches.backends[data_hash];
816 }
817
818//------------------------------------------------------------------------------
825//------------------------------------------------------------------------------
827 if (evaluate().is_same()) {
828 return constant<T, SAFE_MATH> (evaluate().at(0));
829 }
830
831 return this->shared_from_this();
832 }
833
834//------------------------------------------------------------------------------
839//------------------------------------------------------------------------------
841 return constant<T, SAFE_MATH> (static_cast<T> (this->is_match(x)));
842 }
843
844//------------------------------------------------------------------------------
854//------------------------------------------------------------------------------
855 virtual void compile_preamble(std::ostringstream &stream,
856 jit::register_map &registers,
861 int &avail_const_mem) {
862 if (visited.find(this) == visited.end()) {
863 this->left->compile_preamble(stream, registers,
864 visited, usage,
867 this->right->compile_preamble(stream, registers,
868 visited, usage,
871 if (registers.find(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data()) == registers.end()) {
872 registers[leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data()] =
873 jit::to_string('a', leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data());
874 const size_t length = leaf_node<T, SAFE_MATH>::caches.backends[data_hash].size();
875 if constexpr (jit::use_metal<T> ()) {
876 textures2d.try_emplace(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data(),
877 std::array<size_t, 2> ({length/num_columns, num_columns}));
878#ifdef USE_CUDA_TEXTURES
879 } else if constexpr (jit::use_cuda()) {
880 textures2d.try_emplace(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data(),
881 std::array<size_t, 2> ({length/num_columns, num_columns}));
882#endif
883 } else {
884 if constexpr (jit::use_cuda()) {
885 const int buffer_size = length*sizeof(T);
886 if (avail_const_mem - buffer_size > 0) {
888 stream << "__constant__ ";
889 }
890 }
891 stream << "const ";
892 jit::add_type<T> (stream);
893 stream << " " << registers[leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data()] << "[] = {";
894 if constexpr (jit::complex_scalar<T>) {
895 jit::add_type<T> (stream);
896 }
897 stream << leaf_node<T, SAFE_MATH>::caches.backends[data_hash][0];
898 for (size_t i = 1; i < length; i++) {
899 stream << ", ";
900 if constexpr (jit::complex_scalar<T>) {
901 jit::add_type<T> (stream);
902 }
903 stream << leaf_node<T, SAFE_MATH>::caches.backends[data_hash][i];
904 }
905 stream << "};" << std::endl;
906 }
907 } else {
908// When using textures, the register can be defined in a previous kernel. We
909// need to add the textures again.
910 const size_t length = leaf_node<T, SAFE_MATH>::caches.backends[data_hash].size();
911 if constexpr (jit::use_metal<T> ()) {
912 textures2d.try_emplace(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data(),
913 std::array<size_t, 2> ({length/num_columns, num_columns}));
914#ifdef USE_CUDA_TEXTURES
915 } else if constexpr (jit::use_cuda()) {
916 textures2d.try_emplace(leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data(),
917 std::array<size_t, 2> ({length/num_columns, num_columns}));
918#endif
919 }
920 }
921 visited.insert(this);
922#ifdef SHOW_USE_COUNT
923 usage[this] = 1;
924 } else {
925 ++usage[this];
926#endif
927 }
928 }
929
930//------------------------------------------------------------------------------
963//------------------------------------------------------------------------------
965 compile(std::ostringstream &stream,
966 jit::register_map &registers,
968 const jit::register_usage &usage) {
969 if (registers.find(this) == registers.end()) {
970 const size_t length = leaf_node<T, SAFE_MATH>::caches.backends[data_hash].size();
971 const size_t num_rows = length/num_columns;
972
973 shared_leaf<T, SAFE_MATH> x = this->left->compile(stream,
974 registers,
975 indices,
976 usage);
977 shared_leaf<T, SAFE_MATH> y = this->right->compile(stream,
978 registers,
979 indices,
980 usage);
981
982#ifdef USE_INDEX_CACHE
983 if (indices.find(x.get()) == indices.end()) {
984 indices[x.get()] = jit::to_string('i', x.get());
985 stream << " const "
986 << jit::smallest_int_type<T> (num_rows) << " "
987 << indices[x.get()] << " = ";
988 compile_index<T> (stream, registers[x.get()], num_rows,
989 x_scale, x_offset);
990 x->endline(stream, usage);
991 }
992 if (indices.find(y.get()) == indices.end()) {
993 indices[y.get()] = jit::to_string('i', y.get());
994 stream << " const "
995 << jit::smallest_int_type<T> (num_columns) << " "
996 << indices[y.get()] << " = ";
997 compile_index<T> (stream, registers[y.get()], num_columns,
998 y_scale, y_offset);
999 y->endline(stream, usage);
1000 }
1001
1002 auto temp = this->left + this->right;
1003 if constexpr (!jit::use_metal<T> ()
1004#ifdef USE_CUDA_TEXTURES
1005 || !jit::use_cuda()
1006#endif
1007 ) {
1008 if (indices.find(temp.get()) == indices.end()) {
1009 indices[temp.get()] = jit::to_string('i', temp.get());
1010 const size_t length = leaf_node<T, SAFE_MATH>::caches.backends[data_hash].size();
1011 stream << " const "
1012 << jit::smallest_int_type<T> (length) << " "
1013 << indices[temp.get()] << " = "
1014 << indices[x.get()]
1015 << "*" << num_columns << " + "
1016 << indices[y.get()]
1017 << ";" << std::endl;
1018 }
1019 }
1020#endif
1021
1022 registers[this] = jit::to_string('r', this);
1023 stream << " const ";
1024 jit::add_type<T> (stream);
1025 stream << " " << registers[this] << " = ";
1026#ifdef USE_CUDA_TEXTURES
1027 if constexpr (jit::use_cuda()) {
1028 if constexpr (float_base<T>) {
1029 if constexpr (complex_scalar<T>) {
1030 stream << "to_cmp_float(tex1D<float2> (";
1031 } else {
1032 stream << "tex1D<float> (";
1033 }
1034 } else {
1035 if constexpr (complex_scalar<T>) {
1036 stream << "to_cmp_double(tex1D<uint4> (";
1037 } else {
1038 stream << "to_double(tex1D<uint2> (";
1039 }
1040 }
1041 }
1042#endif
1043 stream << registers[leaf_node<T, SAFE_MATH>::caches.backends[data_hash].data()];
1044 if constexpr (jit::use_metal<T> ()) {
1045#ifdef USE_INDEX_CACHE
1046 stream << ".read("
1047 << jit::smallest_int_type<T> (std::max(num_rows,
1048 num_columns))
1049 << "2("
1050 << indices[y.get()]
1051 << ","
1052 << indices[x.get()]
1053 << ")).r";
1054#else
1055 stream << ".read(uint2(";
1056 compile_index<T> (stream, registers[y.get()], num_columns,
1057 y_scale, y_offset);
1058 stream << ",";
1059 compile_index<T> (stream, registers[x.get()], num_rows,
1060 x_scale, x_offset);
1061 stream << ")).r";
1062#endif
1063#ifdef USE_CUDA_TEXTURES
1064 } else if constexpr (jit::use_cuda()) {
1065#ifdef USE_INDEX_CACHE
1066 stream << ", "
1067 << indices[y.get()]
1068 << ", "
1069 << indices[x.get()];
1070#else
1071 stream << ", ";
1072 compile_index<T> (stream, registers[y.get()], num_columns,
1073 y_scale, y_offset);
1074 stream << ", ";
1075 compile_index<T> (stream, registers[x.get()], num_rows,
1076 x_scale, x_offset);
1077#endif
1079 stream << ")";
1080 }
1081 stream << ")";
1082#endif
1083 } else {
1084#ifdef USE_INDEX_CACHE
1085 stream << "["
1086 << indices[temp.get()]
1087 << "]";
1088#else
1089 stream << "[";
1090 compile_index<T> (stream, registers[x.get()], num_rows,
1091 x_scale, x_offset);
1092 stream << "*" << num_columns << " + ";
1093 compile_index<T> (stream, registers[y.get()], num_columns,
1094 y_scale, y_offset);
1095 stream << "]";
1096#endif
1097 }
1098 this->endline(stream, usage);
1099 }
1100
1101 return this->shared_from_this();
1102 }
1103
1104//------------------------------------------------------------------------------
1111//------------------------------------------------------------------------------
1113 auto x_cast = piecewise_2D_cast(x);
1114
1115 if (x_cast.get()) {
1116 return this->data_hash == x_cast->data_hash &&
1117 this->left->is_match(x_cast->get_left()) &&
1118 this->right->is_match(x_cast->get_right());
1119 }
1120
1121 return false;
1122 }
1123
1124//------------------------------------------------------------------------------
1128//------------------------------------------------------------------------------
1129 virtual void to_latex() const {
1130 std::cout << "r\\_" << reinterpret_cast<size_t> (this) << "_{ij}";
1131 }
1132
1133//------------------------------------------------------------------------------
1139//------------------------------------------------------------------------------
1140 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
1141 jit::register_map &registers) {
1142 if (registers.find(this) == registers.end()) {
1143 const std::string name = jit::to_string('r', this);
1144 registers[this] = name;
1145 stream << " " << name
1146 << " [label = \"r_" << reinterpret_cast<size_t> (this)
1147 << "_{ij}\", shape = hexagon, style = filled, fillcolor = black, fontcolor = white];" << std::endl;
1148
1149 auto l = this->left->to_vizgraph(stream, registers);
1150 stream << " " << name << " -- " << registers[l.get()] << ";" << std::endl;
1151 auto r = this->right->to_vizgraph(stream, registers);
1152 stream << " " << name << " -- " << registers[r.get()] << ";" << std::endl;
1153 }
1154
1155 return this->shared_from_this();
1156 }
1157
1158//------------------------------------------------------------------------------
1162//------------------------------------------------------------------------------
1163 virtual bool is_constant() const {
1164 return true;
1165 }
1166
1167//------------------------------------------------------------------------------
1171//------------------------------------------------------------------------------
1172 virtual bool has_constant_zero() const {
1173 return leaf_node<T, SAFE_MATH>::caches.backends[data_hash].has_zero();
1174 }
1175
1176//------------------------------------------------------------------------------
1180//------------------------------------------------------------------------------
1181 virtual bool is_all_variables() const {
1182 return false;
1183 }
1184
1185//------------------------------------------------------------------------------
1189//------------------------------------------------------------------------------
1190 virtual bool is_power_like() const {
1191 return true;
1192 }
1193
1194//------------------------------------------------------------------------------
1198//------------------------------------------------------------------------------
1200 return this->shared_from_this();
1201 }
1202
1203//------------------------------------------------------------------------------
1207//------------------------------------------------------------------------------
1209 return one<T, SAFE_MATH> ();
1210 }
1211
1212//------------------------------------------------------------------------------
1217//------------------------------------------------------------------------------
1219 auto temp = piecewise_2D_cast(x);
1220 return temp.get() &&
1221 this->left->is_match(temp->get_left()) &&
1222 this->right->is_match(temp->get_right()) &&
1223 (temp->get_num_rows() == this->get_num_rows()) &&
1224 (temp->get_num_columns() == this->get_num_columns()) &&
1225 (temp->get_x_scale() == this->x_scale) &&
1226 (temp->get_x_offset() == this->x_offset) &&
1227 (temp->get_y_scale() == this->y_scale) &&
1228 (temp->get_y_offset() == this->y_offset);
1229 }
1230
1231//------------------------------------------------------------------------------
1236//------------------------------------------------------------------------------
1238 auto temp = piecewise_1D_cast(x);
1239 return temp.get() &&
1240 this->left->is_match(temp->get_arg()) &&
1241 (temp->get_size() == this->get_num_rows()) &&
1242 (temp->get_scale() == this->x_scale) &&
1243 (temp->get_offset() == this->x_offset);
1244 }
1245
1246//------------------------------------------------------------------------------
1253//------------------------------------------------------------------------------
1255 auto temp = piecewise_1D_cast(x);
1256 return temp.get() &&
1257 this->right->is_match(temp->get_arg()) &&
1258 (temp->get_size() == this->get_num_columns()) &&
1259 (temp->get_scale() == this->y_scale) &&
1260 (temp->get_offset() == this->y_offset);
1261 }
1262 };
1263
1264//------------------------------------------------------------------------------
1279//------------------------------------------------------------------------------
1280 template<jit::float_scalar T, bool SAFE_MATH=false>
1282 const size_t n,
1284 const T x_scale,
1285 const T x_offset,
1287 const T y_scale,
1288 const T y_offset) {
1289 auto temp = std::make_shared<piecewise_2D_node<T, SAFE_MATH>> (d, n,
1290 x, x_scale, x_offset,
1291 y, y_scale, y_offset)->reduce();
1292// Test for hash collisions.
1293 for (size_t i = temp->get_hash(); i < std::numeric_limits<size_t>::max(); i++) {
1294 if (leaf_node<T, SAFE_MATH>::caches.nodes.find(i) ==
1295 leaf_node<T, SAFE_MATH>::caches.nodes.end()) {
1297 return temp;
1298 } else if (temp->is_match(leaf_node<T, SAFE_MATH>::caches.nodes[i])) {
1299 return leaf_node<T, SAFE_MATH>::caches.nodes[i];
1300 }
1301 }
1302#if defined(__clang__) || defined(__GNUC__)
1304#else
1305 assert(false && "Should never reach.");
1306#endif
1307 }
1308
1310 template<jit::float_scalar T, bool SAFE_MATH=false>
1311 using shared_piecewise_2D = std::shared_ptr<piecewise_2D_node<T, SAFE_MATH>>;
1312
1313//------------------------------------------------------------------------------
1321//------------------------------------------------------------------------------
1322 template<jit::float_scalar T, bool SAFE_MATH=false>
1324 return std::dynamic_pointer_cast<piecewise_2D_node<T, SAFE_MATH>> (x);
1325 }
1326}
1327
1328#endif /* piecewise_h */
Class representing a generic buffer.
Definition backend.hpp:29
Class representing a branch node.
Definition node.hpp:1173
shared_leaf< T, SAFE_MATH > right
Right branch of the tree.
Definition node.hpp:1178
shared_leaf< T, SAFE_MATH > left
Left branch of the tree.
Definition node.hpp:1176
Class representing a node leaf.
Definition node.hpp:364
virtual void endline(std::ostringstream &stream, const jit::register_usage &usage) const final
End a line in the kernel source.
Definition node.hpp:637
virtual std::shared_ptr< leaf_node< T, SAFE_MATH > > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)=0
Compile the node.
Class representing a 1D piecewise constant.
Definition piecewise.hpp:92
virtual void compile_preamble(std::ostringstream &stream, jit::register_map &registers, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d, int &avail_const_mem)
Compile preamble.
Definition piecewise.hpp:220
virtual bool is_power_like() const
Test if the node acts like a power of variable.
Definition piecewise.hpp:485
piecewise_1D_node(const backend::buffer< T > &d, shared_leaf< T, SAFE_MATH > x, const T scale, const T offset)
Construct 1D a piecewise constant node.
Definition piecewise.hpp:163
virtual bool is_constant() const
Test if node is a constant.
Definition piecewise.hpp:458
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Querey if the nodes match.
Definition piecewise.hpp:412
virtual shared_leaf< T, SAFE_MATH > get_power_exponent() const
Get the exponent of a power.
Definition piecewise.hpp:503
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduction method.
Definition piecewise.hpp:192
virtual backend::buffer< T > evaluate()
Evaluate the results of the piecewise constant.
Definition piecewise.hpp:180
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition piecewise.hpp:205
T get_scale() const
Get x argument scale.
Definition piecewise.hpp:527
virtual void to_latex() const
Convert the node to latex.
Definition piecewise.hpp:426
virtual shared_leaf< T, SAFE_MATH > get_power_base()
Get the base of a power.
Definition piecewise.hpp:494
size_t get_size() const
Get the size of the buffer.
Definition piecewise.hpp:545
virtual bool has_constant_zero() const
Test the constant node has a zero.
Definition piecewise.hpp:467
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition piecewise.hpp:437
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition piecewise.hpp:313
T get_offset() const
Get x argument offset.
Definition piecewise.hpp:536
virtual bool is_all_variables() const
Test if node acts like a variable.
Definition piecewise.hpp:476
bool is_arg_match(shared_leaf< T, SAFE_MATH > x)
Check if the args match.
Definition piecewise.hpp:513
Class representing a 2D piecewise constant.
Definition piecewise.hpp:650
T get_y_scale() const
Get y argument scale.
Definition piecewise.hpp:792
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition piecewise.hpp:840
bool is_arg_match(shared_leaf< T, SAFE_MATH > x)
Check if the args match.
Definition piecewise.hpp:1218
T get_y_offset() const
Get y argument offset.
Definition piecewise.hpp:801
virtual bool is_power_like() const
Test if the node acts like a power of variable.
Definition piecewise.hpp:1190
T get_x_scale() const
Get x argument scale.
Definition piecewise.hpp:774
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition piecewise.hpp:965
size_t get_num_columns() const
Get the number of columns.
Definition piecewise.hpp:755
virtual shared_leaf< T, SAFE_MATH > get_power_base()
Get the base of a power.
Definition piecewise.hpp:1199
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Querey if the nodes match.
Definition piecewise.hpp:1112
virtual bool is_all_variables() const
Test if node acts like a variable.
Definition piecewise.hpp:1181
virtual bool is_constant() const
Test if node is a constant.
Definition piecewise.hpp:1163
size_t get_num_rows() const
Get the number of columns.
Definition piecewise.hpp:764
bool is_col_match(shared_leaf< T, SAFE_MATH > x)
Do the columns match.
Definition piecewise.hpp:1254
T get_x_offset() const
Get x argument offset.
Definition piecewise.hpp:783
virtual void to_latex() const
Convert the node to latex.
Definition piecewise.hpp:1129
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition piecewise.hpp:1140
bool is_row_match(shared_leaf< T, SAFE_MATH > x)
Do the rows match.
Definition piecewise.hpp:1237
virtual shared_leaf< T, SAFE_MATH > get_power_exponent() const
Get the exponent of a power.
Definition piecewise.hpp:1208
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduction method.
Definition piecewise.hpp:826
piecewise_2D_node(const backend::buffer< T > &d, const size_t n, shared_leaf< T, SAFE_MATH > x, const T x_scale, const T x_offset, shared_leaf< T, SAFE_MATH > y, const T y_scale, const T y_offset)
Construct 2D a piecewise constant node.
Definition piecewise.hpp:734
virtual backend::buffer< T > evaluate()
Evaluate the results of the piecewise constant.
Definition piecewise.hpp:814
virtual bool has_constant_zero() const
Test the constant node has a zero.
Definition piecewise.hpp:1172
virtual void compile_preamble(std::ostringstream &stream, jit::register_map &registers, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d, int &avail_const_mem)
Compile preamble.
Definition piecewise.hpp:855
Class representing a straight node.
Definition node.hpp:1059
shared_leaf< T, SAFE_MATH > arg
Argument.
Definition node.hpp:1062
Complex scalar concept.
Definition register.hpp:24
Double base concept.
Definition register.hpp:42
subroutine assert(test, message)
Assert check.
Definition f_binding_test.f90:38
Name space for graph nodes.
Definition arithmetic.hpp:13
shared_piecewise_2D< T, SAFE_MATH > piecewise_2D_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a piecewise 2D node.
Definition piecewise.hpp:1323
constexpr shared_leaf< T, SAFE_MATH > zero()
Forward declare for zero.
Definition node.hpp:994
shared_piecewise_1D< T, SAFE_MATH > piecewise_1D_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a piecewise 1D node.
Definition piecewise.hpp:601
std::shared_ptr< piecewise_2D_node< T, SAFE_MATH > > shared_piecewise_2D
Convenience type alias for shared piecewise 2D nodes.
Definition piecewise.hpp:1311
std::shared_ptr< piecewise_1D_node< T, SAFE_MATH > > shared_piecewise_1D
Convenience type alias for shared piecewise 1D nodes.
Definition piecewise.hpp:589
constexpr T i
Convinece type for imaginary constant.
Definition node.hpp:1026
std::shared_ptr< leaf_node< T, SAFE_MATH > > shared_leaf
Convenience type alias for shared leaf nodes.
Definition node.hpp:673
void compile_index(std::ostringstream &stream, const std::string &register_name, const size_t length, const T scale, const T offset)
Compile an index.
Definition piecewise.hpp:26
std::map< void *, size_t > texture1d_list
Type alias for indexing 1D textures.
Definition register.hpp:262
std::string format_to_string(const T value)
Convert a value to a string while avoiding locale.
Definition register.hpp:211
std::map< void *, std::array< size_t, 2 > > texture2d_list
Type alias for indexing 2D textures.
Definition register.hpp:264
std::map< void *, size_t > register_usage
Type alias for counting register usage.
Definition register.hpp:258
std::map< void *, std::string > register_map
Type alias for mapping node pointers to register names.
Definition register.hpp:256
constexpr bool use_cuda()
Test to use Cuda.
Definition register.hpp:67
std::set< void * > visiter_map
Type alias for listing visited nodes.
Definition register.hpp:260
std::string to_string(const char prefix, const NODE *pointer)
Convert a graph::leaf_node pointer to a string.
Definition register.hpp:245
Base nodes of graph computation framework.
void piecewise_1D()
Tests for 1D piecewise nodes.
Definition piecewise_test.cpp:80
void piecewise_2D()
Tests for 2D piecewise nodes.
Definition piecewise_test.cpp:283