Graph Framework
Loading...
Searching...
No Matches
random.hpp
Go to the documentation of this file.
1//------------------------------------------------------------------------------
6//------------------------------------------------------------------------------
7
8#ifndef random_h
9#define random_h
10
11#include "node.hpp"
12
13namespace graph {
14//******************************************************************************
16//******************************************************************************
17//------------------------------------------------------------------------------
22//------------------------------------------------------------------------------
23 template<jit::float_scalar T, bool SAFE_MATH=false>
24 class random_state_node final : public leaf_node<T, SAFE_MATH> {
25 public:
26//------------------------------------------------------------------------------
28//------------------------------------------------------------------------------
29 struct mt_state {
31 std::array<uint32_t, 624> array;
34#ifdef USE_CUDA
37#endif
38 };
39
40//------------------------------------------------------------------------------
45//------------------------------------------------------------------------------
46 random_state_node(const size_t size,
47 const uint32_t seed=0) :
48 leaf_node<T, SAFE_MATH> (random_state_node::to_string(), 1, false) {
49 for (uint32_t i = 0; i < size; i++) {
50 states.push_back(initalize_state(seed + i));
51 }
52 }
53
54//------------------------------------------------------------------------------
58//------------------------------------------------------------------------------
60 backend::buffer<T> result;
61 return result;
62 }
63
64//------------------------------------------------------------------------------
68//------------------------------------------------------------------------------
70 return this->shared_from_this();
71 }
72
73//------------------------------------------------------------------------------
78//------------------------------------------------------------------------------
82
83//------------------------------------------------------------------------------
93//------------------------------------------------------------------------------
94 virtual void compile_preamble(std::ostringstream &stream,
95 jit::register_map &registers,
100 int &avail_const_mem) {
101 if (visited.find(this) == visited.end()) {
102 stream << "struct mt_state {" << std::endl
103 << " array<uint32_t, 624> array;" << std::endl
104 << " uint16_t index;" << std::endl
105#ifdef USE_CUDA
106 << " uint16_t padding[3];" << std::endl
107#endif
108 << "};" << std::endl;
109
110 visited.insert(this);
111#ifdef SHOW_USE_COUNT
112 usage[this] = 1;
113 } else {
114 ++usage[this];
115#endif
116 }
117 }
118
119//------------------------------------------------------------------------------
127//------------------------------------------------------------------------------
129 compile(std::ostringstream &stream,
130 jit::register_map &registers,
132 const jit::register_usage &usage) {
133 return this->shared_from_this();
134 }
135
136//------------------------------------------------------------------------------
138//------------------------------------------------------------------------------
139 virtual void to_latex() const {
140 std::cout << "state";
141 }
142
143//------------------------------------------------------------------------------
149//------------------------------------------------------------------------------
150 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
151 jit::register_map &registers) {
152 if (registers.find(this) == registers.end()) {
153 const std::string name = jit::to_string('r', this);
154 registers[this] = name;
155 stream << " " << name
156 << " [label = \"state\", shape = box, style = \"rounded,filled\", fillcolor = black, fontcolor = white];" << std::endl;
157 }
158
159 return this->shared_from_this();
160 }
161
162//------------------------------------------------------------------------------
166//------------------------------------------------------------------------------
167 virtual bool is_all_variables() const {
168 return false;
169 }
170
171//------------------------------------------------------------------------------
175//------------------------------------------------------------------------------
177 return constant<T, SAFE_MATH> (static_cast<T> (1.0));
178 }
179
180//------------------------------------------------------------------------------
184//------------------------------------------------------------------------------
185 size_t size() {
186 return states.size();
187 }
188
189//------------------------------------------------------------------------------
193//------------------------------------------------------------------------------
194 size_t get_size_bytes() {
195 return size()*sizeof(mt_state);
196 }
197
198//------------------------------------------------------------------------------
202//------------------------------------------------------------------------------
204 return states.data();
205 }
206
207 private:
209 std::vector<mt_state> states;
210
211//------------------------------------------------------------------------------
215//------------------------------------------------------------------------------
216 static std::string to_string() {
217 return "random_state";
218 }
219
220//------------------------------------------------------------------------------
225//------------------------------------------------------------------------------
226 mt_state initalize_state(const uint32_t seed) {
227 mt_state state;
228 state.array[0] = seed;
229 for (uint16_t i = 1, ie = state.array.size(); i < ie; i++) {
230 state.array[i] = 1812433253U*(state.array[i - 1]^(state.array[i - 1] >> 30)) + i;
231 }
232 state.index = 0;
233
234 return state;
235 }
236 };
237
238//------------------------------------------------------------------------------
247//------------------------------------------------------------------------------
248 template<jit::float_scalar T, bool SAFE_MATH=false>
250 const uint32_t seed=0) {
251 auto temp = std::make_shared<random_state_node<T, SAFE_MATH>> (size, seed)->reduce();
252// Test for hash collisions.
253 for (size_t i = temp->get_hash();
255 if (leaf_node<T, SAFE_MATH>::caches.nodes.find(i) ==
258 return temp;
259 } else if (temp->is_match(leaf_node<T, SAFE_MATH>::caches.nodes[i])) {
261 }
262 }
263#if defined(__clang__) || defined(__GNUC__)
265#else
266 assert(false && "Should never reach.");
267#endif
268 }
269
271 template<jit::float_scalar T, bool SAFE_MATH=false>
272 using shared_random_state = std::shared_ptr<random_state_node<T, SAFE_MATH>>;
273
274//------------------------------------------------------------------------------
282//------------------------------------------------------------------------------
283 template<jit::float_scalar T, bool SAFE_MATH=false>
285 return std::dynamic_pointer_cast<random_state_node<T, SAFE_MATH>> (x);
286 }
287
288//******************************************************************************
289// Random constant.
290//******************************************************************************
291//------------------------------------------------------------------------------
296//------------------------------------------------------------------------------
297 template<jit::float_scalar T, bool SAFE_MATH=false>
298 class random_node final : public straight_node<T, SAFE_MATH> {
299 private:
300
301//------------------------------------------------------------------------------
305//------------------------------------------------------------------------------
306 static std::string to_string() {
307 return "random";
308 }
309
310 public:
311//------------------------------------------------------------------------------
315//------------------------------------------------------------------------------
318
319//------------------------------------------------------------------------------
323//------------------------------------------------------------------------------
325 backend::buffer<T> result;
326 return result;
327 }
328
329//------------------------------------------------------------------------------
333//------------------------------------------------------------------------------
335 return this->shared_from_this();
336 }
337
338//------------------------------------------------------------------------------
345//------------------------------------------------------------------------------
349
350//------------------------------------------------------------------------------
360//------------------------------------------------------------------------------
361 virtual void compile_preamble(std::ostringstream &stream,
362 jit::register_map &registers,
367 int &avail_const_mem) {
368 if (visited.find(this) == visited.end()) {
369 this->arg->compile_preamble(stream, registers,
370 visited, usage,
373
374 jit::add_type<T> (stream);
375 stream << " random(";
376 if constexpr (jit::use_metal<T> ()) {
377 stream << "device ";
378 }
379 stream <<"mt_state &state) {" << std::endl
380 << " uint16_t k = state.index;" << std::endl
381 << " uint16_t j = (k + 1) % 624;" << std::endl
382 << " uint32_t x = (state.array[k] & 0x80000000U) |" << std::endl
383 << " (state.array[j] & 0x7fffffffU);" << std::endl
384 << " uint32_t xA = x >> 1;" << std::endl
385 << " if (x & 0x00000001U) {" << std::endl
386 << " xA ^= 0x9908b0dfU;" << std::endl
387 << " }" << std::endl
388 << " j = (k + 397) % 624;" << std::endl
389 << " x = state.array[j]^xA;" << std::endl
390 << " state.array[k] = x;" << std::endl
391 << " state.index = (k + 1) % 624;" << std::endl
392 << " uint32_t y = x^(x >> 11);" << std::endl
393 << " y = y^((y << 7) & 0x9d2c5680U);" << std::endl
394 << " y = y^((y << 15) & 0xefc60000U);" << std::endl
395 << " return static_cast<";
396 jit::add_type<T> (stream);
397 stream << "> (y^(y >> 18));" << std::endl
398 << "}" << std::endl;
399
400 visited.insert(this);
401#ifdef SHOW_USE_COUNT
402 usage[this] = 1;
403 } else {
404 ++usage[this];
405#endif
406 }
407 }
408
409//------------------------------------------------------------------------------
417//------------------------------------------------------------------------------
419 compile(std::ostringstream &stream,
420 jit::register_map &registers,
422 const jit::register_usage &usage) {
423 if (registers.find(this) == registers.end()) {
424 shared_leaf<T, SAFE_MATH> a = this->arg->compile(stream,
425 registers,
426 indices,
427 usage);
428
429 registers[this] = "random(" + registers[a.get()] + ")";
430 }
431
432 return this->shared_from_this();
433 }
434
435//------------------------------------------------------------------------------
445//------------------------------------------------------------------------------
447 return false;
448 }
449
450//------------------------------------------------------------------------------
452//------------------------------------------------------------------------------
453 virtual void to_latex() const {
454 std::cout << "state";
455 }
456
457//------------------------------------------------------------------------------
463//------------------------------------------------------------------------------
464 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
465 jit::register_map &registers) {
466 if (registers.find(this) == registers.end()) {
467 const std::string name = jit::to_string('r', this);
468 registers[this] = name;
469 stream << " " << name
470 << " [label = \"state\", shape = box, style = \"rounded,filled\", fillcolor = black, fontcolor = white];" << std::endl;
471
472 auto a = this->arg->to_vizgraph(stream, registers);
473 stream << " " << name << " -- " << registers[a.get()] << ";" << std::endl;
474 }
475
476 return this->shared_from_this();
477 }
478
479//------------------------------------------------------------------------------
483//------------------------------------------------------------------------------
484 virtual bool is_all_variables() const {
485 return false;
486 }
487
488//------------------------------------------------------------------------------
492//------------------------------------------------------------------------------
494 return constant<T, SAFE_MATH> (static_cast<T> (1.0));
495 }
496 };
497
498//------------------------------------------------------------------------------
506//------------------------------------------------------------------------------
507 template<jit::float_scalar T, bool SAFE_MATH=false>
509 auto temp = std::make_shared<random_node<T, SAFE_MATH>> (state)->reduce();
510// Test for hash collisions.
511 for (size_t i = temp->get_hash();
513 if (leaf_node<T, SAFE_MATH>::caches.nodes.find(i) ==
516 return temp;
517 } else if (temp->is_match(leaf_node<T, SAFE_MATH>::caches.nodes[i])) {
519 }
520 }
521#if defined(__clang__) || defined(__GNUC__)
523#else
524 assert(false && "Should never reach.");
525#endif
526 }
527
529 template<jit::float_scalar T, bool SAFE_MATH=false>
530 using shared_random = std::shared_ptr<random_node<T, SAFE_MATH>>;
531
532//------------------------------------------------------------------------------
540//------------------------------------------------------------------------------
541 template<jit::float_scalar T, bool SAFE_MATH=false>
543 return std::dynamic_pointer_cast<random_node<T, SAFE_MATH>> (x);
544 }
545
546//------------------------------------------------------------------------------
553//------------------------------------------------------------------------------
554 template<jit::float_scalar T, bool SAFE_MATH=false>
556 return constant<T, SAFE_MATH> (static_cast<T> (std::numeric_limits<uint32_t>::max()));
557 }
558}
559
560#endif /* random_h */
Class representing a generic buffer.
Definition backend.hpp:29
Class representing a node leaf.
Definition node.hpp:364
Class representing a random_node leaf.
Definition random.hpp:298
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Querey if the nodes match.
Definition random.hpp:446
random_node(shared_random_state< T, SAFE_MATH > x)
Construct a constant node from a vector.
Definition random.hpp:316
virtual backend::buffer< T > evaluate()
Evaluate the results of random node.
Definition random.hpp:324
virtual shared_leaf< T, SAFE_MATH > get_power_exponent() const
Get the exponent of a power.
Definition random.hpp:493
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduce the random node.
Definition random.hpp:334
virtual void to_latex() const
Convert the node to latex.
Definition random.hpp:453
virtual bool is_all_variables() const
Test if all the subnodes terminate in variables.
Definition random.hpp:484
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition random.hpp:464
virtual void compile_preamble(std::ostringstream &stream, jit::register_map &registers, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d, int &avail_const_mem)
Compile preamble.
Definition random.hpp:361
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition random.hpp:419
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition random.hpp:346
Random state.
Definition random.hpp:24
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition random.hpp:150
size_t size()
Get the size of the random state vector in bytes.
Definition random.hpp:185
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition random.hpp:129
random_state_node(const size_t size, const uint32_t seed=0)
Construct a constant node from a vector.
Definition random.hpp:46
virtual bool is_all_variables() const
Test if all the subnodes terminate in variables.
Definition random.hpp:167
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition random.hpp:79
virtual void to_latex() const
Convert the node to latex.
Definition random.hpp:139
virtual backend::buffer< T > evaluate()
Evaluate the results of random_state_node.
Definition random.hpp:59
virtual void compile_preamble(std::ostringstream &stream, jit::register_map &registers, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d, int &avail_const_mem)
Compile preamble.
Definition random.hpp:94
size_t get_size_bytes()
Get the size of the random state vector in bytes.
Definition random.hpp:194
virtual shared_leaf< T, SAFE_MATH > get_power_exponent() const
Get the exponent of a power.
Definition random.hpp:176
virtual shared_leaf< T, SAFE_MATH > reduce()
Reduce the random_state_node.
Definition random.hpp:69
mt_state * data()
Get the size of the random state vector in bytes.
Definition random.hpp:203
Class representing a straight node.
Definition node.hpp:1059
shared_leaf< T, SAFE_MATH > arg
Argument.
Definition node.hpp:1062
subroutine assert(test, message)
Assert check.
Definition f_binding_test.f90:38
Name space for graph nodes.
Definition arithmetic.hpp:13
constexpr shared_leaf< T, SAFE_MATH > random_scale()
Create a random_scale constant.
Definition random.hpp:555
shared_random_state< T, SAFE_MATH > random_state_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a random_state node.
Definition random.hpp:284
constexpr shared_leaf< T, SAFE_MATH > zero()
Forward declare for zero.
Definition node.hpp:994
shared_leaf< T, SAFE_MATH > random(shared_random_state< T, SAFE_MATH > state)
Define random convience function.
Definition random.hpp:508
std::shared_ptr< random_state_node< T, SAFE_MATH > > shared_random_state
Convenience type alias for shared sqrt nodes.
Definition random.hpp:272
std::shared_ptr< random_node< T, SAFE_MATH > > shared_random
Convenience type alias for shared sqrt nodes.
Definition random.hpp:530
constexpr T i
Convinece type for imaginary constant.
Definition node.hpp:1026
std::shared_ptr< leaf_node< T, SAFE_MATH > > shared_leaf
Convenience type alias for shared leaf nodes.
Definition node.hpp:673
shared_random< T, SAFE_MATH > random_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a random node.
Definition random.hpp:542
shared_leaf< T, SAFE_MATH > random_state(const size_t size, const uint32_t seed=0)
Define random_state convience function.
Definition random.hpp:249
std::map< void *, size_t > texture1d_list
Type alias for indexing 1D textures.
Definition register.hpp:262
std::map< void *, std::array< size_t, 2 > > texture2d_list
Type alias for indexing 2D textures.
Definition register.hpp:264
std::map< void *, size_t > register_usage
Type alias for counting register usage.
Definition register.hpp:258
std::map< void *, std::string > register_map
Type alias for mapping node pointers to register names.
Definition register.hpp:256
std::set< void * > visiter_map
Type alias for listing visited nodes.
Definition register.hpp:260
std::string to_string(const char prefix, const NODE *pointer)
Convert a graph::leaf_node pointer to a string.
Definition register.hpp:245
Base nodes of graph computation framework.
Random state structure.
Definition random.hpp:29
uint16_t index
State index.
Definition random.hpp:33
std::array< uint32_t, 624 > array
State array.
Definition random.hpp:31