Graph Framework
Loading...
Searching...
No Matches
random.hpp
Go to the documentation of this file.
1//------------------------------------------------------------------------------
6//------------------------------------------------------------------------------
7
8#ifndef random_h
9#define random_h
10
11#include "node.hpp"
12
13namespace graph {
14//******************************************************************************
16//******************************************************************************
17//------------------------------------------------------------------------------
22//------------------------------------------------------------------------------
23 template<jit::float_scalar T, bool SAFE_MATH=false>
24 class random_state_node final : public leaf_node<T, SAFE_MATH> {
25 public:
26//------------------------------------------------------------------------------
28//------------------------------------------------------------------------------
29 struct mt_state {
31 std::array<uint32_t, 624> array;
34#ifdef USE_CUDA
37#endif
38 };
39
40//------------------------------------------------------------------------------
45//------------------------------------------------------------------------------
46 random_state_node(const size_t size,
47 const uint32_t seed=0) :
48 leaf_node<T, SAFE_MATH> (random_state_node::to_string(), 1, false) {
49 for (uint32_t i = 0; i < size; i++) {
50 states.push_back(initialize_state(seed + i));
51 }
52 }
53
54//------------------------------------------------------------------------------
58//------------------------------------------------------------------------------
60 backend::buffer<T> result;
61 return result;
62 }
63
64//------------------------------------------------------------------------------
69//------------------------------------------------------------------------------
73
74//------------------------------------------------------------------------------
84//------------------------------------------------------------------------------
85 virtual void compile_preamble(std::ostringstream &stream,
86 jit::register_map &registers,
91 int &avail_const_mem) {
92 if (visited.find(this) == visited.end()) {
93 stream << "struct mt_state {" << std::endl
94 << " array<uint32_t, 624> array;" << std::endl
95 << " uint16_t index;" << std::endl
96#ifdef USE_CUDA
97 << " uint16_t padding[3];" << std::endl
98#endif
99 << "};" << std::endl;
100
101 visited.insert(this);
102#ifdef SHOW_USE_COUNT
103 usage[this] = 1;
104 } else {
105 ++usage[this];
106#endif
107 }
108 }
109
110//------------------------------------------------------------------------------
118//------------------------------------------------------------------------------
120 compile(std::ostringstream &stream,
121 jit::register_map &registers,
123 const jit::register_usage &usage) {
124 return this->shared_from_this();
125 }
126
127//------------------------------------------------------------------------------
129//------------------------------------------------------------------------------
130 virtual void to_latex() const {
131 std::cout << "state";
132 }
133
134//------------------------------------------------------------------------------
140//------------------------------------------------------------------------------
141 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
142 jit::register_map &registers) {
143 if (registers.find(this) == registers.end()) {
144 const std::string name = jit::to_string('r', this);
145 registers[this] = name;
146 stream << " " << name
147 << " [label = \"state\", shape = box, style = \"rounded,filled\", fillcolor = black, fontcolor = white];" << std::endl;
148 }
149
150 return this->shared_from_this();
151 }
152
153//------------------------------------------------------------------------------
157//------------------------------------------------------------------------------
158 virtual bool is_all_variables() const {
159 return false;
160 }
161
162//------------------------------------------------------------------------------
166//------------------------------------------------------------------------------
168 return constant<T, SAFE_MATH> (static_cast<T> (1.0));
169 }
170
171//------------------------------------------------------------------------------
175//------------------------------------------------------------------------------
176 size_t size() {
177 return states.size();
178 }
179
180//------------------------------------------------------------------------------
184//------------------------------------------------------------------------------
185 size_t get_size_bytes() {
186 return size()*sizeof(mt_state);
187 }
188
189//------------------------------------------------------------------------------
193//------------------------------------------------------------------------------
195 return states.data();
196 }
197
198 private:
200 std::vector<mt_state> states;
201
202//------------------------------------------------------------------------------
206//------------------------------------------------------------------------------
207 static std::string to_string() {
208 return "random_state";
209 }
210
211//------------------------------------------------------------------------------
216//------------------------------------------------------------------------------
217 mt_state initialize_state(const uint32_t seed) {
218 mt_state state;
219 state.array[0] = seed;
220 for (uint16_t i = 1, ie = state.array.size(); i < ie; i++) {
221 state.array[i] = 1812433253U*(state.array[i - 1]^(state.array[i - 1] >> 30)) + i;
222 }
223 state.index = 0;
224
225 return state;
226 }
227 };
228
229//------------------------------------------------------------------------------
238//------------------------------------------------------------------------------
239 template<jit::float_scalar T, bool SAFE_MATH=false>
241 const uint32_t seed=0) {
242 auto temp = std::make_shared<random_state_node<T, SAFE_MATH>> (size, seed)->reduce();
243// Test for hash collisions.
244 for (size_t i = temp->get_hash();
246 if (leaf_node<T, SAFE_MATH>::caches.nodes.find(i) ==
249 return temp;
250 } else if (temp->is_match(leaf_node<T, SAFE_MATH>::caches.nodes[i])) {
252 }
253 }
254#if defined(__clang__) || defined(__GNUC__)
256#else
257 assert(false && "Should never reach.");
258#endif
259 }
260
262 template<jit::float_scalar T, bool SAFE_MATH=false>
263 using shared_random_state = std::shared_ptr<random_state_node<T, SAFE_MATH>>;
264
265//------------------------------------------------------------------------------
273//------------------------------------------------------------------------------
274 template<jit::float_scalar T, bool SAFE_MATH=false>
276 return std::dynamic_pointer_cast<random_state_node<T, SAFE_MATH>> (x);
277 }
278
279//******************************************************************************
280// Random constant.
281//******************************************************************************
282//------------------------------------------------------------------------------
287//------------------------------------------------------------------------------
288 template<jit::float_scalar T, bool SAFE_MATH=false>
289 class random_node final : public straight_node<T, SAFE_MATH> {
290 private:
291
292//------------------------------------------------------------------------------
296//------------------------------------------------------------------------------
297 static std::string to_string() {
298 return "random";
299 }
300
301 public:
302//------------------------------------------------------------------------------
306//------------------------------------------------------------------------------
309
310//------------------------------------------------------------------------------
314//------------------------------------------------------------------------------
316 backend::buffer<T> result;
317 return result;
318 }
319
320//------------------------------------------------------------------------------
327//------------------------------------------------------------------------------
331
332//------------------------------------------------------------------------------
342//------------------------------------------------------------------------------
343 virtual void compile_preamble(std::ostringstream &stream,
344 jit::register_map &registers,
349 int &avail_const_mem) {
350 if (visited.find(this) == visited.end()) {
351 this->arg->compile_preamble(stream, registers,
352 visited, usage,
355
356 jit::add_type<T> (stream);
357 stream << " random(";
358 if constexpr (jit::use_metal<T> ()) {
359 stream << "device ";
360 }
361 stream <<"mt_state &state) {" << std::endl
362 << " uint16_t k = state.index;" << std::endl
363 << " uint16_t j = (k + 1) % 624;" << std::endl
364 << " uint32_t x = (state.array[k] & 0x80000000U) |" << std::endl
365 << " (state.array[j] & 0x7fffffffU);" << std::endl
366 << " uint32_t xA = x >> 1;" << std::endl
367 << " if (x & 0x00000001U) {" << std::endl
368 << " xA ^= 0x9908b0dfU;" << std::endl
369 << " }" << std::endl
370 << " j = (k + 397) % 624;" << std::endl
371 << " x = state.array[j]^xA;" << std::endl
372 << " state.array[k] = x;" << std::endl
373 << " state.index = (k + 1) % 624;" << std::endl
374 << " uint32_t y = x^(x >> 11);" << std::endl
375 << " y = y^((y << 7) & 0x9d2c5680U);" << std::endl
376 << " y = y^((y << 15) & 0xefc60000U);" << std::endl
377 << " return static_cast<";
378 jit::add_type<T> (stream);
379 stream << "> (y^(y >> 18));" << std::endl
380 << "}" << std::endl;
381
382 visited.insert(this);
383#ifdef SHOW_USE_COUNT
384 usage[this] = 1;
385 } else {
386 ++usage[this];
387#endif
388 }
389 }
390
391//------------------------------------------------------------------------------
399//------------------------------------------------------------------------------
401 compile(std::ostringstream &stream,
402 jit::register_map &registers,
404 const jit::register_usage &usage) {
405 if (registers.find(this) == registers.end()) {
406 shared_leaf<T, SAFE_MATH> a = this->arg->compile(stream,
407 registers,
408 indices,
409 usage);
410
411 registers[this] = "random(" + registers[a.get()] + ")";
412 }
413
414 return this->shared_from_this();
415 }
416
417//------------------------------------------------------------------------------
427//------------------------------------------------------------------------------
429 return false;
430 }
431
432//------------------------------------------------------------------------------
434//------------------------------------------------------------------------------
435 virtual void to_latex() const {
436 std::cout << "state";
437 }
438
439//------------------------------------------------------------------------------
445//------------------------------------------------------------------------------
446 virtual shared_leaf<T, SAFE_MATH> to_vizgraph(std::stringstream &stream,
447 jit::register_map &registers) {
448 if (registers.find(this) == registers.end()) {
449 const std::string name = jit::to_string('r', this);
450 registers[this] = name;
451 stream << " " << name
452 << " [label = \"state\", shape = box, style = \"rounded,filled\", fillcolor = black, fontcolor = white];" << std::endl;
453
454 auto a = this->arg->to_vizgraph(stream, registers);
455 stream << " " << name << " -- " << registers[a.get()] << ";" << std::endl;
456 }
457
458 return this->shared_from_this();
459 }
460
461//------------------------------------------------------------------------------
465//------------------------------------------------------------------------------
466 virtual bool is_all_variables() const {
467 return false;
468 }
469
470//------------------------------------------------------------------------------
474//------------------------------------------------------------------------------
476 return constant<T, SAFE_MATH> (static_cast<T> (1.0));
477 }
478 };
479
480//------------------------------------------------------------------------------
488//------------------------------------------------------------------------------
489 template<jit::float_scalar T, bool SAFE_MATH=false>
491 auto temp = std::make_shared<random_node<T, SAFE_MATH>> (state)->reduce();
492// Test for hash collisions.
493 for (size_t i = temp->get_hash();
495 if (leaf_node<T, SAFE_MATH>::caches.nodes.find(i) ==
498 return temp;
499 } else if (temp->is_match(leaf_node<T, SAFE_MATH>::caches.nodes[i])) {
501 }
502 }
503#if defined(__clang__) || defined(__GNUC__)
505#else
506 assert(false && "Should never reach.");
507#endif
508 }
509
511 template<jit::float_scalar T, bool SAFE_MATH=false>
512 using shared_random = std::shared_ptr<random_node<T, SAFE_MATH>>;
513
514//------------------------------------------------------------------------------
522//------------------------------------------------------------------------------
523 template<jit::float_scalar T, bool SAFE_MATH=false>
525 return std::dynamic_pointer_cast<random_node<T, SAFE_MATH>> (x);
526 }
527
528//------------------------------------------------------------------------------
535//------------------------------------------------------------------------------
536 template<jit::float_scalar T, bool SAFE_MATH=false>
538 return constant<T, SAFE_MATH> (static_cast<T> (std::numeric_limits<uint32_t>::max()));
539 }
540}
541
542#endif /* random_h */
Class representing a generic buffer.
Definition backend.hpp:29
Class representing a node leaf.
Definition node.hpp:364
Class representing a random_node leaf.
Definition random.hpp:289
virtual bool is_match(shared_leaf< T, SAFE_MATH > x)
Query if the nodes match.
Definition random.hpp:428
random_node(shared_random_state< T, SAFE_MATH > x)
Construct a constant node from a vector.
Definition random.hpp:307
virtual backend::buffer< T > evaluate()
Evaluate the results of random node.
Definition random.hpp:315
virtual shared_leaf< T, SAFE_MATH > get_power_exponent() const
Get the exponent of a power.
Definition random.hpp:475
virtual void to_latex() const
Convert the node to latex.
Definition random.hpp:435
virtual bool is_all_variables() const
Test if all the sub-nodes terminate in variables.
Definition random.hpp:466
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition random.hpp:446
virtual void compile_preamble(std::ostringstream &stream, jit::register_map &registers, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d, int &avail_const_mem)
Compile preamble.
Definition random.hpp:343
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition random.hpp:401
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition random.hpp:328
Random state.
Definition random.hpp:24
virtual shared_leaf< T, SAFE_MATH > to_vizgraph(std::stringstream &stream, jit::register_map &registers)
Convert the node to vizgraph.
Definition random.hpp:141
size_t size()
Get the size of the random state vector in bytes.
Definition random.hpp:176
virtual shared_leaf< T, SAFE_MATH > compile(std::ostringstream &stream, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Compile the node.
Definition random.hpp:120
random_state_node(const size_t size, const uint32_t seed=0)
Construct a constant node from a vector.
Definition random.hpp:46
virtual bool is_all_variables() const
Test if all the sub-nodes terminate in variables.
Definition random.hpp:158
virtual shared_leaf< T, SAFE_MATH > df(shared_leaf< T, SAFE_MATH > x)
Transform node to derivative.
Definition random.hpp:70
virtual void to_latex() const
Convert the node to latex.
Definition random.hpp:130
virtual backend::buffer< T > evaluate()
Evaluate the results of random_state_node.
Definition random.hpp:59
virtual void compile_preamble(std::ostringstream &stream, jit::register_map &registers, jit::visiter_map &visited, jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d, int &avail_const_mem)
Compile preamble.
Definition random.hpp:85
size_t get_size_bytes()
Get the size of the random state vector in bytes.
Definition random.hpp:185
virtual shared_leaf< T, SAFE_MATH > get_power_exponent() const
Get the exponent of a power.
Definition random.hpp:167
mt_state * data()
Get the size of the random state vector in bytes.
Definition random.hpp:194
Class representing a straight node.
Definition node.hpp:1051
shared_leaf< T, SAFE_MATH > arg
Argument.
Definition node.hpp:1054
subroutine assert(test, message)
Assert check.
Definition f_binding_test.f90:38
Name space for graph nodes.
Definition arithmetic.hpp:13
constexpr shared_leaf< T, SAFE_MATH > random_scale()
Create a random_scale constant.
Definition random.hpp:537
shared_random_state< T, SAFE_MATH > random_state_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a random_state node.
Definition random.hpp:275
constexpr shared_leaf< T, SAFE_MATH > zero()
Forward declare for zero.
Definition node.hpp:986
shared_leaf< T, SAFE_MATH > random(shared_random_state< T, SAFE_MATH > state)
Define random convenience function.
Definition random.hpp:490
std::shared_ptr< random_state_node< T, SAFE_MATH > > shared_random_state
Convenience type alias for shared sqrt nodes.
Definition random.hpp:263
std::shared_ptr< random_node< T, SAFE_MATH > > shared_random
Convenience type alias for shared sqrt nodes.
Definition random.hpp:512
constexpr T i
Convenience type for imaginary constant.
Definition node.hpp:1018
std::shared_ptr< leaf_node< T, SAFE_MATH > > shared_leaf
Convenience type alias for shared leaf nodes.
Definition node.hpp:676
shared_random< T, SAFE_MATH > random_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a random node.
Definition random.hpp:524
shared_leaf< T, SAFE_MATH > random_state(const size_t size, const uint32_t seed=0)
Define random_state convenience function.
Definition random.hpp:240
std::map< void *, size_t > texture1d_list
Type alias for indexing 1D textures.
Definition register.hpp:263
std::map< void *, std::array< size_t, 2 > > texture2d_list
Type alias for indexing 2D textures.
Definition register.hpp:265
std::map< void *, size_t > register_usage
Type alias for counting register usage.
Definition register.hpp:259
std::map< void *, std::string > register_map
Type alias for mapping node pointers to register names.
Definition register.hpp:257
std::set< void * > visiter_map
Type alias for listing visited nodes.
Definition register.hpp:261
std::string to_string(const char prefix, const NODE *pointer)
Convert a graph::leaf_node pointer to a string.
Definition register.hpp:246
Base nodes of graph computation framework.
Random state structure.
Definition random.hpp:29
uint16_t index
State index.
Definition random.hpp:33
std::array< uint32_t, 624 > array
State array.
Definition random.hpp:31