Graph Framework
Loading...
Searching...
No Matches
metal_context.hpp
Go to the documentation of this file.
1//------------------------------------------------------------------------------
6//------------------------------------------------------------------------------
7
8#ifndef metal_context_h
9#define metal_context_h
10
11#include <unordered_set>
12
13#import <Metal/Metal.h>
14
15#include "random.hpp"
16
18namespace gpu {
19//------------------------------------------------------------------------------
23//------------------------------------------------------------------------------
24 template<bool SAFE_MATH=false>
26 private:
28 id<MTLDevice> device;
30 id<MTLCommandQueue> queue;
32 std::map<graph::leaf_node<float, SAFE_MATH> *, id<MTLBuffer>> kernel_arguments;
34 std::map<void *, id<MTLTexture>> texture_arguments;
36 id<MTLCommandBuffer> command_buffer;
38 id<MTLLibrary> library;
40 std::map<std::string, std::vector<MTLMutability>> bufferMutability;
41
42 public:
44 constexpr static size_t random_state_size = 1024;
45
48
49//------------------------------------------------------------------------------
53//------------------------------------------------------------------------------
54 static size_t max_concurrency() {
55 return MTLCopyAllDevices().count;
56 }
57
58//------------------------------------------------------------------------------
60//------------------------------------------------------------------------------
61 static std::string device_type() {
62 return "Metal GPU";
63 }
64
65//------------------------------------------------------------------------------
69//------------------------------------------------------------------------------
70 metal_context(const size_t index) :
71 device([MTLCopyAllDevices() objectAtIndex:index]),
72 queue([device newCommandQueue]) {}
73
74//------------------------------------------------------------------------------
80//------------------------------------------------------------------------------
81 void compile(const std::string kernel_source,
82 std::vector<std::string> names,
83 const bool add_reduction=false) {
84 NSError *error;
85 library = [device newLibraryWithSource:[NSString stringWithCString:kernel_source.c_str()
86 encoding:NSUTF8StringEncoding]
87 options:compile_options()
88 error:&error];
89
90 if (error) {
91 NSLog(@"%@", error);
92 }
93
94 if (jit::verbose) {
95 std::cout << "Metal GPU info." << std::endl;
96 }
97 }
98
99//------------------------------------------------------------------------------
110//------------------------------------------------------------------------------
111 std::function<void(void)> create_kernel_call(const std::string kernel_name,
115 const size_t num_rays,
116 const jit::texture1d_list &tex1d_list,
117 const jit::texture2d_list &tex2d_list) {
118 NSError *error;
119
120 id<MTLFunction> function = [library newFunctionWithName:[NSString stringWithCString:kernel_name.c_str()
121 encoding:NSUTF8StringEncoding]];
122
123 MTLComputePipelineDescriptor *compute = [MTLComputePipelineDescriptor new];
124 compute.threadGroupSizeIsMultipleOfThreadExecutionWidth = YES;
125 compute.computeFunction = function;
126 compute.maxTotalThreadsPerThreadgroup = 1024;
127 for (size_t i = 0, ie = bufferMutability[kernel_name].size(); i < ie; i++) {
128 compute.buffers[i].mutability = bufferMutability[kernel_name][i];
129 }
130
131 id<MTLComputePipelineState> pipline = [device newComputePipelineStateWithDescriptor:compute
132 options:MTLPipelineOptionNone
133 reflection:NULL
134 error:&error];
135
136 if (error) {
137 NSLog(@"%@", error);
138 }
139
140 std::vector<id<MTLBuffer>> buffers;
141 std::set<graph::leaf_node<float, SAFE_MATH> *> needed_buffers;
142
143 const size_t buffer_element_size = sizeof(float);
144 for (graph::shared_variable<float, SAFE_MATH> &input : inputs) {
145 if (!kernel_arguments.contains(input.get())) {
146 backend::buffer<float> buffer = input->evaluate();
147 kernel_arguments[input.get()] = [device newBufferWithBytes:buffer.data()
148 length:buffer.size()*buffer_element_size
149 options:MTLResourceStorageModeShared];
150 buffers.push_back(kernel_arguments[input.get()]);
151 needed_buffers.insert(input.get());
152 }
153 if (!needed_buffers.contains(input.get())) {
154 buffers.push_back(kernel_arguments[input.get()]);
155 needed_buffers.insert(input.get());
156 }
157 }
159 if (!kernel_arguments.contains(output.get())) {
160 kernel_arguments[output.get()] = [device newBufferWithLength:num_rays*sizeof(float)
161 options:MTLResourceStorageModeShared];
162 buffers.push_back(kernel_arguments[output.get()]);
163 needed_buffers.insert(output.get());
164 }
165 if (!needed_buffers.contains(output.get())) {
166 buffers.push_back(kernel_arguments[output.get()]);
167 needed_buffers.insert(output.get());
168 }
169 }
170 if (state.get()) {
171 if (!kernel_arguments.contains(state.get())) {
172 kernel_arguments[state.get()] = [device newBufferWithBytes:state->data()
173 length:state->get_size_bytes()
174 options:MTLResourceCPUCacheModeWriteCombined |
175 MTLResourceStorageModeShared |
176 MTLResourceHazardTrackingModeUntracked];
177 }
178 buffers.push_back(kernel_arguments[state.get()]);
179 }
180
181 std::vector<id<MTLTexture>> textures;
182 command_buffer = [queue commandBuffer];
183 id<MTLBlitCommandEncoder> encoder = [command_buffer blitCommandEncoder];
184 for (auto &[data, size] : tex1d_list) {
185 if (!texture_arguments.contains(data)) {
186 MTLTextureDescriptor *descriptor = [MTLTextureDescriptor new];
187 descriptor.textureType = MTLTextureType1D;
188 descriptor.pixelFormat = MTLPixelFormatR32Float;
189 descriptor.width = size;
190 descriptor.storageMode = MTLStorageModeManaged;
191 descriptor.cpuCacheMode = MTLCPUCacheModeWriteCombined;
192 descriptor.hazardTrackingMode = MTLHazardTrackingModeUntracked;
193 descriptor.usage = MTLTextureUsageShaderRead;
194 texture_arguments[data] = [device newTextureWithDescriptor:descriptor];
195 [texture_arguments[data] replaceRegion:MTLRegionMake1D(0, size)
196 mipmapLevel:0
197 withBytes:reinterpret_cast<float *> (data)
198 bytesPerRow:4*size];
199
200 [encoder optimizeContentsForGPUAccess:texture_arguments[data]];
201 }
202 textures.push_back(texture_arguments[data]);
203 }
204 for (auto &[data, size] : tex2d_list) {
205 if (!texture_arguments.contains(data)) {
206 MTLTextureDescriptor *descriptor = [MTLTextureDescriptor new];
207 descriptor.textureType = MTLTextureType2D;
208 descriptor.pixelFormat = MTLPixelFormatR32Float;
209 descriptor.width = size[1];
210 descriptor.height = size[0];
211 descriptor.storageMode = MTLStorageModeManaged;
212 descriptor.cpuCacheMode = MTLCPUCacheModeWriteCombined;
213 descriptor.hazardTrackingMode = MTLHazardTrackingModeUntracked;
214 descriptor.usage = MTLTextureUsageShaderRead;
215 texture_arguments[data] = [device newTextureWithDescriptor:descriptor];
216 [texture_arguments[data] replaceRegion:MTLRegionMake2D(0, 0, size[1], size[0])
217 mipmapLevel:0
218 withBytes:reinterpret_cast<float *> (data)
219 bytesPerRow:4*size[1]];
220
221 [encoder optimizeContentsForGPUAccess:texture_arguments[data]];
222 }
223 textures.push_back(texture_arguments[data]);
224 }
225 [encoder endEncoding];
226 [command_buffer commit];
227
228 std::vector<NSUInteger> offsets(buffers.size(), 0);
229 NSRange range = NSMakeRange(0, buffers.size());
230 NSRange tex_range = NSMakeRange(0, textures.size());
231
232 NSUInteger threads_per_group = pipline.maxTotalThreadsPerThreadgroup;
233 NSUInteger thread_width = pipline.threadExecutionWidth;
234 NSUInteger thread_groups = num_rays/threads_per_group + (num_rays%threads_per_group ? 1 : 0);
235
236 if (jit::verbose) {
237 std::cout << " Kernel name : " << kernel_name << std::endl;
238 std::cout << " Thread execution width : " << thread_width << std::endl;
239 std::cout << " Threads per group : " << threads_per_group << std::endl;
240 std::cout << " Number of groups : " << thread_groups << std::endl;
241 std::cout << " Total problem size : " << threads_per_group*thread_groups << std::endl;
242 }
243
244 if (state.get()) {
245 return [this, num_rays, pipline, buffers, offsets, range, tex_range, thread_groups, threads_per_group, textures] () mutable {
246 command_buffer = [queue commandBuffer];
247 for (uint32_t i = 0; i < num_rays; i += threads_per_group) {
248 id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDispatchType:MTLDispatchTypeSerial];
249
250 for (size_t j = 0, je = buffers.size() - 1; j < je; j++) {
251 offsets[j] = i*sizeof(float);
252 }
253
254 [encoder setComputePipelineState:pipline];
255 [encoder setBuffers:buffers.data()
256 offsets:offsets.data()
257 withRange:range];
258 [encoder setBytes:&i
259 length:sizeof(uint32_t)
260 atIndex:buffers.size()];
261 [encoder setTextures:textures.data()
262 withRange:tex_range];
263
264 [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1)
265 threadsPerThreadgroup:MTLSizeMake(threads_per_group, 1, 1)];
266 [encoder endEncoding];
267 }
268
269 [command_buffer commit];
270 };
271 } else {
272 return [this, pipline, buffers, offsets, range, tex_range, thread_groups, threads_per_group, textures] () mutable {
273 command_buffer = [queue commandBuffer];
274 id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDispatchType:MTLDispatchTypeSerial];
275
276 [encoder setComputePipelineState:pipline];
277 [encoder setBuffers:buffers.data()
278 offsets:offsets.data()
279 withRange:range];
280 [encoder setTextures:textures.data()
281 withRange:tex_range];
282
283 [encoder dispatchThreadgroups:MTLSizeMake(thread_groups, 1, 1)
284 threadsPerThreadgroup:MTLSizeMake(threads_per_group, 1, 1)];
285 [encoder endEncoding];
286
287 [command_buffer commit];
288 };
289 }
290 }
291
292//------------------------------------------------------------------------------
298//------------------------------------------------------------------------------
299 std::function<float(void)> create_max_call(graph::shared_leaf<float, SAFE_MATH> &argument,
300 std::function<void(void)> run) {
301 MTLComputePipelineDescriptor *compute = [MTLComputePipelineDescriptor new];
302 compute.threadGroupSizeIsMultipleOfThreadExecutionWidth = YES;
303 compute.computeFunction = [library newFunctionWithName:@"max_reduction"];
304 compute.maxTotalThreadsPerThreadgroup = 1024;
305 compute.buffers[0].mutability = MTLMutabilityImmutable;
306
307 NSError *error;
308 id<MTLComputePipelineState> max_state = [device newComputePipelineStateWithDescriptor:compute
309 options:MTLPipelineOptionNone
310 reflection:NULL
311 error:&error];
312 if (error) {
313 NSLog(@"%@", error);
314 }
315
316 id<MTLBuffer> result = [device newBufferWithLength:sizeof(float)
317 options:MTLResourceStorageModeShared];
318
319 id<MTLBuffer> buffer = kernel_arguments[argument.get()];
320
321 NSUInteger threads_per_group = max_state.maxTotalThreadsPerThreadgroup;
322 NSUInteger thread_width = max_state.threadExecutionWidth;
323 if (jit::verbose) {
324 std::cout << " Kernel name : max_reduction" << std::endl;
325 std::cout << " Thread execution width : " << thread_width << std::endl;
326 std::cout << " Threads per group : " << threads_per_group << std::endl;
327 std::cout << " Number of groups : " << 1 << std::endl;
328 std::cout << " Total problem size : " << threads_per_group*1 << std::endl;
329 }
330
331 return [this, run, buffer, result, max_state] () mutable {
332 run();
333 command_buffer = [queue commandBuffer];
334
335 id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDispatchType:MTLDispatchTypeSerial];
336
337 [encoder setComputePipelineState:max_state];
338 [encoder setBuffer:buffer offset:0 atIndex:0];
339 [encoder setBuffer:result offset:0 atIndex:1];
340 [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1)
341 threadsPerThreadgroup:MTLSizeMake(1024, 1, 1)];
342 [encoder endEncoding];
343
344 [command_buffer commit];
345 [command_buffer waitUntilCompleted];
346
347 return static_cast<float *> (result.contents)[0];
348 };
349 }
350
351//------------------------------------------------------------------------------
353//------------------------------------------------------------------------------
354 MTLCompileOptions *compile_options() {
355 MTLCompileOptions *options = [MTLCompileOptions new];
356 options.mathMode = MTLMathModeFast;
357 options.mathFloatingPointFunctions = MTLMathFloatingPointFunctionsFast;
358 return options;
359 }
360
361//------------------------------------------------------------------------------
363//------------------------------------------------------------------------------
364 void wait() {
365 command_buffer = [queue commandBuffer];
366
367 [command_buffer commit];
368 [command_buffer waitUntilCompleted];
369 }
370
371//------------------------------------------------------------------------------
376//------------------------------------------------------------------------------
377 void print_results(const size_t index,
379 wait();
380 for (auto &out : nodes) {
381 std::cout << static_cast<float *> ([kernel_arguments[out.get()] contents])[index] << " ";
382 }
383 std::cout << std::endl;
384 }
385
386//------------------------------------------------------------------------------
392//------------------------------------------------------------------------------
393 float check_value(const size_t index,
395 wait();
396 return static_cast<float *> ([kernel_arguments[node.get()] contents])[index];
397 }
398
399//------------------------------------------------------------------------------
404//------------------------------------------------------------------------------
406 float *source) {
407 const size_t size = [kernel_arguments[node.get()] length];
408 memcpy([kernel_arguments[node.get()] contents],
409 source, size);
410 }
411
412//------------------------------------------------------------------------------
417//------------------------------------------------------------------------------
419 float *destination) {
420 command_buffer = [queue commandBuffer];
421
422 [command_buffer commit];
423 [command_buffer waitUntilCompleted];
424
425 memcpy(destination,
426 kernel_arguments[node.get()].contents,
427 kernel_arguments[node.get()].length);
428 }
429
430//------------------------------------------------------------------------------
434//------------------------------------------------------------------------------
435 void create_header(std::ostringstream &source_buffer) {
436 source_buffer << "#include <metal_stdlib>" << std::endl;
437 source_buffer << "#include <metal_simdgroup>" << std::endl;
438 source_buffer << "using namespace metal;" << std::endl;
439 }
440
441//------------------------------------------------------------------------------
455//------------------------------------------------------------------------------
456 void create_kernel_prefix(std::ostringstream &source_buffer,
457 const std::string name,
461 const size_t size,
462 const std::vector<bool> &is_constant,
463 jit::register_map &registers,
464 const jit::register_usage &usage,
465 jit::texture1d_list &textures1d,
466 jit::texture2d_list &textures2d) {
467 source_buffer << std::endl;
468 source_buffer << "kernel void " << name << "(" << std::endl;
469
470 bufferMutability[name] = std::vector<MTLMutability> ();
471
472 size_t buffer_count = 0;
473 std::unordered_set<void *> used_args;
474 for (size_t i = 0, ie = inputs.size(); i < ie; i++) {
475 if (!used_args.contains(inputs[i].get())) {
476 bufferMutability[name].push_back(is_constant[i] ? MTLMutabilityMutable : MTLMutabilityImmutable);
477 source_buffer << " " << (is_constant[i] ? "constant" : "device")
478 << " float *"
479 << jit::to_string('v', inputs[i].get())
480 << " [[buffer(" << buffer_count++ << ")]], // "
481 << inputs[i]->get_symbol()
482#ifndef USE_INPUT_CACHE
483#ifdef SHOW_USE_COUNT
484 << " used " << usage.at(inputs[i].get())
485#endif
486#endif
487 << std::endl;
488 used_args.insert(inputs[i].get());
489 }
490 }
491 for (size_t i = 0, ie = outputs.size(); i < ie; i++) {
492 if (!used_args.contains(outputs[i].get())) {
493 bufferMutability[name].push_back(MTLMutabilityMutable);
494 source_buffer << " device float *"
495 << jit::to_string('o', outputs[i].get())
496 << " [[buffer(" << buffer_count++ << ")]],"
497 << std::endl;
498 used_args.insert(outputs[i].get());
499 }
500 }
501 if (state.get()) {
502 bufferMutability[name].push_back(MTLMutabilityMutable);
503 source_buffer << " device mt_state *"
504 << jit::to_string('s', state.get())
505 << " [[buffer(" << buffer_count++ << ")]],"
506 << std::endl
507 << " constant uint32_t &offset [[buffer("
508 << buffer_count++ << ")]],"
509 << std::endl;
510 }
511 size_t index = 0;
512 for (auto &[key, value] : textures1d) {
513 source_buffer << " const texture1d<float, access::read> "
514 << jit::to_string('a', key)
515 << " [[texture(" << index++ << ")]],"
516 << std::endl;
517 }
518 for (auto &[key, value] : textures2d) {
519 source_buffer << " const texture2d<float, access::read> "
520 << jit::to_string('a', key)
521 << " [[texture(" << index++ << ")]],"
522 << std::endl;
523 }
524 if (state.get()) {
525 source_buffer << " uint thread_index [[thread_index_in_threadgroup]],"
526 << std::endl;
527 }
528 source_buffer << " uint index [[thread_position_in_grid]]) {" << std::endl
529 << " if (";
530 if (state.get()) {
531 source_buffer << "offset + ";
532 }
533 source_buffer << "index < " << size << ") {" << std::endl;
534
535 for (auto &input : inputs) {
536#ifdef USE_INPUT_CACHE
537 if (usage.at(input.get())) {
538 registers[input.get()] = jit::to_string('r', input.get());
539 source_buffer << " const ";
540 jit::add_type<float> (source_buffer);
541 source_buffer << " " << registers[input.get()] << " = "
542 << jit::to_string('v', input.get())
543 << "[index]; // " << input->get_symbol()
544#ifdef SHOW_USE_COUNT
545 << " used " << usage.at(input.get())
546#endif
547 << std::endl;
548 }
549#else
550 registers[input.get()] = jit::to_string('v', input.get()) + "[index]";
551#endif
552 }
553 if (state.get()) {
554#ifdef USE_INPUT_CACHE
555 registers[state.get()] = jit::to_string('r', state.get());
556 source_buffer << " device mt_state &" << registers[state.get()]
557 << " = " << jit::to_string('s', state.get())
558 << "[thread_index];"
559#ifdef SHOW_USE_COUNT
560 << " // used " << usage.at(input.get())
561#endif
562 << std::endl;
563#else
564 registers[state.get()] = jit::to_string('s', state.get()) + "[thread_index]";
565#endif
566 }
567 }
568
569//------------------------------------------------------------------------------
579//------------------------------------------------------------------------------
580 void create_kernel_postfix(std::ostringstream &source_buffer,
584 jit::register_map &registers,
585 jit::register_map &indices,
586 const jit::register_usage &usage) {
587 std::unordered_set<void *> out_registers;
588 for (auto &[out, in] : setters) {
589 if (!out->is_match(in) &&
590 !out_registers.contains(out.get())) {
591 graph::shared_leaf<float, SAFE_MATH> a = out->compile(source_buffer,
592 registers,
593 indices,
594 usage);
595 source_buffer << " "
596 << jit::to_string('v', in.get())
597 << "[index] = ";
598 if constexpr (SAFE_MATH) {
599 source_buffer << "isnan(" << registers[a.get()]
600 << ") ? 0.0 : ";
601 }
602 source_buffer << registers[a.get()] << ";" << std::endl;
603 out_registers.insert(out.get());
604 }
605 }
606
607 for (auto &out : outputs) {
608 if (!graph::variable_cast(out).get() &&
609 !out_registers.contains(out.get())) {
610 graph::shared_leaf<float, SAFE_MATH> a = out->compile(source_buffer,
611 registers,
612 indices,
613 usage);
614 source_buffer << " " << jit::to_string('o', out.get())
615 << "[index] = ";
616 if constexpr (SAFE_MATH) {
617 source_buffer << "isnan(" << registers[a.get()]
618 << ") ? 0.0 : ";
619 }
620 source_buffer << registers[a.get()] << ";" << std::endl;
621 out_registers.insert(out.get());
622 }
623 }
624
625 source_buffer << " }" << std::endl << "}" << std::endl;
626 }
627
628//------------------------------------------------------------------------------
633//------------------------------------------------------------------------------
634 void create_reduction(std::ostringstream &source_buffer,
635 const size_t size) {
636 source_buffer << std::endl;
637 source_buffer << "kernel void max_reduction(" << std::endl;
638 source_buffer << " constant float *input [[buffer(0)]]," << std::endl;
639 source_buffer << " device float *result [[buffer(1)]]," << std::endl;
640 source_buffer << " uint i [[thread_position_in_grid]]," << std::endl;
641 source_buffer << " uint j [[simdgroup_index_in_threadgroup]]," << std::endl;
642 source_buffer << " uint k [[thread_index_in_simdgroup]]) {" << std::endl;
643 source_buffer << " if (i < " << size << ") {" << std::endl;
644 source_buffer << " float sub_max = input[i];" << std::endl;
645 source_buffer << " for (size_t index = i + 1024; index < " << size <<"; index += 1024) {" << std::endl;
646 source_buffer << " sub_max = max(sub_max, input[index]);" << std::endl;
647 source_buffer << " }" << std::endl;
648 source_buffer << " threadgroup float thread_max[32];" << std::endl;
649 source_buffer << " thread_max[j] = simd_max(sub_max);" << std::endl;
650 source_buffer << " threadgroup_barrier(mem_flags::mem_threadgroup);" << std::endl;
651 source_buffer << " if (j == 0) {" << std::endl;
652 source_buffer << " *result = simd_max(thread_max[k]);" << std::endl;
653 source_buffer << " }" << std::endl;
654 source_buffer << " }" << std::endl;
655 source_buffer << "}" << std::endl << std::endl;
656 }
657
658//------------------------------------------------------------------------------
662//------------------------------------------------------------------------------
664 return static_cast<float *> ([kernel_arguments[node.get()] contents]);
665 }
666 };
667}
668
669#endif /* metal_context_h */
Class representing a generic buffer.
Definition backend.hpp:29
size_t size() const
Get size of the buffer.
Definition backend.hpp:116
T * data()
Get a pointer to the basic memory buffer.
Definition backend.hpp:270
Class representing a metal gpu context.
Definition metal_context.hpp:25
void create_kernel_postfix(std::ostringstream &source_buffer, graph::output_nodes< float, SAFE_MATH > &outputs, graph::map_nodes< float, SAFE_MATH > &setters, graph::shared_random_state< float, SAFE_MATH > state, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Create kernel postfix.
Definition metal_context.hpp:580
void create_reduction(std::ostringstream &source_buffer, const size_t size)
Create reduction.
Definition metal_context.hpp:634
static size_t max_concurrency()
Get the maximum number of concurrent instances.
Definition metal_context.hpp:54
MTLCompileOptions * compile_options()
Get the compile options.
Definition metal_context.hpp:354
void create_kernel_prefix(std::ostringstream &source_buffer, const std::string name, graph::input_nodes< float, SAFE_MATH > &inputs, graph::output_nodes< float, SAFE_MATH > &outputs, graph::shared_random_state< float, SAFE_MATH > state, const size_t size, const std::vector< bool > &is_constant, jit::register_map &registers, const jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d)
Create kernel prefix.
Definition metal_context.hpp:456
void wait()
Hold the current thread until the command buffer has completed.
Definition metal_context.hpp:364
float check_value(const size_t index, const graph::shared_leaf< float, SAFE_MATH > &node)
Check the value.
Definition metal_context.hpp:393
std::function< void(void)> create_kernel_call(const std::string kernel_name, graph::input_nodes< float, SAFE_MATH > inputs, graph::output_nodes< float, SAFE_MATH > outputs, graph::shared_random_state< float, SAFE_MATH > state, const size_t num_rays, const jit::texture1d_list &tex1d_list, const jit::texture2d_list &tex2d_list)
Create a kernel calling function.
Definition metal_context.hpp:111
void compile(const std::string kernel_source, std::vector< std::string > names, const bool add_reduction=false)
Compile the kernels.
Definition metal_context.hpp:81
void print_results(const size_t index, const graph::output_nodes< float, SAFE_MATH > &nodes)
Print out the results.
Definition metal_context.hpp:377
float * get_buffer(graph::shared_leaf< float, SAFE_MATH > &node)
Get the buffer for a node.
Definition metal_context.hpp:663
static std::string device_type()
Device description.
Definition metal_context.hpp:61
void copy_to_device(graph::shared_leaf< float, SAFE_MATH > node, float *source)
Copy buffer contents to the device.
Definition metal_context.hpp:405
metal_context(const size_t index)
Construct a metal context.
Definition metal_context.hpp:70
static constexpr size_t random_state_size
Size of random state needed.
Definition metal_context.hpp:44
int remaining_const_memory
Remaining constant memory in bytes. NOT USED.
Definition metal_context.hpp:47
std::function< float(void)> create_max_call(graph::shared_leaf< float, SAFE_MATH > &argument, std::function< void(void)> run)
Create a max compute kernel calling function.
Definition metal_context.hpp:299
void copy_to_host(graph::shared_leaf< float, SAFE_MATH > node, float *destination)
Copy buffer contents to host.
Definition metal_context.hpp:418
void create_header(std::ostringstream &source_buffer)
Create the source header.
Definition metal_context.hpp:435
Name space for GPU backends.
Definition cpu_context.hpp:51
std::shared_ptr< variable_node< T, SAFE_MATH > > shared_variable
Convenience type alias for shared variable nodes.
Definition node.hpp:1708
std::shared_ptr< random_state_node< T, SAFE_MATH > > shared_random_state
Convenience type alias for shared sqrt nodes.
Definition random.hpp:263
std::vector< shared_variable< T, SAFE_MATH > > input_nodes
Convenience type alias for a vector of inputs.
Definition node.hpp:1711
shared_variable< T, SAFE_MATH > variable_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a variable node.
Definition node.hpp:1727
std::shared_ptr< leaf_node< T, SAFE_MATH > > shared_leaf
Convenience type alias for shared leaf nodes.
Definition node.hpp:676
std::vector< std::pair< shared_leaf< T, SAFE_MATH >, shared_variable< T, SAFE_MATH > > > map_nodes
Convenience type alias for mapping end codes back to inputs.
Definition node.hpp:1715
std::vector< shared_leaf< T, SAFE_MATH > > output_nodes
Convenience type alias for a vector of output nodes.
Definition node.hpp:691
std::map< void *, size_t > texture1d_list
Type alias for indexing 1D textures.
Definition register.hpp:263
std::map< void *, std::array< size_t, 2 > > texture2d_list
Type alias for indexing 2D textures.
Definition register.hpp:265
std::map< void *, size_t > register_usage
Type alias for counting register usage.
Definition register.hpp:259
std::map< void *, std::string > register_map
Type alias for mapping node pointers to register names.
Definition register.hpp:257
std::string to_string(const char prefix, const NODE *pointer)
Convert a graph::leaf_node pointer to a string.
Definition register.hpp:246
Name space for output files.
Definition output.hpp:16
Random constants and distributions.