25template<jit::
float_scalar T>
31 const std::string type = jit::smallest_int_type<T> (length);
42 stream << offset <<
")/";
46 stream << scale <<
")";
50 stream <<
",(" << type <<
")0),("
51 << type <<
")" << length - 1 <<
")";
91 template<jit::
float_scalar T,
bool SAFE_MATH=false>
107 for (
size_t i = 0,
ie =
d.size();
i <
ie;
i++) {
123 return piecewise_1D_node::to_string(
d) +
134 const size_t h = std::hash<std::string>{} (piecewise_1D_node::to_string(
d));
144#if defined(__clang__) || defined(__GNUC__)
147 assert(
false &&
"Should never reach.");
152 const size_t data_hash;
228 this->
arg->compile_preamble(stream, registers,
236 if constexpr (jit::use_metal<T> ()) {
239#ifdef USE_CUDA_TEXTURES
249 stream <<
"__constant__ ";
253 jit::add_type<T> (stream);
256 jit::add_type<T> (stream);
259 for (
size_t i = 1;
i < length;
i++) {
262 jit::add_type<T> (stream);
266 stream <<
"};" << std::endl;
272 if constexpr (jit::use_metal<T> ()) {
275#ifdef USE_CUDA_TEXTURES
317 if (registers.find(
this) == registers.end()) {
318#ifdef USE_INDEX_CACHE
326#ifdef USE_INDEX_CACHE
329 << jit::smallest_int_type<T> (length) <<
" "
333 a->endline(stream,
usage);
339 jit::add_type<T> (stream);
340 stream <<
" " << registers[
this] <<
" = ";
341#ifdef USE_CUDA_TEXTURES
345 stream <<
"to_cmp_float(tex1D<float2> (";
347 stream <<
"tex1D<float> (";
351 stream <<
"to_cmp_double(tex1D<uint4> (";
353 stream <<
"to_double(tex1D<uint2> (";
359 if constexpr (jit::use_metal<T> ()) {
360#ifdef USE_INDEX_CACHE
370#ifdef USE_CUDA_TEXTURES
372#ifdef USE_INDEX_CACHE
386#ifdef USE_INDEX_CACHE
416 return this->data_hash ==
x_cast->data_hash &&
427 std::cout <<
"r\\_" <<
reinterpret_cast<size_t> (
this) <<
"_{i}";
439 if (registers.find(
this) == registers.end()) {
441 registers[
this] =
name;
442 stream <<
" " <<
name
443 <<
" [label = \"r_" <<
reinterpret_cast<size_t> (
this)
444 <<
"_{i}\", shape = hexagon, style = filled, fillcolor = black, fontcolor = white];" << std::endl;
446 auto a = this->
arg->to_vizgraph(stream, registers);
447 stream <<
" " <<
name <<
" -- " << registers[
a.get()] <<
";" << std::endl;
516 this->
arg->is_match(
temp->get_arg()) &&
518 (
temp->get_scale() == this->scale) &&
519 (
temp->get_offset() == this->offset);
562 template<jit::
float_scalar T,
bool SAFE_MATH=false>
567 auto temp = std::make_shared<piecewise_1D_node<T, SAFE_MATH>> (
d, x,
580#if defined(__clang__) || defined(__GNUC__)
583 assert(
false &&
"Should never reach.");
588 template<jit::
float_scalar T,
bool SAFE_MATH=false>
600 template<jit::
float_scalar T,
bool SAFE_MATH=false>
602 return std::dynamic_pointer_cast<piecewise_1D_node<T, SAFE_MATH>> (x);
649 template<jit::
float_scalar T,
bool SAFE_MATH=false>
669 for (
size_t i = 0,
ie =
d.size();
i <
ie;
i++) {
687 return piecewise_2D_node::to_string(
d) +
699 const size_t h = std::hash<std::string>{} (piecewise_2D_node::to_string(
d));
709#if defined(__clang__) || defined(__GNUC__)
712 assert(
false &&
"Should never reach.");
717 const size_t data_hash;
719 const size_t num_columns;
744 num_columns(
n), x_scale(x_scale), x_offset(x_offset), y_scale(y_scale),
747 "Expected the data buffer to be a multiple of the number of columns.");
863 this->
left->compile_preamble(stream, registers,
867 this->
right->compile_preamble(stream, registers,
875 if constexpr (jit::use_metal<T> ()) {
877 std::array<size_t, 2> ({length/num_columns, num_columns}));
878#ifdef USE_CUDA_TEXTURES
881 std::array<size_t, 2> ({length/num_columns, num_columns}));
888 stream <<
"__constant__ ";
892 jit::add_type<T> (stream);
895 jit::add_type<T> (stream);
898 for (
size_t i = 1;
i < length;
i++) {
901 jit::add_type<T> (stream);
905 stream <<
"};" << std::endl;
911 if constexpr (jit::use_metal<T> ()) {
913 std::array<size_t, 2> ({length/num_columns, num_columns}));
914#ifdef USE_CUDA_TEXTURES
917 std::array<size_t, 2> ({length/num_columns, num_columns}));
969 if (registers.find(
this) == registers.end()) {
971 const size_t num_rows = length/num_columns;
982#ifdef USE_INDEX_CACHE
986 << jit::smallest_int_type<T> (
num_rows) <<
" "
990 x->endline(stream,
usage);
995 << jit::smallest_int_type<T> (num_columns) <<
" "
999 y->endline(stream,
usage);
1003 if constexpr (!jit::use_metal<T> ()
1004#ifdef USE_CUDA_TEXTURES
1012 << jit::smallest_int_type<T> (length) <<
" "
1015 <<
"*" << num_columns <<
" + "
1017 <<
";" << std::endl;
1023 stream <<
" const ";
1024 jit::add_type<T> (stream);
1025 stream <<
" " << registers[
this] <<
" = ";
1026#ifdef USE_CUDA_TEXTURES
1030 stream <<
"to_cmp_float(tex1D<float2> (";
1032 stream <<
"tex1D<float> (";
1036 stream <<
"to_cmp_double(tex1D<uint4> (";
1038 stream <<
"to_double(tex1D<uint2> (";
1044 if constexpr (jit::use_metal<T> ()) {
1045#ifdef USE_INDEX_CACHE
1047 << jit::smallest_int_type<T> (std::max(
num_rows,
1055 stream <<
".read(uint2(";
1063#ifdef USE_CUDA_TEXTURES
1065#ifdef USE_INDEX_CACHE
1084#ifdef USE_INDEX_CACHE
1092 stream <<
"*" << num_columns <<
" + ";
1116 return this->data_hash ==
x_cast->data_hash &&
1130 std::cout <<
"r\\_" <<
reinterpret_cast<size_t> (
this) <<
"_{ij}";
1142 if (registers.find(
this) == registers.end()) {
1144 registers[
this] =
name;
1145 stream <<
" " <<
name
1146 <<
" [label = \"r_" <<
reinterpret_cast<size_t> (
this)
1147 <<
"_{ij}\", shape = hexagon, style = filled, fillcolor = black, fontcolor = white];" << std::endl;
1149 auto l = this->
left->to_vizgraph(stream, registers);
1150 stream <<
" " <<
name <<
" -- " << registers[
l.get()] <<
";" << std::endl;
1151 auto r = this->
right->to_vizgraph(stream, registers);
1152 stream <<
" " <<
name <<
" -- " << registers[
r.get()] <<
";" << std::endl;
1220 return temp.get() &&
1221 this->
left->is_match(
temp->get_left()) &&
1222 this->
right->is_match(
temp->get_right()) &&
1225 (
temp->get_x_scale() == this->x_scale) &&
1226 (
temp->get_x_offset() == this->x_offset) &&
1227 (
temp->get_y_scale() == this->y_scale) &&
1228 (
temp->get_y_offset() == this->y_offset);
1239 return temp.get() &&
1240 this->
left->is_match(
temp->get_arg()) &&
1242 (
temp->get_scale() == this->x_scale) &&
1243 (
temp->get_offset() == this->x_offset);
1256 return temp.get() &&
1257 this->
right->is_match(
temp->get_arg()) &&
1259 (
temp->get_scale() == this->y_scale) &&
1260 (
temp->get_offset() == this->y_offset);
1280 template<jit::
float_scalar T,
bool SAFE_MATH=false>
1289 auto temp = std::make_shared<piecewise_2D_node<T, SAFE_MATH>> (
d,
n,
1290 x, x_scale, x_offset,
1291 y, y_scale, y_offset)->reduce();
1302#if defined(__clang__) || defined(__GNUC__)
1305 assert(
false &&
"Should never reach.");
1310 template<jit::
float_scalar T,
bool SAFE_MATH=false>
1322 template<jit::
float_scalar T,
bool SAFE_MATH=false>
1324 return std::dynamic_pointer_cast<piecewise_2D_node<T, SAFE_MATH>> (x);
Class representing a generic buffer.
Definition backend.hpp:29
Class representing a branch node.
Definition node.hpp:1173
shared_leaf< T, SAFE_MATH > right
Right branch of the tree.
Definition node.hpp:1178
shared_leaf< T, SAFE_MATH > left
Left branch of the tree.
Definition node.hpp:1176
Class representing a node leaf.
Definition node.hpp:364
virtual void endline(std::ostringstream &stream, const jit::register_usage &usage) const final
End a line in the kernel source.
Definition node.hpp:637
virtual std::shared_ptr< leaf_node< T, SAFE_MATH > > compile(std::ostringstream &stream, jit::register_map ®isters, jit::register_map &indices, const jit::register_usage &usage)=0
Compile the node.
Class representing a 1D piecewise constant.
Definition piecewise.hpp:92
virtual void compile_preamble(std::ostringstream &stream, jit::register_map ®isters, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d, int &avail_const_mem)
Compile preamble.
Definition piecewise.hpp:220
virtual bool is_power_like() const
Test if the node acts like a power of variable.
Definition piecewise.hpp:485
piecewise_1D_node(const backend::buffer< T > &d, shared_leaf< T, SAFE_MATH > x, const T scale, const T offset)
Construct 1D a piecewise constant node.
Definition piecewise.hpp:163
virtual bool is_constant() const
Test if node is a constant.
Definition piecewise.hpp:458
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Querey if the nodes match.
Definition piecewise.hpp:412
virtual shared_leaf< T, SAFE_MATH > get_power_exponent() const
Get the exponent of a power.
Definition piecewise.hpp:503
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduction method.
Definition piecewise.hpp:192
virtual backend::buffer< T > evaluate()
Evaluate the results of the piecewise constant.
Definition piecewise.hpp:180
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition piecewise.hpp:205
T get_scale() const
Get x argument scale.
Definition piecewise.hpp:527
virtual void to_latex() const
Convert the node to latex.
Definition piecewise.hpp:426
virtual shared_leaf< T, SAFE_MATH > get_power_base()
Get the base of a power.
Definition piecewise.hpp:494
size_t get_size() const
Get the size of the buffer.
Definition piecewise.hpp:545
virtual bool has_constant_zero() const
Test the constant node has a zero.
Definition piecewise.hpp:467
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map ®isters)
Convert the node to vizgraph.
Definition piecewise.hpp:437
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map ®isters, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition piecewise.hpp:313
T get_offset() const
Get x argument offset.
Definition piecewise.hpp:536
virtual bool is_all_variables() const
Test if node acts like a variable.
Definition piecewise.hpp:476
bool is_arg_match(shared_leaf< T, SAFE_MATH > x)
Check if the args match.
Definition piecewise.hpp:513
Class representing a 2D piecewise constant.
Definition piecewise.hpp:650
T get_y_scale() const
Get y argument scale.
Definition piecewise.hpp:792
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition piecewise.hpp:840
bool is_arg_match(shared_leaf< T, SAFE_MATH > x)
Check if the args match.
Definition piecewise.hpp:1218
T get_y_offset() const
Get y argument offset.
Definition piecewise.hpp:801
virtual bool is_power_like() const
Test if the node acts like a power of variable.
Definition piecewise.hpp:1190
T get_x_scale() const
Get x argument scale.
Definition piecewise.hpp:774
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map ®isters, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition piecewise.hpp:965
size_t get_num_columns() const
Get the number of columns.
Definition piecewise.hpp:755
virtual shared_leaf< T, SAFE_MATH > get_power_base()
Get the base of a power.
Definition piecewise.hpp:1199
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Querey if the nodes match.
Definition piecewise.hpp:1112
virtual bool is_all_variables() const
Test if node acts like a variable.
Definition piecewise.hpp:1181
virtual bool is_constant() const
Test if node is a constant.
Definition piecewise.hpp:1163
size_t get_num_rows() const
Get the number of columns.
Definition piecewise.hpp:764
bool is_col_match(shared_leaf< T, SAFE_MATH > x)
Do the columns match.
Definition piecewise.hpp:1254
T get_x_offset() const
Get x argument offset.
Definition piecewise.hpp:783
virtual void to_latex() const
Convert the node to latex.
Definition piecewise.hpp:1129
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map ®isters)
Convert the node to vizgraph.
Definition piecewise.hpp:1140
bool is_row_match(shared_leaf< T, SAFE_MATH > x)
Do the rows match.
Definition piecewise.hpp:1237
virtual shared_leaf< T, SAFE_MATH > get_power_exponent() const
Get the exponent of a power.
Definition piecewise.hpp:1208
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduction method.
Definition piecewise.hpp:826
piecewise_2D_node(const backend::buffer< T > &d, const size_t n, shared_leaf< T, SAFE_MATH > x, const T x_scale, const T x_offset, shared_leaf< T, SAFE_MATH > y, const T y_scale, const T y_offset)
Construct 2D a piecewise constant node.
Definition piecewise.hpp:734
virtual backend::buffer< T > evaluate()
Evaluate the results of the piecewise constant.
Definition piecewise.hpp:814
virtual bool has_constant_zero() const
Test the constant node has a zero.
Definition piecewise.hpp:1172
virtual void compile_preamble(std::ostringstream &stream, jit::register_map ®isters, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d, int &avail_const_mem)
Compile preamble.
Definition piecewise.hpp:855
Class representing a straight node.
Definition node.hpp:1059
shared_leaf< T, SAFE_MATH > arg
Argument.
Definition node.hpp:1062
Complex scalar concept.
Definition register.hpp:24
Double base concept.
Definition register.hpp:42
subroutine assert(test, message)
Assert check.
Definition f_binding_test.f90:38
Name space for graph nodes.
Definition arithmetic.hpp:13
shared_piecewise_2D< T, SAFE_MATH > piecewise_2D_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a piecewise 2D node.
Definition piecewise.hpp:1323
constexpr shared_leaf< T, SAFE_MATH > zero()
Forward declare for zero.
Definition node.hpp:994
shared_piecewise_1D< T, SAFE_MATH > piecewise_1D_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a piecewise 1D node.
Definition piecewise.hpp:601
std::shared_ptr< piecewise_2D_node< T, SAFE_MATH > > shared_piecewise_2D
Convenience type alias for shared piecewise 2D nodes.
Definition piecewise.hpp:1311
std::shared_ptr< piecewise_1D_node< T, SAFE_MATH > > shared_piecewise_1D
Convenience type alias for shared piecewise 1D nodes.
Definition piecewise.hpp:589
constexpr T i
Convinece type for imaginary constant.
Definition node.hpp:1026
std::shared_ptr< leaf_node< T, SAFE_MATH > > shared_leaf
Convenience type alias for shared leaf nodes.
Definition node.hpp:673
void compile_index(std::ostringstream &stream, const std::string ®ister_name, const size_t length, const T scale, const T offset)
Compile an index.
Definition piecewise.hpp:26
std::map< void *, size_t > texture1d_list
Type alias for indexing 1D textures.
Definition register.hpp:262
std::string format_to_string(const T value)
Convert a value to a string while avoiding locale.
Definition register.hpp:211
std::map< void *, std::array< size_t, 2 > > texture2d_list
Type alias for indexing 2D textures.
Definition register.hpp:264
std::map< void *, size_t > register_usage
Type alias for counting register usage.
Definition register.hpp:258
std::map< void *, std::string > register_map
Type alias for mapping node pointers to register names.
Definition register.hpp:256
constexpr bool use_cuda()
Test to use Cuda.
Definition register.hpp:67
std::set< void * > visiter_map
Type alias for listing visited nodes.
Definition register.hpp:260
std::string to_string(const char prefix, const NODE *pointer)
Convert a graph::leaf_node pointer to a string.
Definition register.hpp:245
Base nodes of graph computation framework.
void piecewise_1D()
Tests for 1D piecewise nodes.
Definition piecewise_test.cpp:80
void piecewise_2D()
Tests for 2D piecewise nodes.
Definition piecewise_test.cpp:283