24 template<jit::
float_scalar T,
bool SAFE_MATH=false>
80 ap2->get_num_columns(),
81 ap2->get_left(),
ap2->get_x_scale(),
ap2->get_x_offset(),
82 ap2->get_right(),
ap2->get_y_scale(),
ap2->get_y_offset());
88 return temp->get_right() /
90 temp->get_right()*
temp->get_right()));
97 if (
lc.get() &&
lc->evaluate().is_negative()) {
119 const size_t hash =
reinterpret_cast<size_t> (x.get());
139 if (registers.find(
this) == registers.end()) {
147 jit::add_type<T> (stream);
148 stream <<
" " << registers[
this] <<
" = sin("
149 << registers[
a.get()] <<
")";
163 if (
this == x.get()) {
169 return this->
arg->is_match(
x_cast->get_arg());
179 std::cout <<
"\\sin\\left(";
180 this->
arg->to_latex();
181 std::cout <<
"\\right)";
191 return sin(this->
arg->remove_pseudo());
205 if (registers.find(
this) == registers.end()) {
207 registers[
this] =
name;
208 stream <<
" " <<
name
209 <<
" [label = \"sin\", shape = oval, style = filled, fillcolor = blue, fontcolor = white];" << std::endl;
211 auto a = this->
arg->to_vizgraph(stream, registers);
212 stream <<
" " <<
name <<
" -- " << registers[
a.get()] <<
";" << std::endl;
228 template<jit::
float_scalar T,
bool SAFE_MATH=false>
230 auto temp = std::make_shared<sine_node<T, SAFE_MATH>> (x)->reduce();
241#if defined(__clang__) || defined(__GNUC__)
244 assert(
false &&
"Should never reach.");
249 template<jit::
float_scalar T,
bool SAFE_MATH=false>
261 template<jit::
float_scalar T,
bool SAFE_MATH=false>
263 return std::dynamic_pointer_cast<sine_node<T, SAFE_MATH>> (x);
275 template<jit::
float_scalar T,
bool SAFE_MATH=false>
331 ap2->get_num_columns(),
332 ap2->get_left(),
ap2->get_x_scale(),
ap2->get_x_offset(),
333 ap2->get_right(),
ap2->get_y_scale(),
ap2->get_y_offset());
339 return temp->get_left() /
341 temp->get_right()*
temp->get_right()));
348 if (
lc.get() &&
lc->evaluate().is_negative()) {
370 const size_t hash =
reinterpret_cast<size_t> (x.get());
391 if (registers.find(
this) == registers.end()) {
399 jit::add_type<T> (stream);
400 stream <<
" " << registers[
this] <<
" = cos("
401 << registers[
a.get()] <<
")";
415 if (
this == x.get()) {
421 return this->
arg->is_match(
x_cast->get_arg());
431 std::cout <<
"\\cos\\left(";
432 this->
arg->to_latex();
433 std::cout <<
"\\right)";
443 return cos(this->
arg->remove_pseudo());
457 if (registers.find(
this) == registers.end()) {
459 registers[
this] =
name;
460 stream <<
" " <<
name
461 <<
" [label = \"cos\", shape = oval, style = filled, fillcolor = blue, fontcolor = white];" << std::endl;
463 auto a = this->
arg->to_vizgraph(stream, registers);
464 stream <<
" " <<
name <<
" -- " << registers[
a.get()] <<
";" << std::endl;
480 template<jit::
float_scalar T,
bool SAFE_MATH=false>
482 auto temp = std::make_shared<cosine_node<T, SAFE_MATH>> (x)->reduce();
493#if defined(__clang__) || defined(__GNUC__)
496 assert(
false &&
"Should never reach.");
501 template<jit::
float_scalar T,
bool SAFE_MATH=false>
513 template<jit::
float_scalar T,
bool SAFE_MATH=false>
515 return std::dynamic_pointer_cast<cosine_node<T, SAFE_MATH>> (x);
532 template<jit::
float_scalar T,
bool SAFE_MATH=false>
546 template<jit::
float_scalar T,
bool SAFE_MATH=false>
594 if (l.get() &&
r.get()) {
601 if (
pl1.get() && (r.get() ||
pl1->is_arg_match(
this->right))) {
603 pl1->get_scale(),
pl1->get_offset());
604 }
else if (
pr1.get() && (l.get() ||
pr1->is_arg_match(
this->left))) {
606 pr1->get_scale(),
pr1->get_offset());
612 if (
pl2.get() && (r.get() ||
pl2->is_arg_match(
this->right))) {
614 pl2->get_num_columns(),
615 pl2->get_left(),
pl2->get_x_scale(),
pl2->get_x_offset(),
616 pl2->get_right(),
pl2->get_y_scale(),
pl2->get_y_offset());
617 }
else if (
pr2.get() && (l.get() ||
pr2->is_arg_match(
this->left))) {
619 pr2->get_num_columns(),
620 pr2->get_left(),
pr2->get_x_scale(),
pr2->get_x_offset(),
621 pr2->get_right(),
pr2->get_y_scale(),
pr2->get_y_offset());
625 if (
pr2.get() &&
pr2->is_row_match(
this->left)) {
629 pr2->get_num_columns(),
630 pr2->get_left(),
pr2->get_x_scale(),
pr2->get_x_offset(),
631 pr2->get_right(),
pr2->get_y_scale(),
pr2->get_y_offset());
632 }
else if (
pr2.get() &&
pr2->is_col_match(
this->left)) {
636 pr2->get_num_columns(),
637 pr2->get_left(),
pr2->get_x_scale(),
pr2->get_x_offset(),
638 pr2->get_right(),
pr2->get_y_scale(),
pr2->get_y_offset());
639 }
else if (
pl2.get() &&
pl2->is_row_match(
this->right)) {
643 pl2->get_num_columns(),
644 pl2->get_left(),
pl2->get_x_scale(),
pl2->get_x_offset(),
645 pl2->get_right(),
pl2->get_y_scale(),
pl2->get_y_offset());
646 }
else if (
pl2.get() &&
pl2->is_col_match(
this->right)) {
650 pl2->get_num_columns(),
651 pl2->get_left(),
pl2->get_x_scale(),
pl2->get_x_offset(),
652 pl2->get_right(),
pl2->get_y_scale(),
pl2->get_y_offset());
672 const size_t hash =
reinterpret_cast<size_t> (x.get());
694 if (registers.find(
this) == registers.end()) {
706 jit::add_type<T> (stream);
708 stream <<
" " << registers[
this] <<
" = atan("
709 << registers[
r.get()] <<
"/"
710 << registers[
l.get()];
712 stream <<
" " << registers[
this] <<
" = atan2("
713 << registers[
r.get()] <<
","
714 << registers[
l.get()];
730 if (
this == x.get()) {
736 return this->
left->is_match(
x_cast->get_left()) &&
747 std::cout <<
"atan\\left(";
748 this->
left->to_latex();
750 this->
right->to_latex();
751 std::cout <<
"\\right)";
761 return atan(this->
left->remove_pseudo(),
762 this->right->remove_pseudo());
776 if (registers.find(
this) == registers.end()) {
778 registers[
this] =
name;
779 stream <<
" " <<
name
780 <<
" [label = \"atan\", shape = oval, style = filled, fillcolor = blue, fontcolor = white];" << std::endl;
782 auto l = this->
left->to_vizgraph(stream, registers);
783 stream <<
" " <<
name <<
" -- " << registers[
l.get()] <<
";" << std::endl;
784 auto r = this->
right->to_vizgraph(stream, registers);
785 stream <<
" " <<
name <<
" -- " << registers[
r.get()] <<
";" << std::endl;
801 template<jit::
float_scalar T,
bool SAFE_MATH=false>
804 auto temp = std::make_shared<arctan_node<T, SAFE_MATH>> (
l,
r)->reduce();
815#if defined(__clang__) || defined(__GNUC__)
818 assert(
false &&
"Should never reach.");
832 template<jit::
float_scalar T, jit::
float_scalar L,
bool SAFE_MATH=false>
848 template<jit::
float_scalar T, jit::
float_scalar R,
bool SAFE_MATH=false>
855 template<jit::
float_scalar T,
bool SAFE_MATH=false>
867 template<jit::
float_scalar T,
bool SAFE_MATH=false>
869 return std::dynamic_pointer_cast<arctan_node<T, SAFE_MATH>> (x);
Class representing a generic buffer.
Definition backend.hpp:29
void sin()
Take sin.
Definition backend.hpp:241
void atan_col(const buffer< T > &x)
Atan col operation.
Definition backend.hpp:635
void atan_row(const buffer< T > &x)
Atan row operation.
Definition backend.hpp:591
void cos()
Take cos.
Definition backend.hpp:250
Class representing a sine_node leaf.
Definition trigonometry.hpp:547
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition trigonometry.hpp:667
arctan_node(shared_leaf< T, SAFE_MATH > x, shared_leaf< T, SAFE_MATH > y)
Construct a arctan_node node.
Definition trigonometry.hpp:569
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduce a arctan node.
Definition trigonometry.hpp:591
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map ®isters, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition trigonometry.hpp:690
virtual shared_leaf< T, SAFE_MATH > remove_pseudo()
Remove pseudo variable nodes.
Definition trigonometry.hpp:759
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Querey if the nodes match.
Definition trigonometry.hpp:729
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map ®isters)
Convert the node to vizgraph.
Definition trigonometry.hpp:774
virtual backend::buffer< T > evaluate()
Evaluate the results of arctan.
Definition trigonometry.hpp:580
virtual void to_latex() const
Convert the node to latex.
Definition trigonometry.hpp:746
Class representing a branch node.
Definition node.hpp:1173
shared_leaf< T, SAFE_MATH > right
Right branch of the tree.
Definition node.hpp:1178
shared_leaf< T, SAFE_MATH > left
Left branch of the tree.
Definition node.hpp:1176
Class representing a cosine_node leaf.
Definition trigonometry.hpp:276
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map ®isters)
Convert the node to vizgraph.
Definition trigonometry.hpp:455
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map ®isters, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition trigonometry.hpp:387
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition trigonometry.hpp:365
cosine_node(shared_leaf< T, SAFE_MATH > x)
Construct a cosine_node node.
Definition trigonometry.hpp:294
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduce the cos(x).
Definition trigonometry.hpp:315
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Querey if the nodes match.
Definition trigonometry.hpp:414
virtual void to_latex() const
Convert the node to latex.
Definition trigonometry.hpp:430
virtual backend::buffer< T > evaluate()
Evaluate the results of cosine.
Definition trigonometry.hpp:304
virtual shared_leaf< T, SAFE_MATH > remove_pseudo()
Remove pseudo variable nodes.
Definition trigonometry.hpp:441
Class representing a node leaf.
Definition node.hpp:364
virtual void endline(std::ostringstream &stream, const jit::register_usage &usage) const final
End a line in the kernel source.
Definition node.hpp:637
std::map< size_t, std::shared_ptr< leaf_node< T, SAFE_MATH > > > df_cache
Cache derivative terms.
Definition node.hpp:371
virtual bool has_pseudo() const
Query if the node contains pseudo variables.
Definition node.hpp:618
const size_t hash
Hash for node.
Definition node.hpp:367
Class representing a sine_node leaf.
Definition trigonometry.hpp:25
virtual shared_leaf< T, SAFE_MATH > remove_pseudo()
Remove pseudo variable nodes.
Definition trigonometry.hpp:189
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map ®isters)
Convert the node to vizgraph.
Definition trigonometry.hpp:203
virtual void to_latex() const
Convert the node to latex.
Definition trigonometry.hpp:178
virtual backend::buffer< T > evaluate()
Evaluate the results of sine.
Definition trigonometry.hpp:53
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduce the sin(x).
Definition trigonometry.hpp:64
sine_node(shared_leaf< T, SAFE_MATH > x)
Construct a sine_node node.
Definition trigonometry.hpp:43
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition trigonometry.hpp:114
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Querey if the nodes match.
Definition trigonometry.hpp:162
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map ®isters, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition trigonometry.hpp:135
Class representing a straight node.
Definition node.hpp:1059
shared_leaf< T, SAFE_MATH > arg
Argument.
Definition node.hpp:1062
Complex scalar concept.
Definition register.hpp:24
subroutine assert(test, message)
Assert check.
Definition f_binding_test.f90:38
buffer< T > atan(buffer< T > &x, buffer< T > &y)
Take the inverse tangent.
Definition backend.hpp:1098
Name space for graph nodes.
Definition arithmetic.hpp:13
shared_piecewise_2D< T, SAFE_MATH > piecewise_2D_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a piecewise 2D node.
Definition piecewise.hpp:1323
shared_leaf< T, SAFE_MATH > tan(shared_leaf< T, SAFE_MATH > x)
Define tangent convience function.
Definition trigonometry.hpp:533
std::shared_ptr< arctan_node< T, SAFE_MATH > > shared_atan
Convenience type alias for shared add nodes.
Definition trigonometry.hpp:856
constexpr shared_leaf< T, SAFE_MATH > zero()
Forward declare for zero.
Definition node.hpp:994
std::shared_ptr< sine_node< T, SAFE_MATH > > shared_sine
Convenience type alias for shared sine nodes.
Definition trigonometry.hpp:250
shared_leaf< T, SAFE_MATH > atan(shared_leaf< T, SAFE_MATH > l, shared_leaf< T, SAFE_MATH > r)
Build arctan node.
Definition trigonometry.hpp:802
shared_sine< T, SAFE_MATH > sin_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a sine node.
Definition trigonometry.hpp:262
shared_piecewise_1D< T, SAFE_MATH > piecewise_1D_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a piecewise 1D node.
Definition piecewise.hpp:601
shared_atan< T, SAFE_MATH > atan_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a power node.
Definition trigonometry.hpp:868
shared_multiply< T, SAFE_MATH > multiply_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a multiply node.
Definition arithmetic.hpp:2723
shared_constant< T, SAFE_MATH > constant_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a constant node.
Definition node.hpp:1042
constexpr T i
Convinece type for imaginary constant.
Definition node.hpp:1026
shared_leaf< T, SAFE_MATH > sin(shared_leaf< T, SAFE_MATH > x)
Define sine convience function.
Definition trigonometry.hpp:229
shared_leaf< T, SAFE_MATH > sqrt(shared_leaf< T, SAFE_MATH > x)
Define sqrt convience function.
Definition math.hpp:279
std::shared_ptr< cosine_node< T, SAFE_MATH > > shared_cosine
Convenience type alias for shared cosine nodes.
Definition trigonometry.hpp:502
std::shared_ptr< leaf_node< T, SAFE_MATH > > shared_leaf
Convenience type alias for shared leaf nodes.
Definition node.hpp:673
shared_leaf< T, SAFE_MATH > cos(shared_leaf< T, SAFE_MATH > x)
Define cosine convience function.
Definition trigonometry.hpp:481
shared_cosine< T, SAFE_MATH > cos_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a cosine node.
Definition trigonometry.hpp:514
std::string format_to_string(const T value)
Convert a value to a string while avoiding locale.
Definition register.hpp:211
std::map< void *, size_t > register_usage
Type alias for counting register usage.
Definition register.hpp:258
std::map< void *, std::string > register_map
Type alias for mapping node pointers to register names.
Definition register.hpp:256
std::string to_string(const char prefix, const NODE *pointer)
Convert a graph::leaf_node pointer to a string.
Definition register.hpp:245
Base nodes of graph computation framework.
void piecewise_1D()
Tests for 1D piecewise nodes.
Definition piecewise_test.cpp:80
void piecewise_2D()
Tests for 2D piecewise nodes.
Definition piecewise_test.cpp:283