Graph Framework
Loading...
Searching...
No Matches
metal_context.hpp
Go to the documentation of this file.
1//------------------------------------------------------------------------------
6//------------------------------------------------------------------------------
7
8#ifndef metal_context_h
9#define metal_context_h
10
11#include <unordered_set>
12
13#import <Metal/Metal.h>
14
15#include "random.hpp"
16
18namespace gpu {
19//------------------------------------------------------------------------------
23//------------------------------------------------------------------------------
24 template<bool SAFE_MATH=false>
26 private:
28 id<MTLDevice> device;
30 id<MTLCommandQueue> queue;
32 std::map<graph::leaf_node<float, SAFE_MATH> *, id<MTLBuffer>> kernel_arguments;
34 std::map<void *, id<MTLTexture>> texture_arguments;
36 id<MTLCommandBuffer> command_buffer;
38 id<MTLLibrary> library;
40 std::map<std::string, std::vector<MTLMutability>> bufferMutability;
41
42 public:
44 constexpr static size_t random_state_size = 1024;
45
48
49//------------------------------------------------------------------------------
53//------------------------------------------------------------------------------
54 static size_t max_concurrency() {
55 return MTLCopyAllDevices().count;
56 }
57
58//------------------------------------------------------------------------------
60//------------------------------------------------------------------------------
61 static std::string device_type() {
62 return "Metal GPU";
63 }
64
65//------------------------------------------------------------------------------
69//------------------------------------------------------------------------------
70 metal_context(const size_t index) :
71 device([MTLCopyAllDevices() objectAtIndex:index]),
72 queue([device newCommandQueue]) {}
73
74//------------------------------------------------------------------------------
80//------------------------------------------------------------------------------
81 void compile(const std::string kernel_source,
82 std::vector<std::string> names,
83 const bool add_reduction=false) {
84 NSError *error;
85 library = [device newLibraryWithSource:[NSString stringWithCString:kernel_source.c_str()
86 encoding:NSUTF8StringEncoding]
87 options:compile_options()
88 error:&error];
89
90 if (error) {
91 NSLog(@"%@", error);
92 }
93
94 if (jit::verbose) {
95 std::cout << "Metal GPU info." << std::endl;
96 }
97 }
98
99//------------------------------------------------------------------------------
110//------------------------------------------------------------------------------
111 std::function<void(void)> create_kernel_call(const std::string kernel_name,
115 const size_t num_rays,
116 const jit::texture1d_list &tex1d_list,
117 const jit::texture2d_list &tex2d_list) {
118 NSError *error;
119
120 id<MTLFunction> function = [library newFunctionWithName:[NSString stringWithCString:kernel_name.c_str()
121 encoding:NSUTF8StringEncoding]];
122
123 MTLComputePipelineDescriptor *compute = [MTLComputePipelineDescriptor new];
124 compute.threadGroupSizeIsMultipleOfThreadExecutionWidth = YES;
125 compute.computeFunction = function;
126 compute.maxTotalThreadsPerThreadgroup = 1024;
127 for (size_t i = 0, ie = bufferMutability[kernel_name].size(); i < ie; i++) {
128 compute.buffers[i].mutability = bufferMutability[kernel_name][i];
129 }
130
131 id<MTLComputePipelineState> pipline = [device newComputePipelineStateWithDescriptor:compute
132 options:MTLPipelineOptionNone
133 reflection:NULL
134 error:&error];
135
136 if (error) {
137 NSLog(@"%@", error);
138 }
139
140 std::vector<id<MTLBuffer>> buffers;
141
142 const size_t buffer_element_size = sizeof(float);
143 for (graph::shared_variable<float, SAFE_MATH> &input : inputs) {
144 if (!kernel_arguments.contains(input.get())) {
145 backend::buffer<float> buffer = input->evaluate();
146 kernel_arguments[input.get()] = [device newBufferWithBytes:buffer.data()
147 length:buffer.size()*buffer_element_size
148 options:MTLResourceStorageModeShared];
149 buffers.push_back(kernel_arguments[input.get()]);
150 }
151 }
153 if (!kernel_arguments.contains(output.get())) {
154 kernel_arguments[output.get()] = [device newBufferWithLength:num_rays*sizeof(float)
155 options:MTLResourceStorageModeShared];
156 buffers.push_back(kernel_arguments[output.get()]);
157 }
158 }
159 if (state.get()) {
160 if (!kernel_arguments.contains(state.get())) {
161 kernel_arguments[state.get()] = [device newBufferWithBytes:state->data()
162 length:state->get_size_bytes()
163 options:MTLResourceCPUCacheModeWriteCombined |
164 MTLResourceStorageModeShared |
165 MTLResourceHazardTrackingModeUntracked];
166 }
167 buffers.push_back(kernel_arguments[state.get()]);
168 }
169
170 std::vector<id<MTLTexture>> textures;
171 command_buffer = [queue commandBuffer];
172 id<MTLBlitCommandEncoder> encoder = [command_buffer blitCommandEncoder];
173 for (auto &[data, size] : tex1d_list) {
174 if (!texture_arguments.contains(data)) {
175 MTLTextureDescriptor *discriptor = [MTLTextureDescriptor new];
176 discriptor.textureType = MTLTextureType1D;
177 discriptor.pixelFormat = MTLPixelFormatR32Float;
178 discriptor.width = size;
179 discriptor.storageMode = MTLStorageModeManaged;
180 discriptor.cpuCacheMode = MTLCPUCacheModeWriteCombined;
181 discriptor.hazardTrackingMode = MTLHazardTrackingModeUntracked;
182 discriptor.usage = MTLTextureUsageShaderRead;
183 texture_arguments[data] = [device newTextureWithDescriptor:discriptor];
184 [texture_arguments[data] replaceRegion:MTLRegionMake1D(0, size)
185 mipmapLevel:0
186 withBytes:reinterpret_cast<float *> (data)
187 bytesPerRow:4*size];
188
189 [encoder optimizeContentsForGPUAccess:texture_arguments[data]];
190 }
191 textures.push_back(texture_arguments[data]);
192 }
193 for (auto &[data, size] : tex2d_list) {
194 if (!texture_arguments.contains(data)) {
195 MTLTextureDescriptor *discriptor = [MTLTextureDescriptor new];
196 discriptor.textureType = MTLTextureType2D;
197 discriptor.pixelFormat = MTLPixelFormatR32Float;
198 discriptor.width = size[1];
199 discriptor.height = size[0];
200 discriptor.storageMode = MTLStorageModeManaged;
201 discriptor.cpuCacheMode = MTLCPUCacheModeWriteCombined;
202 discriptor.hazardTrackingMode = MTLHazardTrackingModeUntracked;
203 discriptor.usage = MTLTextureUsageShaderRead;
204 texture_arguments[data] = [device newTextureWithDescriptor:discriptor];
205 [texture_arguments[data] replaceRegion:MTLRegionMake2D(0, 0, size[1], size[0])
206 mipmapLevel:0
207 withBytes:reinterpret_cast<float *> (data)
208 bytesPerRow:4*size[1]];
209
210 [encoder optimizeContentsForGPUAccess:texture_arguments[data]];
211 }
212 textures.push_back(texture_arguments[data]);
213 }
214 [encoder endEncoding];
215 [command_buffer commit];
216
217 std::vector<NSUInteger> offsets(buffers.size(), 0);
218 NSRange range = NSMakeRange(0, buffers.size());
219 NSRange tex_range = NSMakeRange(0, textures.size());
220
221 NSUInteger threads_per_group = pipline.maxTotalThreadsPerThreadgroup;
222 NSUInteger thread_width = pipline.threadExecutionWidth;
223 NSUInteger thread_groups = num_rays/threads_per_group + (num_rays%threads_per_group ? 1 : 0);
224
225 if (jit::verbose) {
226 std::cout << " Kernel name : " << kernel_name << std::endl;
227 std::cout << " Thread execution width : " << thread_width << std::endl;
228 std::cout << " Threads per group : " << threads_per_group << std::endl;
229 std::cout << " Number of groups : " << thread_groups << std::endl;
230 std::cout << " Total problem size : " << threads_per_group*thread_groups << std::endl;
231 }
232
233 if (state.get()) {
234 return [this, num_rays, pipline, buffers, offsets, range, tex_range, thread_groups, threads_per_group, textures] () mutable {
235 command_buffer = [queue commandBuffer];
236 for (uint32_t i = 0; i < num_rays; i += threads_per_group) {
237 id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDispatchType:MTLDispatchTypeSerial];
238
239 for (size_t j = 0, je = buffers.size() - 1; j < je; j++) {
240 offsets[j] = i*sizeof(float);
241 }
242
243 [encoder setComputePipelineState:pipline];
244 [encoder setBuffers:buffers.data()
245 offsets:offsets.data()
246 withRange:range];
247 [encoder setBytes:&i
248 length:sizeof(uint32_t)
249 atIndex:buffers.size()];
250 [encoder setTextures:textures.data()
251 withRange:tex_range];
252
253 [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1)
254 threadsPerThreadgroup:MTLSizeMake(threads_per_group, 1, 1)];
255 [encoder endEncoding];
256 }
257
258 [command_buffer commit];
259 };
260 } else {
261 return [this, pipline, buffers, offsets, range, tex_range, thread_groups, threads_per_group, textures] () mutable {
262 command_buffer = [queue commandBuffer];
263 id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDispatchType:MTLDispatchTypeSerial];
264
265 [encoder setComputePipelineState:pipline];
266 [encoder setBuffers:buffers.data()
267 offsets:offsets.data()
268 withRange:range];
269 [encoder setTextures:textures.data()
270 withRange:tex_range];
271
272 [encoder dispatchThreadgroups:MTLSizeMake(thread_groups, 1, 1)
273 threadsPerThreadgroup:MTLSizeMake(threads_per_group, 1, 1)];
274 [encoder endEncoding];
275
276 [command_buffer commit];
277 };
278 }
279 }
280
281//------------------------------------------------------------------------------
287//------------------------------------------------------------------------------
288 std::function<float(void)> create_max_call(graph::shared_leaf<float, SAFE_MATH> &argument,
289 std::function<void(void)> run) {
290 MTLComputePipelineDescriptor *compute = [MTLComputePipelineDescriptor new];
291 compute.threadGroupSizeIsMultipleOfThreadExecutionWidth = YES;
292 compute.computeFunction = [library newFunctionWithName:@"max_reduction"];
293 compute.maxTotalThreadsPerThreadgroup = 1024;
294 compute.buffers[0].mutability = MTLMutabilityImmutable;
295
296 NSError *error;
297 id<MTLComputePipelineState> max_state = [device newComputePipelineStateWithDescriptor:compute
298 options:MTLPipelineOptionNone
299 reflection:NULL
300 error:&error];
301 if (error) {
302 NSLog(@"%@", error);
303 }
304
305 id<MTLBuffer> result = [device newBufferWithLength:sizeof(float)
306 options:MTLResourceStorageModeShared];
307
308 id<MTLBuffer> buffer = kernel_arguments[argument.get()];
309
310 NSUInteger threads_per_group = max_state.maxTotalThreadsPerThreadgroup;
311 NSUInteger thread_width = max_state.threadExecutionWidth;
312 if (jit::verbose) {
313 std::cout << " Kernel name : max_reduction" << std::endl;
314 std::cout << " Thread execution width : " << thread_width << std::endl;
315 std::cout << " Threads per group : " << threads_per_group << std::endl;
316 std::cout << " Number of groups : " << 1 << std::endl;
317 std::cout << " Total problem size : " << threads_per_group*1 << std::endl;
318 }
319
320 return [this, run, buffer, result, max_state] () mutable {
321 run();
322 command_buffer = [queue commandBuffer];
323
324 id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDispatchType:MTLDispatchTypeSerial];
325
326 [encoder setComputePipelineState:max_state];
327 [encoder setBuffer:buffer offset:0 atIndex:0];
328 [encoder setBuffer:result offset:0 atIndex:1];
329 [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1)
330 threadsPerThreadgroup:MTLSizeMake(1024, 1, 1)];
331 [encoder endEncoding];
332
333 [command_buffer commit];
334 [command_buffer waitUntilCompleted];
335
336 return static_cast<float *> (result.contents)[0];
337 };
338 }
339
340//------------------------------------------------------------------------------
342//------------------------------------------------------------------------------
343 MTLCompileOptions *compile_options() {
344 MTLCompileOptions *options = [MTLCompileOptions new];
345 options.mathMode = MTLMathModeFast;
346 options.mathFloatingPointFunctions = MTLMathFloatingPointFunctionsFast;
347 return options;
348 }
349
350//------------------------------------------------------------------------------
352//------------------------------------------------------------------------------
353 void wait() {
354 command_buffer = [queue commandBuffer];
355
356 [command_buffer commit];
357 [command_buffer waitUntilCompleted];
358 }
359
360//------------------------------------------------------------------------------
365//------------------------------------------------------------------------------
366 void print_results(const size_t index,
368 wait();
369 for (auto &out : nodes) {
370 std::cout << static_cast<float *> ([kernel_arguments[out.get()] contents])[index] << " ";
371 }
372 std::cout << std::endl;
373 }
374
375//------------------------------------------------------------------------------
381//------------------------------------------------------------------------------
382 float check_value(const size_t index,
384 wait();
385 return static_cast<float *> ([kernel_arguments[node.get()] contents])[index];
386 }
387
388//------------------------------------------------------------------------------
393//------------------------------------------------------------------------------
395 float *source) {
396 const size_t size = [kernel_arguments[node.get()] length];
397 memcpy([kernel_arguments[node.get()] contents],
398 source, size);
399 }
400
401//------------------------------------------------------------------------------
406//------------------------------------------------------------------------------
408 float *destination) {
409 command_buffer = [queue commandBuffer];
410
411 [command_buffer commit];
412 [command_buffer waitUntilCompleted];
413
414 memcpy(destination,
415 kernel_arguments[node.get()].contents,
416 kernel_arguments[node.get()].length);
417 }
418
419//------------------------------------------------------------------------------
423//------------------------------------------------------------------------------
424 void create_header(std::ostringstream &source_buffer) {
425 source_buffer << "#include <metal_stdlib>" << std::endl;
426 source_buffer << "#include <metal_simdgroup>" << std::endl;
427 source_buffer << "using namespace metal;" << std::endl;
428 }
429
430//------------------------------------------------------------------------------
444//------------------------------------------------------------------------------
445 void create_kernel_prefix(std::ostringstream &source_buffer,
446 const std::string name,
450 const size_t size,
451 const std::vector<bool> &is_constant,
452 jit::register_map &registers,
453 const jit::register_usage &usage,
454 jit::texture1d_list &textures1d,
455 jit::texture2d_list &textures2d) {
456 source_buffer << std::endl;
457 source_buffer << "kernel void " << name << "(" << std::endl;
458
459 bufferMutability[name] = std::vector<MTLMutability> ();
460
461 size_t buffer_count = 0;
462 std::unordered_set<void *> used_args;
463 for (size_t i = 0, ie = inputs.size(); i < ie; i++) {
464 if (!used_args.contains(inputs[i].get())) {
465 bufferMutability[name].push_back(is_constant[i] ? MTLMutabilityMutable : MTLMutabilityImmutable);
466 source_buffer << " " << (is_constant[i] ? "constant" : "device")
467 << " float *"
468 << jit::to_string('v', inputs[i].get())
469 << " [[buffer(" << buffer_count++ << ")]], // "
470 << inputs[i]->get_symbol()
471#ifndef USE_INPUT_CACHE
472#ifdef SHOW_USE_COUNT
473 << " used " << usage.at(inputs[i].get())
474#endif
475#endif
476 << std::endl;
477 used_args.insert(inputs[i].get());
478 }
479 }
480 for (size_t i = 0, ie = outputs.size(); i < ie; i++) {
481 if (!used_args.contains(outputs[i].get())) {
482 bufferMutability[name].push_back(MTLMutabilityMutable);
483 source_buffer << " device float *"
484 << jit::to_string('o', outputs[i].get())
485 << " [[buffer(" << buffer_count++ << ")]],"
486 << std::endl;
487 used_args.insert(outputs[i].get());
488 }
489 }
490 if (state.get()) {
491 bufferMutability[name].push_back(MTLMutabilityMutable);
492 source_buffer << " device mt_state *"
493 << jit::to_string('s', state.get())
494 << " [[buffer(" << buffer_count++ << ")]],"
495 << std::endl
496 << " constant uint32_t &offset [[buffer("
497 << buffer_count++ << ")]],"
498 << std::endl;
499 }
500 size_t index = 0;
501 for (auto &[key, value] : textures1d) {
502 source_buffer << " const texture1d<float, access::read> "
503 << jit::to_string('a', key)
504 << " [[texture(" << index++ << ")]],"
505 << std::endl;
506 }
507 for (auto &[key, value] : textures2d) {
508 source_buffer << " const texture2d<float, access::read> "
509 << jit::to_string('a', key)
510 << " [[texture(" << index++ << ")]],"
511 << std::endl;
512 }
513 if (state.get()) {
514 source_buffer << " uint thread_index [[thread_index_in_threadgroup]],"
515 << std::endl;
516 }
517 source_buffer << " uint index [[thread_position_in_grid]]) {" << std::endl
518 << " if (";
519 if (state.get()) {
520 source_buffer << "offset + ";
521 }
522 source_buffer << "index < " << size << ") {" << std::endl;
523
524 for (auto &input : inputs) {
525#ifdef USE_INPUT_CACHE
526 if (usage.at(input.get())) {
527 registers[input.get()] = jit::to_string('r', input.get());
528 source_buffer << " const ";
529 jit::add_type<float> (source_buffer);
530 source_buffer << " " << registers[input.get()] << " = "
531 << jit::to_string('v', input.get())
532 << "[index]; // " << input->get_symbol()
533#ifdef SHOW_USE_COUNT
534 << " used " << usage.at(input.get())
535#endif
536 << std::endl;
537 }
538#else
539 registers[input.get()] = jit::to_string('v', input.get()) + "[index]";
540#endif
541 }
542 if (state.get()) {
543#ifdef USE_INPUT_CACHE
544 registers[state.get()] = jit::to_string('r', state.get());
545 source_buffer << " device mt_state &" << registers[state.get()]
546 << " = " << jit::to_string('s', state.get())
547 << "[thread_index];"
548#ifdef SHOW_USE_COUNT
549 << " // used " << usage.at(input.get())
550#endif
551 << std::endl;
552#else
553 registers[state.get()] = jit::to_string('s', state.get()) + "[thread_index]";
554#endif
555 }
556 }
557
558//------------------------------------------------------------------------------
568//------------------------------------------------------------------------------
569 void create_kernel_postfix(std::ostringstream &source_buffer,
573 jit::register_map &registers,
574 jit::register_map &indices,
575 const jit::register_usage &usage) {
576 std::unordered_set<void *> out_registers;
577 for (auto &[out, in] : setters) {
578 if (!out->is_match(in) &&
579 !out_registers.contains(out.get())) {
580 graph::shared_leaf<float, SAFE_MATH> a = out->compile(source_buffer,
581 registers,
582 indices,
583 usage);
584 source_buffer << " "
585 << jit::to_string('v', in.get())
586 << "[index] = ";
587 if constexpr (SAFE_MATH) {
588 source_buffer << "isnan(" << registers[a.get()]
589 << ") ? 0.0 : ";
590 }
591 source_buffer << registers[a.get()] << ";" << std::endl;
592 out_registers.insert(out.get());
593 }
594 }
595
596 for (auto &out : outputs) {
597 if (!graph::variable_cast(out).get() &&
598 !out_registers.contains(out.get())) {
599 graph::shared_leaf<float, SAFE_MATH> a = out->compile(source_buffer,
600 registers,
601 indices,
602 usage);
603 source_buffer << " " << jit::to_string('o', out.get())
604 << "[index] = ";
605 if constexpr (SAFE_MATH) {
606 source_buffer << "isnan(" << registers[a.get()]
607 << ") ? 0.0 : ";
608 }
609 source_buffer << registers[a.get()] << ";" << std::endl;
610 out_registers.insert(out.get());
611 }
612 }
613
614 source_buffer << " }" << std::endl << "}" << std::endl;
615 }
616
617//------------------------------------------------------------------------------
622//------------------------------------------------------------------------------
623 void create_reduction(std::ostringstream &source_buffer,
624 const size_t size) {
625 source_buffer << std::endl;
626 source_buffer << "kernel void max_reduction(" << std::endl;
627 source_buffer << " constant float *input [[buffer(0)]]," << std::endl;
628 source_buffer << " device float *result [[buffer(1)]]," << std::endl;
629 source_buffer << " uint i [[thread_position_in_grid]]," << std::endl;
630 source_buffer << " uint j [[simdgroup_index_in_threadgroup]]," << std::endl;
631 source_buffer << " uint k [[thread_index_in_simdgroup]]) {" << std::endl;
632 source_buffer << " if (i < " << size << ") {" << std::endl;
633 source_buffer << " float sub_max = input[i];" << std::endl;
634 source_buffer << " for (size_t index = i + 1024; index < " << size <<"; index += 1024) {" << std::endl;
635 source_buffer << " sub_max = max(sub_max, input[index]);" << std::endl;
636 source_buffer << " }" << std::endl;
637 source_buffer << " threadgroup float thread_max[32];" << std::endl;
638 source_buffer << " thread_max[j] = simd_max(sub_max);" << std::endl;
639 source_buffer << " threadgroup_barrier(mem_flags::mem_threadgroup);" << std::endl;
640 source_buffer << " if (j == 0) {" << std::endl;
641 source_buffer << " *result = simd_max(thread_max[k]);" << std::endl;
642 source_buffer << " }" << std::endl;
643 source_buffer << " }" << std::endl;
644 source_buffer << "}" << std::endl << std::endl;
645 }
646
647//------------------------------------------------------------------------------
651//------------------------------------------------------------------------------
653 return static_cast<float *> ([kernel_arguments[node.get()] contents]);
654 }
655 };
656}
657
658#endif /* metal_context_h */
Class representing a generic buffer.
Definition backend.hpp:29
size_t size() const
Get size of the buffer.
Definition backend.hpp:116
T * data()
Get a pointer to the basic memory buffer.
Definition backend.hpp:270
Class representing a metal gpu context.
Definition metal_context.hpp:25
void create_kernel_postfix(std::ostringstream &source_buffer, graph::output_nodes< float, SAFE_MATH > &outputs, graph::map_nodes< float, SAFE_MATH > &setters, graph::shared_random_state< float, SAFE_MATH > state, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Create kernel postfix.
Definition metal_context.hpp:569
void create_reduction(std::ostringstream &source_buffer, const size_t size)
Create reduction.
Definition metal_context.hpp:623
static size_t max_concurrency()
Get the maximum number of concurrent instances.
Definition metal_context.hpp:54
MTLCompileOptions * compile_options()
Get the compile options.
Definition metal_context.hpp:343
void create_kernel_prefix(std::ostringstream &source_buffer, const std::string name, graph::input_nodes< float, SAFE_MATH > &inputs, graph::output_nodes< float, SAFE_MATH > &outputs, graph::shared_random_state< float, SAFE_MATH > state, const size_t size, const std::vector< bool > &is_constant, jit::register_map &registers, const jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d)
Create kernel prefix.
Definition metal_context.hpp:445
void wait()
Hold the current thread until the command buffer has completed.
Definition metal_context.hpp:353
float check_value(const size_t index, const graph::shared_leaf< float, SAFE_MATH > &node)
Check the value.
Definition metal_context.hpp:382
std::function< void(void)> create_kernel_call(const std::string kernel_name, graph::input_nodes< float, SAFE_MATH > inputs, graph::output_nodes< float, SAFE_MATH > outputs, graph::shared_random_state< float, SAFE_MATH > state, const size_t num_rays, const jit::texture1d_list &tex1d_list, const jit::texture2d_list &tex2d_list)
Create a kernel calling function.
Definition metal_context.hpp:111
void compile(const std::string kernel_source, std::vector< std::string > names, const bool add_reduction=false)
Compile the kernels.
Definition metal_context.hpp:81
void print_results(const size_t index, const graph::output_nodes< float, SAFE_MATH > &nodes)
Print out the results.
Definition metal_context.hpp:366
float * get_buffer(graph::shared_leaf< float, SAFE_MATH > &node)
Get the buffer for a node.
Definition metal_context.hpp:652
static std::string device_type()
Device discription.
Definition metal_context.hpp:61
void copy_to_device(graph::shared_leaf< float, SAFE_MATH > node, float *source)
Copy buffer contents to the device.
Definition metal_context.hpp:394
metal_context(const size_t index)
Construct a metal context.
Definition metal_context.hpp:70
static constexpr size_t random_state_size
Size of random state needed.
Definition metal_context.hpp:44
int remaining_const_memory
Remaining constant memory in bytes. NOT USED.
Definition metal_context.hpp:47
std::function< float(void)> create_max_call(graph::shared_leaf< float, SAFE_MATH > &argument, std::function< void(void)> run)
Create a max compute kernel calling function.
Definition metal_context.hpp:288
void copy_to_host(graph::shared_leaf< float, SAFE_MATH > node, float *destination)
Copy buffer contents to host.
Definition metal_context.hpp:407
void create_header(std::ostringstream &source_buffer)
Create the source header.
Definition metal_context.hpp:424
Name space for GPU backends.
Definition cpu_context.hpp:51
std::shared_ptr< variable_node< T, SAFE_MATH > > shared_variable
Convenience type alias for shared variable nodes.
Definition node.hpp:1727
std::shared_ptr< random_state_node< T, SAFE_MATH > > shared_random_state
Convenience type alias for shared sqrt nodes.
Definition random.hpp:272
std::vector< shared_variable< T, SAFE_MATH > > input_nodes
Convenience type alias for a vector of inputs.
Definition node.hpp:1730
shared_variable< T, SAFE_MATH > variable_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a variable node.
Definition node.hpp:1746
std::shared_ptr< leaf_node< T, SAFE_MATH > > shared_leaf
Convenience type alias for shared leaf nodes.
Definition node.hpp:673
std::vector< std::pair< shared_leaf< T, SAFE_MATH >, shared_variable< T, SAFE_MATH > > > map_nodes
Convenience type alias for maping end codes back to inputs.
Definition node.hpp:1734
std::vector< shared_leaf< T, SAFE_MATH > > output_nodes
Convenience type alias for a vector of output nodes.
Definition node.hpp:688
std::map< void *, size_t > texture1d_list
Type alias for indexing 1D textures.
Definition register.hpp:262
std::map< void *, std::array< size_t, 2 > > texture2d_list
Type alias for indexing 2D textures.
Definition register.hpp:264
std::map< void *, size_t > register_usage
Type alias for counting register usage.
Definition register.hpp:258
std::map< void *, std::string > register_map
Type alias for mapping node pointers to register names.
Definition register.hpp:256
std::string to_string(const char prefix, const NODE *pointer)
Convert a graph::leaf_node pointer to a string.
Definition register.hpp:245
Name space for output files.
Definition output.hpp:16
Random constants and distributions.