23 template<jit::
float_scalar T,
bool SAFE_MATH=false>
31 std::array<uint32_t, 624>
array;
50 states.push_back(initialize_state(
seed +
i));
93 stream <<
"struct mt_state {" << std::endl
94 <<
" array<uint32_t, 624> array;" << std::endl
95 <<
" uint16_t index;" << std::endl
97 <<
" uint16_t padding[3];" << std::endl
131 std::cout <<
"state";
143 if (registers.find(
this) == registers.end()) {
145 registers[
this] =
name;
146 stream <<
" " <<
name
147 <<
" [label = \"state\", shape = box, style = \"rounded,filled\", fillcolor = black, fontcolor = white];" << std::endl;
177 return states.size();
195 return states.data();
200 std::vector<mt_state> states;
207 static std::string to_string() {
208 return "random_state";
219 state.array[0] =
seed;
221 state.array[
i] = 1812433253U*(state.array[
i - 1]^(state.array[
i - 1] >> 30)) +
i;
239 template<jit::
float_scalar T,
bool SAFE_MATH=false>
242 auto temp = std::make_shared<random_state_node<T, SAFE_MATH>> (size,
seed)->reduce();
244 for (
size_t i =
temp->get_hash();
254#if defined(__clang__) || defined(__GNUC__)
257 assert(
false &&
"Should never reach.");
262 template<jit::
float_scalar T,
bool SAFE_MATH=false>
274 template<jit::
float_scalar T,
bool SAFE_MATH=false>
276 return std::dynamic_pointer_cast<random_state_node<T, SAFE_MATH>> (x);
288 template<jit::
float_scalar T,
bool SAFE_MATH=false>
297 static std::string to_string() {
351 this->
arg->compile_preamble(stream, registers,
356 jit::add_type<T> (stream);
357 stream <<
" random(";
358 if constexpr (jit::use_metal<T> ()) {
361 stream <<
"mt_state &state) {" << std::endl
362 <<
" uint16_t k = state.index;" << std::endl
363 <<
" uint16_t j = (k + 1) % 624;" << std::endl
364 <<
" uint32_t x = (state.array[k] & 0x80000000U) |" << std::endl
365 <<
" (state.array[j] & 0x7fffffffU);" << std::endl
366 <<
" uint32_t xA = x >> 1;" << std::endl
367 <<
" if (x & 0x00000001U) {" << std::endl
368 <<
" xA ^= 0x9908b0dfU;" << std::endl
370 <<
" j = (k + 397) % 624;" << std::endl
371 <<
" x = state.array[j]^xA;" << std::endl
372 <<
" state.array[k] = x;" << std::endl
373 <<
" state.index = (k + 1) % 624;" << std::endl
374 <<
" uint32_t y = x^(x >> 11);" << std::endl
375 <<
" y = y^((y << 7) & 0x9d2c5680U);" << std::endl
376 <<
" y = y^((y << 15) & 0xefc60000U);" << std::endl
377 <<
" return static_cast<";
378 jit::add_type<T> (stream);
379 stream <<
"> (y^(y >> 18));" << std::endl
405 if (registers.find(
this) == registers.end()) {
411 registers[
this] =
"random(" + registers[
a.get()] +
")";
436 std::cout <<
"state";
448 if (registers.find(
this) == registers.end()) {
450 registers[
this] =
name;
451 stream <<
" " <<
name
452 <<
" [label = \"state\", shape = box, style = \"rounded,filled\", fillcolor = black, fontcolor = white];" << std::endl;
454 auto a = this->
arg->to_vizgraph(stream, registers);
455 stream <<
" " <<
name <<
" -- " << registers[
a.get()] <<
";" << std::endl;
489 template<jit::
float_scalar T,
bool SAFE_MATH=false>
491 auto temp = std::make_shared<random_node<T, SAFE_MATH>> (state)->reduce();
493 for (
size_t i =
temp->get_hash();
503#if defined(__clang__) || defined(__GNUC__)
506 assert(
false &&
"Should never reach.");
511 template<jit::
float_scalar T,
bool SAFE_MATH=false>
523 template<jit::
float_scalar T,
bool SAFE_MATH=false>
525 return std::dynamic_pointer_cast<random_node<T, SAFE_MATH>> (x);
536 template<jit::
float_scalar T,
bool SAFE_MATH=false>
Class representing a generic buffer.
Definition backend.hpp:29
Class representing a node leaf.
Definition node.hpp:364
Class representing a random_node leaf.
Definition random.hpp:289
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Query if the nodes match.
Definition random.hpp:428
random_node(shared_random_state< T, SAFE_MATH > x)
Construct a constant node from a vector.
Definition random.hpp:307
virtual backend::buffer< T > evaluate()
Evaluate the results of random node.
Definition random.hpp:315
virtual shared_leaf< T, SAFE_MATH > get_power_exponent() const
Get the exponent of a power.
Definition random.hpp:475
virtual void to_latex() const
Convert the node to latex.
Definition random.hpp:435
virtual bool is_all_variables() const
Test if all the sub-nodes terminate in variables.
Definition random.hpp:466
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map ®isters)
Convert the node to vizgraph.
Definition random.hpp:446
virtual void compile_preamble(std::ostringstream &stream, jit::register_map ®isters, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d, int &avail_const_mem)
Compile preamble.
Definition random.hpp:343
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map ®isters, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition random.hpp:401
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition random.hpp:328
Random state.
Definition random.hpp:24
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map ®isters)
Convert the node to vizgraph.
Definition random.hpp:141
size_t size()
Get the size of the random state vector in bytes.
Definition random.hpp:176
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map ®isters, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition random.hpp:120
random_state_node(const size_t size, const uint32_t seed=0)
Construct a constant node from a vector.
Definition random.hpp:46
virtual bool is_all_variables() const
Test if all the sub-nodes terminate in variables.
Definition random.hpp:158
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition random.hpp:70
virtual void to_latex() const
Convert the node to latex.
Definition random.hpp:130
virtual backend::buffer< T > evaluate()
Evaluate the results of random_state_node.
Definition random.hpp:59
virtual void compile_preamble(std::ostringstream &stream, jit::register_map ®isters, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d, int &avail_const_mem)
Compile preamble.
Definition random.hpp:85
size_t get_size_bytes()
Get the size of the random state vector in bytes.
Definition random.hpp:185
virtual shared_leaf< T, SAFE_MATH > get_power_exponent() const
Get the exponent of a power.
Definition random.hpp:167
mt_state * data()
Get the size of the random state vector in bytes.
Definition random.hpp:194
Class representing a straight node.
Definition node.hpp:1051
shared_leaf< T, SAFE_MATH > arg
Argument.
Definition node.hpp:1054
subroutine assert(test, message)
Assert check.
Definition f_binding_test.f90:38
Name space for graph nodes.
Definition arithmetic.hpp:13
constexpr shared_leaf< T, SAFE_MATH > random_scale()
Create a random_scale constant.
Definition random.hpp:537
shared_random_state< T, SAFE_MATH > random_state_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a random_state node.
Definition random.hpp:275
constexpr shared_leaf< T, SAFE_MATH > zero()
Forward declare for zero.
Definition node.hpp:986
shared_leaf< T, SAFE_MATH > random(shared_random_state< T, SAFE_MATH > state)
Define random convenience function.
Definition random.hpp:490
std::shared_ptr< random_state_node< T, SAFE_MATH > > shared_random_state
Convenience type alias for shared sqrt nodes.
Definition random.hpp:263
std::shared_ptr< random_node< T, SAFE_MATH > > shared_random
Convenience type alias for shared sqrt nodes.
Definition random.hpp:512
constexpr T i
Convenience type for imaginary constant.
Definition node.hpp:1018
std::shared_ptr< leaf_node< T, SAFE_MATH > > shared_leaf
Convenience type alias for shared leaf nodes.
Definition node.hpp:676
shared_random< T, SAFE_MATH > random_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a random node.
Definition random.hpp:524
shared_leaf< T, SAFE_MATH > random_state(const size_t size, const uint32_t seed=0)
Define random_state convenience function.
Definition random.hpp:240
std::map< void *, size_t > texture1d_list
Type alias for indexing 1D textures.
Definition register.hpp:263
std::map< void *, std::array< size_t, 2 > > texture2d_list
Type alias for indexing 2D textures.
Definition register.hpp:265
std::map< void *, size_t > register_usage
Type alias for counting register usage.
Definition register.hpp:259
std::map< void *, std::string > register_map
Type alias for mapping node pointers to register names.
Definition register.hpp:257
std::set< void * > visiter_map
Type alias for listing visited nodes.
Definition register.hpp:261
std::string to_string(const char prefix, const NODE *pointer)
Convert a graph::leaf_node pointer to a string.
Definition register.hpp:246
Base nodes of graph computation framework.
Random state structure.
Definition random.hpp:29
uint16_t index
State index.
Definition random.hpp:33
std::array< uint32_t, 624 > array
State array.
Definition random.hpp:31