25template<jit::
float_scalar T>
31 const std::string type = jit::smallest_int_type<T> (length);
42 stream << offset <<
")/";
46 stream << scale <<
")";
50 stream <<
",(" << type <<
")0),("
51 << type <<
")" << length - 1 <<
")";
91 template<jit::
float_scalar T,
bool SAFE_MATH=false>
107 for (
size_t i = 0,
ie =
d.size();
i <
ie;
i++) {
123 return piecewise_1D_node::to_string(
d) +
134 const size_t h = std::hash<std::string>{} (piecewise_1D_node::to_string(
d));
144#if defined(__clang__) || defined(__GNUC__)
147 assert(
false &&
"Should never reach.");
152 const size_t data_hash;
194 const T
arg = (this->arg->evaluate().at(0) + offset)/scale;
195 const size_t i = std::min(
static_cast<size_t> (std::real(
arg)),
235 this->
arg->compile_preamble(stream, registers,
243 if constexpr (jit::use_metal<T> ()) {
246#ifdef USE_CUDA_TEXTURES
256 stream <<
"__constant__ ";
260 jit::add_type<T> (stream);
263 jit::add_type<T> (stream);
266 for (
size_t i = 1;
i < length;
i++) {
269 jit::add_type<T> (stream);
273 stream <<
"};" << std::endl;
279 if constexpr (jit::use_metal<T> ()) {
282#ifdef USE_CUDA_TEXTURES
324 if (registers.find(
this) == registers.end()) {
325#ifdef USE_INDEX_CACHE
333#ifdef USE_INDEX_CACHE
336 << jit::smallest_int_type<T> (length) <<
" "
340 a->endline(stream,
usage);
346 jit::add_type<T> (stream);
347 stream <<
" " << registers[
this] <<
" = ";
348#ifdef USE_CUDA_TEXTURES
352 stream <<
"to_cmp_float(tex1D<float2> (";
354 stream <<
"tex1D<float> (";
358 stream <<
"to_cmp_double(tex1D<uint4> (";
360 stream <<
"to_double(tex1D<uint2> (";
366 if constexpr (jit::use_metal<T> ()) {
367#ifdef USE_INDEX_CACHE
377#ifdef USE_CUDA_TEXTURES
379#ifdef USE_INDEX_CACHE
393#ifdef USE_INDEX_CACHE
423 return this->data_hash ==
x_cast->data_hash &&
434 std::cout <<
"r\\_" <<
reinterpret_cast<size_t> (
this) <<
"_{i}";
446 if (registers.find(
this) == registers.end()) {
448 registers[
this] =
name;
449 stream <<
" " <<
name
450 <<
" [label = \"r_" <<
reinterpret_cast<size_t> (
this)
451 <<
"_{i}\", shape = hexagon, style = filled, fillcolor = black, fontcolor = white];" << std::endl;
453 auto a = this->
arg->to_vizgraph(stream, registers);
454 stream <<
" " <<
name <<
" -- " << registers[
a.get()] <<
";" << std::endl;
523 this->
arg->is_match(
temp->get_arg()) &&
525 (
temp->get_scale() == this->scale) &&
526 (
temp->get_offset() == this->offset);
569 template<jit::
float_scalar T,
bool SAFE_MATH=false>
574 auto temp = std::make_shared<piecewise_1D_node<T, SAFE_MATH>> (
d, x,
587#if defined(__clang__) || defined(__GNUC__)
590 assert(
false &&
"Should never reach.");
595 template<jit::
float_scalar T,
bool SAFE_MATH=false>
607 template<jit::
float_scalar T,
bool SAFE_MATH=false>
609 return std::dynamic_pointer_cast<piecewise_1D_node<T, SAFE_MATH>> (x);
656 template<jit::
float_scalar T,
bool SAFE_MATH=false>
676 for (
size_t i = 0,
ie =
d.size();
i <
ie;
i++) {
694 return piecewise_2D_node::to_string(
d) +
706 const size_t h = std::hash<std::string>{} (piecewise_2D_node::to_string(
d));
716#if defined(__clang__) || defined(__GNUC__)
719 assert(
false &&
"Should never reach.");
724 const size_t data_hash;
726 const size_t num_columns;
751 num_columns(
n), x_scale(x_scale), x_offset(x_offset), y_scale(y_scale),
754 "Expected the data buffer to be a multiple of the number of columns.");
836 const T
l = (this->
left->evaluate().at(0) + x_offset)/x_scale;
837 const T
r = (this->
right->evaluate().at(0) + y_offset)/y_scale;
838 const size_t i = std::min(
static_cast<size_t> (std::real(
l)),
840 const size_t j = std::min(
static_cast<size_t> (std::real(
r)),
844 const T
l = (this->
left->evaluate().at(0) + x_offset)/x_scale;
845 const size_t i = std::min(
static_cast<size_t> (std::real(
l)),
849 this->
right, y_scale, y_offset);
851 const T
r = (this->
right->evaluate().at(0) + y_offset)/y_scale;
852 const size_t j = std::min(
static_cast<size_t> (std::real(
r)),
856 this->
left, x_scale, x_offset);
895 this->
left->compile_preamble(stream, registers,
899 this->
right->compile_preamble(stream, registers,
907 if constexpr (jit::use_metal<T> ()) {
909 std::array<size_t, 2> ({length/num_columns, num_columns}));
910#ifdef USE_CUDA_TEXTURES
913 std::array<size_t, 2> ({length/num_columns, num_columns}));
920 stream <<
"__constant__ ";
924 jit::add_type<T> (stream);
927 jit::add_type<T> (stream);
930 for (
size_t i = 1;
i < length;
i++) {
933 jit::add_type<T> (stream);
937 stream <<
"};" << std::endl;
943 if constexpr (jit::use_metal<T> ()) {
945 std::array<size_t, 2> ({length/num_columns, num_columns}));
946#ifdef USE_CUDA_TEXTURES
949 std::array<size_t, 2> ({length/num_columns, num_columns}));
1001 if (registers.find(
this) == registers.end()) {
1003 const size_t num_rows = length/num_columns;
1014#ifdef USE_INDEX_CACHE
1018 << jit::smallest_int_type<T> (
num_rows) <<
" "
1022 x->endline(stream,
usage);
1027 << jit::smallest_int_type<T> (num_columns) <<
" "
1031 y->endline(stream,
usage);
1035 if constexpr (!jit::use_metal<T> ()
1036#ifdef USE_CUDA_TEXTURES
1044 << jit::smallest_int_type<T> (length) <<
" "
1047 <<
"*" << num_columns <<
" + "
1049 <<
";" << std::endl;
1055 stream <<
" const ";
1056 jit::add_type<T> (stream);
1057 stream <<
" " << registers[
this] <<
" = ";
1058#ifdef USE_CUDA_TEXTURES
1062 stream <<
"to_cmp_float(tex1D<float2> (";
1064 stream <<
"tex1D<float> (";
1068 stream <<
"to_cmp_double(tex1D<uint4> (";
1070 stream <<
"to_double(tex1D<uint2> (";
1076 if constexpr (jit::use_metal<T> ()) {
1077#ifdef USE_INDEX_CACHE
1079 << jit::smallest_int_type<T> (std::max(
num_rows,
1087 stream <<
".read(uint2(";
1095#ifdef USE_CUDA_TEXTURES
1097#ifdef USE_INDEX_CACHE
1116#ifdef USE_INDEX_CACHE
1124 stream <<
"*" << num_columns <<
" + ";
1148 return this->data_hash ==
x_cast->data_hash &&
1162 std::cout <<
"r\\_" <<
reinterpret_cast<size_t> (
this) <<
"_{ij}";
1174 if (registers.find(
this) == registers.end()) {
1176 registers[
this] =
name;
1177 stream <<
" " <<
name
1178 <<
" [label = \"r_" <<
reinterpret_cast<size_t> (
this)
1179 <<
"_{ij}\", shape = hexagon, style = filled, fillcolor = black, fontcolor = white];" << std::endl;
1181 auto l = this->
left->to_vizgraph(stream, registers);
1182 stream <<
" " <<
name <<
" -- " << registers[
l.get()] <<
";" << std::endl;
1183 auto r = this->
right->to_vizgraph(stream, registers);
1184 stream <<
" " <<
name <<
" -- " << registers[
r.get()] <<
";" << std::endl;
1252 return temp.get() &&
1253 this->
left->is_match(
temp->get_left()) &&
1254 this->
right->is_match(
temp->get_right()) &&
1257 (
temp->get_x_scale() == this->x_scale) &&
1258 (
temp->get_x_offset() == this->x_offset) &&
1259 (
temp->get_y_scale() == this->y_scale) &&
1260 (
temp->get_y_offset() == this->y_offset);
1271 return temp.get() &&
1272 this->
left->is_match(
temp->get_arg()) &&
1274 (
temp->get_scale() == this->x_scale) &&
1275 (
temp->get_offset() == this->x_offset);
1288 return temp.get() &&
1289 this->
right->is_match(
temp->get_arg()) &&
1291 (
temp->get_scale() == this->y_scale) &&
1292 (
temp->get_offset() == this->y_offset);
1312 template<jit::
float_scalar T,
bool SAFE_MATH=false>
1321 auto temp = std::make_shared<piecewise_2D_node<T, SAFE_MATH>> (
d,
n,
1322 x, x_scale, x_offset,
1323 y, y_scale, y_offset)->reduce();
1334#if defined(__clang__) || defined(__GNUC__)
1337 assert(
false &&
"Should never reach.");
1342 template<jit::
float_scalar T,
bool SAFE_MATH=false>
1354 template<jit::
float_scalar T,
bool SAFE_MATH=false>
1356 return std::dynamic_pointer_cast<piecewise_2D_node<T, SAFE_MATH>> (x);
1374 template<jit::
float_scalar T,
bool SAFE_MATH=false>
1409 scale(scale), offset(offset) {}
1421 return this->
right->evaluate();
1465 if (registers.find(
this) == registers.end()) {
1466#ifdef USE_INDEX_CACHE
1474#ifdef USE_INDEX_CACHE
1477 << jit::smallest_int_type<T> (length) <<
" "
1481 a->endline(stream,
usage);
1486 stream <<
" const ";
1487 jit::add_type<T> (stream);
1488 auto var = this->
left->compile(stream,
1492 stream <<
" " << registers[
this] <<
" = "
1494#ifdef USE_INDEX_CACHE
1512 std::cout <<
"r\\_" <<
reinterpret_cast<size_t> (this->
left.get())
1514 <<
reinterpret_cast<size_t> (this->
right.get())
1527 if (registers.find(
this) == registers.end()) {
1529 registers[
this] =
name;
1530 stream <<
" " <<
name
1531 <<
" [label = \"r_" <<
reinterpret_cast<size_t> (this->
left.get())
1532 <<
"\", shape = hexagon, style = filled, fillcolor = black, fontcolor = white];" << std::endl;
1534 auto l = this->
left->to_vizgraph(stream, registers);
1535 stream <<
" " <<
name <<
" -- " << registers[
l.get()] <<
";" << std::endl;
1536 auto r = this->
right->to_vizgraph(stream, registers);
1537 stream <<
" " <<
name <<
" -- " << registers[
r.get()] <<
";" << std::endl;
1587 return temp.get() &&
1588 this->
right->is_match(
temp->get_right()) &&
1590 (
temp->get_scale() == this->scale) &&
1591 (
temp->get_offset() == this->offset);
1618 return this->
left->size();
1634 template<jit::
float_scalar T,
bool SAFE_MATH=false>
1640 auto temp = std::make_shared<index_1D_node<T, SAFE_MATH>> (
v, x,
1653#if defined(__clang__) || defined(__GNUC__)
1656 assert(
false &&
"Should never reach.");
1661 template<jit::
float_scalar T,
bool SAFE_MATH=false>
1673 template<jit::
float_scalar T,
bool SAFE_MATH=false>
1675 return std::dynamic_pointer_cast<index_1D_node<T, SAFE_MATH>> (x);
Class representing a generic buffer.
Definition backend.hpp:29
Class representing a branch node.
Definition node.hpp:1174
shared_leaf< T, SAFE_MATH > right
Right branch of the tree.
Definition node.hpp:1179
shared_leaf< T, SAFE_MATH > left
Left branch of the tree.
Definition node.hpp:1177
Class representing a 1D index.
Definition piecewise.hpp:1375
virtual shared_leaf< T, SAFE_MATH > get_power_exponent() const
Get the exponent of a power.
Definition piecewise.hpp:1575
virtual backend::buffer< T > evaluate()
Evaluate the results of the piecewise constant.
Definition piecewise.hpp:1420
size_t get_size() const
Get the size of the buffer.
Definition piecewise.hpp:1617
virtual bool is_power_like() const
Test if the node acts like a power of variable.
Definition piecewise.hpp:1566
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map ®isters, jit::register_map &indices, const jit::register_usage &usage)
the node.
Definition piecewise.hpp:1461
T get_scale() const
Get x argument scale.
Definition piecewise.hpp:1599
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduction method.
Definition piecewise.hpp:1432
bool is_arg_match(shared_leaf< T, SAFE_MATH > x)
Check if the args match.
Definition piecewise.hpp:1585
T get_offset() const
Get x argument offset.
Definition piecewise.hpp:1608
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map ®isters)
Convert the node to vizgraph.
Definition piecewise.hpp:1525
index_1D_node(shared_leaf< T, SAFE_MATH > var, shared_leaf< T, SAFE_MATH > x, const T scale, const T offset)
Construct a 1D index.
Definition piecewise.hpp:1404
virtual bool is_all_variables() const
Test if node acts like a variable.
Definition piecewise.hpp:1557
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition piecewise.hpp:1442
virtual bool is_constant() const
Test if node is a constant.
Definition piecewise.hpp:1548
virtual void to_latex() const
Convert the node to latex.
Definition piecewise.hpp:1511
Class representing a node leaf.
Definition node.hpp:364
virtual void endline(std::ostringstream &stream, const jit::register_usage &usage) const final
End a line in the kernel source.
Definition node.hpp:637
virtual bool is_match(std::shared_ptr< leaf_node< T, SAFE_MATH > > x)
Query if the nodes match.
Definition node.hpp:470
virtual std::shared_ptr< leaf_node< T, SAFE_MATH > > compile(std::ostringstream &stream, jit::register_map ®isters, jit::register_map &indices, const jit::register_usage &usage)=0
Compile the node.
Class representing a 1D piecewise constant.
Definition piecewise.hpp:92
virtual void compile_preamble(std::ostringstream &stream, jit::register_map ®isters, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d, int &avail_const_mem)
Compile preamble.
Definition piecewise.hpp:227
virtual bool is_power_like() const
Test if the node acts like a power of variable.
Definition piecewise.hpp:492
piecewise_1D_node(const backend::buffer< T > &d, shared_leaf< T, SAFE_MATH > x, const T scale, const T offset)
Construct 1D a piecewise constant node.
Definition piecewise.hpp:163
virtual bool is_constant() const
Test if node is a constant.
Definition piecewise.hpp:465
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Query if the nodes match.
Definition piecewise.hpp:419
virtual shared_leaf< T, SAFE_MATH > get_power_exponent() const
Get the exponent of a power.
Definition piecewise.hpp:510
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduction method.
Definition piecewise.hpp:192
virtual backend::buffer< T > evaluate()
Evaluate the results of the piecewise constant.
Definition piecewise.hpp:180
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition piecewise.hpp:212
T get_scale() const
Get x argument scale.
Definition piecewise.hpp:534
virtual void to_latex() const
Convert the node to latex.
Definition piecewise.hpp:433
virtual shared_leaf< T, SAFE_MATH > get_power_base()
Get the base of a power.
Definition piecewise.hpp:501
size_t get_size() const
Get the size of the buffer.
Definition piecewise.hpp:552
virtual bool has_constant_zero() const
Test the constant node has a zero.
Definition piecewise.hpp:474
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map ®isters)
Convert the node to vizgraph.
Definition piecewise.hpp:444
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map ®isters, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition piecewise.hpp:320
T get_offset() const
Get x argument offset.
Definition piecewise.hpp:543
virtual bool is_all_variables() const
Test if node acts like a variable.
Definition piecewise.hpp:483
bool is_arg_match(shared_leaf< T, SAFE_MATH > x)
Check if the args match.
Definition piecewise.hpp:520
Class representing a 2D piecewise constant.
Definition piecewise.hpp:657
T get_y_scale() const
Get y argument scale.
Definition piecewise.hpp:799
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition piecewise.hpp:872
bool is_arg_match(shared_leaf< T, SAFE_MATH > x)
Check if the args match.
Definition piecewise.hpp:1250
T get_y_offset() const
Get y argument offset.
Definition piecewise.hpp:808
virtual bool is_power_like() const
Test if the node acts like a power of variable.
Definition piecewise.hpp:1222
T get_x_scale() const
Get x argument scale.
Definition piecewise.hpp:781
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map ®isters, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition piecewise.hpp:997
size_t get_num_columns() const
Get the number of columns.
Definition piecewise.hpp:762
virtual shared_leaf< T, SAFE_MATH > get_power_base()
Get the base of a power.
Definition piecewise.hpp:1231
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Query if the nodes match.
Definition piecewise.hpp:1144
virtual bool is_all_variables() const
Test if node acts like a variable.
Definition piecewise.hpp:1213
virtual bool is_constant() const
Test if node is a constant.
Definition piecewise.hpp:1195
size_t get_num_rows() const
Get the number of columns.
Definition piecewise.hpp:771
bool is_col_match(shared_leaf< T, SAFE_MATH > x)
Do the columns match.
Definition piecewise.hpp:1286
T get_x_offset() const
Get x argument offset.
Definition piecewise.hpp:790
virtual void to_latex() const
Convert the node to latex.
Definition piecewise.hpp:1161
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map ®isters)
Convert the node to vizgraph.
Definition piecewise.hpp:1172
bool is_row_match(shared_leaf< T, SAFE_MATH > x)
Do the rows match.
Definition piecewise.hpp:1269
virtual shared_leaf< T, SAFE_MATH > get_power_exponent() const
Get the exponent of a power.
Definition piecewise.hpp:1240
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduction method.
Definition piecewise.hpp:833
piecewise_2D_node(const backend::buffer< T > &d, const size_t n, shared_leaf< T, SAFE_MATH > x, const T x_scale, const T x_offset, shared_leaf< T, SAFE_MATH > y, const T y_scale, const T y_offset)
Construct 2D a piecewise constant node.
Definition piecewise.hpp:741
virtual backend::buffer< T > evaluate()
Evaluate the results of the piecewise constant.
Definition piecewise.hpp:821
virtual bool has_constant_zero() const
Test the constant node has a zero.
Definition piecewise.hpp:1204
virtual void compile_preamble(std::ostringstream &stream, jit::register_map ®isters, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d, int &avail_const_mem)
Compile preamble.
Definition piecewise.hpp:887
Class representing a straight node.
Definition node.hpp:1060
shared_leaf< T, SAFE_MATH > arg
Argument.
Definition node.hpp:1063
Complex scalar concept.
Definition register.hpp:24
Double base concept.
Definition register.hpp:42
subroutine assert(test, message)
Assert check.
Definition f_binding_test.f90:38
Name space for graph nodes.
Definition arithmetic.hpp:13
shared_piecewise_2D< T, SAFE_MATH > piecewise_2D_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a piecewise 2D node.
Definition piecewise.hpp:1355
constexpr shared_leaf< T, SAFE_MATH > zero()
Forward declare for zero.
Definition node.hpp:995
shared_piecewise_1D< T, SAFE_MATH > piecewise_1D_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a piecewise 1D node.
Definition piecewise.hpp:608
std::shared_ptr< piecewise_2D_node< T, SAFE_MATH > > shared_piecewise_2D
Convenience type alias for shared piecewise 2D nodes.
Definition piecewise.hpp:1343
shared_constant< T, SAFE_MATH > constant_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a constant node.
Definition node.hpp:1043
shared_variable< T, SAFE_MATH > variable_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a variable node.
Definition node.hpp:1747
std::shared_ptr< piecewise_1D_node< T, SAFE_MATH > > shared_piecewise_1D
Convenience type alias for shared piecewise 1D nodes.
Definition piecewise.hpp:596
constexpr T i
Convenience type for imaginary constant.
Definition node.hpp:1027
std::shared_ptr< leaf_node< T, SAFE_MATH > > shared_leaf
Convenience type alias for shared leaf nodes.
Definition node.hpp:674
std::shared_ptr< index_1D_node< T, SAFE_MATH > > shared_index_1D
Convenience type alias for shared piecewise 1D nodes.
Definition piecewise.hpp:1662
void compile_index(std::ostringstream &stream, const std::string ®ister_name, const size_t length, const T scale, const T offset)
Compile an index.
Definition piecewise.hpp:26
shared_index_1D< T, SAFE_MATH > index_1D_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a piecewise 1D node.
Definition piecewise.hpp:1674
std::map< void *, size_t > texture1d_list
Type alias for indexing 1D textures.
Definition register.hpp:263
std::string format_to_string(const T value)
Convert a value to a string while avoiding locale.
Definition register.hpp:212
std::map< void *, std::array< size_t, 2 > > texture2d_list
Type alias for indexing 2D textures.
Definition register.hpp:265
std::map< void *, size_t > register_usage
Type alias for counting register usage.
Definition register.hpp:259
std::map< void *, std::string > register_map
Type alias for mapping node pointers to register names.
Definition register.hpp:257
constexpr bool use_cuda()
Test to use Cuda.
Definition register.hpp:67
std::set< void * > visiter_map
Type alias for listing visited nodes.
Definition register.hpp:261
std::string to_string(const char prefix, const NODE *pointer)
Convert a graph::leaf_node pointer to a string.
Definition register.hpp:246
Base nodes of graph computation framework.
void piecewise_1D()
Tests for 1D piecewise nodes.
Definition piecewise_test.cpp:80
void index_1D()
Tests for 1D index nodes.
Definition piecewise_test.cpp:738
void piecewise_2D()
Tests for 2D piecewise nodes.
Definition piecewise_test.cpp:292