11#include <unordered_set>
13#import <Metal/Metal.h>
24 template<
bool SAFE_MATH=false>
30 id<MTLCommandQueue> queue;
32 std::map<graph::leaf_node<float, SAFE_MATH> *, id<MTLBuffer>> kernel_arguments;
34 std::map<void *, id<MTLTexture>> texture_arguments;
36 id<MTLCommandBuffer> command_buffer;
38 id<MTLLibrary> library;
40 std::map<std::string, std::vector<MTLMutability>> bufferMutability;
55 return MTLCopyAllDevices().count;
71 device([MTLCopyAllDevices() objectAtIndex:index]),
72 queue([device newCommandQueue]) {}
81 void compile(
const std::string kernel_source,
82 std::vector<std::string> names,
83 const bool add_reduction=
false) {
85 library = [device newLibraryWithSource:[NSString stringWithCString:kernel_source.c_str()
86 encoding:NSUTF8StringEncoding]
95 std::cout <<
"Metal GPU info." << std::endl;
115 const size_t num_rays,
120 id<MTLFunction> function = [library newFunctionWithName:[NSString stringWithCString:kernel_name.c_str()
121 encoding:NSUTF8StringEncoding]];
123 MTLComputePipelineDescriptor *compute = [MTLComputePipelineDescriptor
new];
124 compute.threadGroupSizeIsMultipleOfThreadExecutionWidth = YES;
125 compute.computeFunction = function;
126 compute.maxTotalThreadsPerThreadgroup = 1024;
127 for (
size_t i = 0, ie = bufferMutability[kernel_name].size(); i < ie; i++) {
128 compute.buffers[i].mutability = bufferMutability[kernel_name][i];
131 id<MTLComputePipelineState> pipline = [device newComputePipelineStateWithDescriptor:compute
132 options:MTLPipelineOptionNone
140 std::vector<id<MTLBuffer>> buffers;
142 const size_t buffer_element_size =
sizeof(float);
144 if (!kernel_arguments.contains(input.get())) {
146 kernel_arguments[input.get()] = [device newBufferWithBytes:buffer.
data()
147 length:buffer.
size()*buffer_element_size
148 options:MTLResourceStorageModeShared];
149 buffers.push_back(kernel_arguments[input.get()]);
153 if (!kernel_arguments.contains(
output.get())) {
154 kernel_arguments[
output.get()] = [device newBufferWithLength:num_rays*
sizeof(float)
155 options:MTLResourceStorageModeShared];
156 buffers.push_back(kernel_arguments[
output.get()]);
160 if (!kernel_arguments.contains(state.get())) {
161 kernel_arguments[state.get()] = [device newBufferWithBytes:state->data()
162 length:state->get_size_bytes()
163 options:MTLResourceCPUCacheModeWriteCombined |
164 MTLResourceStorageModeShared |
165 MTLResourceHazardTrackingModeUntracked];
167 buffers.push_back(kernel_arguments[state.get()]);
170 std::vector<id<MTLTexture>> textures;
171 command_buffer = [queue commandBuffer];
172 id<MTLBlitCommandEncoder> encoder = [command_buffer blitCommandEncoder];
173 for (
auto &[data, size] : tex1d_list) {
174 if (!texture_arguments.contains(data)) {
175 MTLTextureDescriptor *discriptor = [MTLTextureDescriptor
new];
176 discriptor.textureType = MTLTextureType1D;
177 discriptor.pixelFormat = MTLPixelFormatR32Float;
178 discriptor.width = size;
179 discriptor.storageMode = MTLStorageModeManaged;
180 discriptor.cpuCacheMode = MTLCPUCacheModeWriteCombined;
181 discriptor.hazardTrackingMode = MTLHazardTrackingModeUntracked;
182 discriptor.usage = MTLTextureUsageShaderRead;
183 texture_arguments[data] = [device newTextureWithDescriptor:discriptor];
184 [texture_arguments[data] replaceRegion:MTLRegionMake1D(0, size)
186 withBytes:
reinterpret_cast<float *
> (data)
189 [encoder optimizeContentsForGPUAccess:texture_arguments[data]];
191 textures.push_back(texture_arguments[data]);
193 for (
auto &[data, size] : tex2d_list) {
194 if (!texture_arguments.contains(data)) {
195 MTLTextureDescriptor *discriptor = [MTLTextureDescriptor
new];
196 discriptor.textureType = MTLTextureType2D;
197 discriptor.pixelFormat = MTLPixelFormatR32Float;
198 discriptor.width = size[1];
199 discriptor.height = size[0];
200 discriptor.storageMode = MTLStorageModeManaged;
201 discriptor.cpuCacheMode = MTLCPUCacheModeWriteCombined;
202 discriptor.hazardTrackingMode = MTLHazardTrackingModeUntracked;
203 discriptor.usage = MTLTextureUsageShaderRead;
204 texture_arguments[data] = [device newTextureWithDescriptor:discriptor];
205 [texture_arguments[data] replaceRegion:MTLRegionMake2D(0, 0, size[1], size[0])
207 withBytes:
reinterpret_cast<float *
> (data)
208 bytesPerRow:4*size[1]];
210 [encoder optimizeContentsForGPUAccess:texture_arguments[data]];
212 textures.push_back(texture_arguments[data]);
214 [encoder endEncoding];
215 [command_buffer commit];
217 std::vector<NSUInteger> offsets(buffers.size(), 0);
218 NSRange range = NSMakeRange(0, buffers.size());
219 NSRange tex_range = NSMakeRange(0, textures.size());
221 NSUInteger threads_per_group = pipline.maxTotalThreadsPerThreadgroup;
222 NSUInteger thread_width = pipline.threadExecutionWidth;
223 NSUInteger thread_groups = num_rays/threads_per_group + (num_rays%threads_per_group ? 1 : 0);
226 std::cout <<
" Kernel name : " << kernel_name << std::endl;
227 std::cout <<
" Thread execution width : " << thread_width << std::endl;
228 std::cout <<
" Threads per group : " << threads_per_group << std::endl;
229 std::cout <<
" Number of groups : " << thread_groups << std::endl;
230 std::cout <<
" Total problem size : " << threads_per_group*thread_groups << std::endl;
234 return [
this, num_rays, pipline, buffers, offsets, range, tex_range, thread_groups, threads_per_group, textures] ()
mutable {
235 command_buffer = [queue commandBuffer];
236 for (uint32_t i = 0; i < num_rays; i += threads_per_group) {
237 id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDispatchType:MTLDispatchTypeSerial];
239 for (
size_t j = 0, je = buffers.size() - 1; j < je; j++) {
240 offsets[j] = i*
sizeof(float);
243 [encoder setComputePipelineState:pipline];
244 [encoder setBuffers:buffers.data()
245 offsets:offsets.data()
248 length:
sizeof(uint32_t)
249 atIndex:buffers.size()];
250 [encoder setTextures:textures.data()
251 withRange:tex_range];
253 [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1)
254 threadsPerThreadgroup:MTLSizeMake(threads_per_group, 1, 1)];
255 [encoder endEncoding];
258 [command_buffer commit];
261 return [
this, pipline, buffers, offsets, range, tex_range, thread_groups, threads_per_group, textures] ()
mutable {
262 command_buffer = [queue commandBuffer];
263 id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDispatchType:MTLDispatchTypeSerial];
265 [encoder setComputePipelineState:pipline];
266 [encoder setBuffers:buffers.data()
267 offsets:offsets.data()
269 [encoder setTextures:textures.data()
270 withRange:tex_range];
272 [encoder dispatchThreadgroups:MTLSizeMake(thread_groups, 1, 1)
273 threadsPerThreadgroup:MTLSizeMake(threads_per_group, 1, 1)];
274 [encoder endEncoding];
276 [command_buffer commit];
289 std::function<
void(
void)> run) {
290 MTLComputePipelineDescriptor *compute = [MTLComputePipelineDescriptor
new];
291 compute.threadGroupSizeIsMultipleOfThreadExecutionWidth = YES;
292 compute.computeFunction = [library newFunctionWithName:
@"max_reduction"];
293 compute.maxTotalThreadsPerThreadgroup = 1024;
294 compute.buffers[0].mutability = MTLMutabilityImmutable;
297 id<MTLComputePipelineState> max_state = [device newComputePipelineStateWithDescriptor:compute
298 options:MTLPipelineOptionNone
305 id<MTLBuffer> result = [device newBufferWithLength:
sizeof(float)
306 options:MTLResourceStorageModeShared];
308 id<MTLBuffer> buffer = kernel_arguments[argument.get()];
310 NSUInteger threads_per_group = max_state.maxTotalThreadsPerThreadgroup;
311 NSUInteger thread_width = max_state.threadExecutionWidth;
313 std::cout <<
" Kernel name : max_reduction" << std::endl;
314 std::cout <<
" Thread execution width : " << thread_width << std::endl;
315 std::cout <<
" Threads per group : " << threads_per_group << std::endl;
316 std::cout <<
" Number of groups : " << 1 << std::endl;
317 std::cout <<
" Total problem size : " << threads_per_group*1 << std::endl;
320 return [
this, run, buffer, result, max_state] ()
mutable {
322 command_buffer = [queue commandBuffer];
324 id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDispatchType:MTLDispatchTypeSerial];
326 [encoder setComputePipelineState:max_state];
327 [encoder setBuffer:buffer offset:0 atIndex:0];
328 [encoder setBuffer:result offset:0 atIndex:1];
329 [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1)
330 threadsPerThreadgroup:MTLSizeMake(1024, 1, 1)];
331 [encoder endEncoding];
333 [command_buffer commit];
334 [command_buffer waitUntilCompleted];
336 return static_cast<float *
> (result.contents)[0];
344 MTLCompileOptions *options = [MTLCompileOptions
new];
345 options.mathMode = MTLMathModeFast;
346 options.mathFloatingPointFunctions = MTLMathFloatingPointFunctionsFast;
354 command_buffer = [queue commandBuffer];
356 [command_buffer commit];
357 [command_buffer waitUntilCompleted];
369 for (
auto &out : nodes) {
370 std::cout << static_cast<float *> ([kernel_arguments[out.get()] contents])[index] <<
" ";
372 std::cout << std::endl;
385 return static_cast<float *
> ([kernel_arguments[node.get()] contents])[index];
396 const size_t size = [kernel_arguments[node.get()] length];
397 memcpy([kernel_arguments[node.get()] contents],
408 float *destination) {
409 command_buffer = [queue commandBuffer];
411 [command_buffer commit];
412 [command_buffer waitUntilCompleted];
415 kernel_arguments[node.get()].contents,
416 kernel_arguments[node.get()].length);
425 source_buffer <<
"#include <metal_stdlib>" << std::endl;
426 source_buffer <<
"#include <metal_simdgroup>" << std::endl;
427 source_buffer <<
"using namespace metal;" << std::endl;
446 const std::string name,
451 const std::vector<bool> &is_constant,
456 source_buffer << std::endl;
457 source_buffer <<
"kernel void " << name <<
"(" << std::endl;
459 bufferMutability[name] = std::vector<MTLMutability> ();
461 size_t buffer_count = 0;
462 std::unordered_set<void *> used_args;
463 for (
size_t i = 0, ie = inputs.size(); i < ie; i++) {
464 if (!used_args.contains(inputs[i].get())) {
465 bufferMutability[name].push_back(is_constant[i] ? MTLMutabilityMutable : MTLMutabilityImmutable);
466 source_buffer <<
" " << (is_constant[i] ?
"constant" :
"device")
469 <<
" [[buffer(" << buffer_count++ <<
")]], // "
470 << inputs[i]->get_symbol()
471#ifndef USE_INPUT_CACHE
473 <<
" used " << usage.at(inputs[i].get())
477 used_args.insert(inputs[i].get());
480 for (
size_t i = 0, ie = outputs.size(); i < ie; i++) {
481 if (!used_args.contains(outputs[i].get())) {
482 bufferMutability[name].push_back(MTLMutabilityMutable);
483 source_buffer <<
" device float *"
485 <<
" [[buffer(" << buffer_count++ <<
")]],"
487 used_args.insert(outputs[i].get());
491 bufferMutability[name].push_back(MTLMutabilityMutable);
492 source_buffer <<
" device mt_state *"
494 <<
" [[buffer(" << buffer_count++ <<
")]],"
496 <<
" constant uint32_t &offset [[buffer("
497 << buffer_count++ <<
")]],"
501 for (
auto &[key, value] : textures1d) {
502 source_buffer <<
" const texture1d<float, access::read> "
504 <<
" [[texture(" << index++ <<
")]],"
507 for (
auto &[key, value] : textures2d) {
508 source_buffer <<
" const texture2d<float, access::read> "
510 <<
" [[texture(" << index++ <<
")]],"
514 source_buffer <<
" uint thread_index [[thread_index_in_threadgroup]],"
517 source_buffer <<
" uint index [[thread_position_in_grid]]) {" << std::endl
520 source_buffer <<
"offset + ";
522 source_buffer <<
"index < " << size <<
") {" << std::endl;
524 for (
auto &input : inputs) {
525#ifdef USE_INPUT_CACHE
526 if (usage.at(input.get())) {
528 source_buffer <<
" const ";
529 jit::add_type<float> (source_buffer);
530 source_buffer <<
" " << registers[input.get()] <<
" = "
532 <<
"[index]; // " << input->get_symbol()
534 <<
" used " << usage.at(input.get())
539 registers[input.get()] =
jit::to_string(
'v', input.get()) +
"[index]";
543#ifdef USE_INPUT_CACHE
545 source_buffer <<
" device mt_state &" << registers[state.get()]
549 <<
" // used " << usage.at(input.get())
553 registers[state.get()] =
jit::to_string(
's', state.get()) +
"[thread_index]";
576 std::unordered_set<void *> out_registers;
577 for (
auto &[out, in] : setters) {
578 if (!out->is_match(in) &&
579 !out_registers.contains(out.get())) {
587 if constexpr (SAFE_MATH) {
588 source_buffer <<
"isnan(" << registers[a.get()]
591 source_buffer << registers[a.get()] <<
";" << std::endl;
592 out_registers.insert(out.get());
596 for (
auto &out : outputs) {
598 !out_registers.contains(out.get())) {
605 if constexpr (SAFE_MATH) {
606 source_buffer <<
"isnan(" << registers[a.get()]
609 source_buffer << registers[a.get()] <<
";" << std::endl;
610 out_registers.insert(out.get());
614 source_buffer <<
" }" << std::endl <<
"}" << std::endl;
625 source_buffer << std::endl;
626 source_buffer <<
"kernel void max_reduction(" << std::endl;
627 source_buffer <<
" constant float *input [[buffer(0)]]," << std::endl;
628 source_buffer <<
" device float *result [[buffer(1)]]," << std::endl;
629 source_buffer <<
" uint i [[thread_position_in_grid]]," << std::endl;
630 source_buffer <<
" uint j [[simdgroup_index_in_threadgroup]]," << std::endl;
631 source_buffer <<
" uint k [[thread_index_in_simdgroup]]) {" << std::endl;
632 source_buffer <<
" if (i < " << size <<
") {" << std::endl;
633 source_buffer <<
" float sub_max = input[i];" << std::endl;
634 source_buffer <<
" for (size_t index = i + 1024; index < " << size <<
"; index += 1024) {" << std::endl;
635 source_buffer <<
" sub_max = max(sub_max, input[index]);" << std::endl;
636 source_buffer <<
" }" << std::endl;
637 source_buffer <<
" threadgroup float thread_max[32];" << std::endl;
638 source_buffer <<
" thread_max[j] = simd_max(sub_max);" << std::endl;
639 source_buffer <<
" threadgroup_barrier(mem_flags::mem_threadgroup);" << std::endl;
640 source_buffer <<
" if (j == 0) {" << std::endl;
641 source_buffer <<
" *result = simd_max(thread_max[k]);" << std::endl;
642 source_buffer <<
" }" << std::endl;
643 source_buffer <<
" }" << std::endl;
644 source_buffer <<
"}" << std::endl << std::endl;
653 return static_cast<float *
> ([kernel_arguments[node.get()] contents]);
Class representing a generic buffer.
Definition backend.hpp:29
size_t size() const
Get size of the buffer.
Definition backend.hpp:116
T * data()
Get a pointer to the basic memory buffer.
Definition backend.hpp:270
Class representing a metal gpu context.
Definition metal_context.hpp:25
void create_kernel_postfix(std::ostringstream &source_buffer, graph::output_nodes< float, SAFE_MATH > &outputs, graph::map_nodes< float, SAFE_MATH > &setters, graph::shared_random_state< float, SAFE_MATH > state, jit::register_map ®isters, jit::register_map &indices, const jit::register_usage &usage)
Create kernel postfix.
Definition metal_context.hpp:569
void create_reduction(std::ostringstream &source_buffer, const size_t size)
Create reduction.
Definition metal_context.hpp:623
static size_t max_concurrency()
Get the maximum number of concurrent instances.
Definition metal_context.hpp:54
MTLCompileOptions * compile_options()
Get the compile options.
Definition metal_context.hpp:343
void create_kernel_prefix(std::ostringstream &source_buffer, const std::string name, graph::input_nodes< float, SAFE_MATH > &inputs, graph::output_nodes< float, SAFE_MATH > &outputs, graph::shared_random_state< float, SAFE_MATH > state, const size_t size, const std::vector< bool > &is_constant, jit::register_map ®isters, const jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d)
Create kernel prefix.
Definition metal_context.hpp:445
void wait()
Hold the current thread until the command buffer has completed.
Definition metal_context.hpp:353
float check_value(const size_t index, const graph::shared_leaf< float, SAFE_MATH > &node)
Check the value.
Definition metal_context.hpp:382
std::function< void(void)> create_kernel_call(const std::string kernel_name, graph::input_nodes< float, SAFE_MATH > inputs, graph::output_nodes< float, SAFE_MATH > outputs, graph::shared_random_state< float, SAFE_MATH > state, const size_t num_rays, const jit::texture1d_list &tex1d_list, const jit::texture2d_list &tex2d_list)
Create a kernel calling function.
Definition metal_context.hpp:111
void compile(const std::string kernel_source, std::vector< std::string > names, const bool add_reduction=false)
Compile the kernels.
Definition metal_context.hpp:81
void print_results(const size_t index, const graph::output_nodes< float, SAFE_MATH > &nodes)
Print out the results.
Definition metal_context.hpp:366
float * get_buffer(graph::shared_leaf< float, SAFE_MATH > &node)
Get the buffer for a node.
Definition metal_context.hpp:652
static std::string device_type()
Device discription.
Definition metal_context.hpp:61
void copy_to_device(graph::shared_leaf< float, SAFE_MATH > node, float *source)
Copy buffer contents to the device.
Definition metal_context.hpp:394
metal_context(const size_t index)
Construct a metal context.
Definition metal_context.hpp:70
static constexpr size_t random_state_size
Size of random state needed.
Definition metal_context.hpp:44
int remaining_const_memory
Remaining constant memory in bytes. NOT USED.
Definition metal_context.hpp:47
std::function< float(void)> create_max_call(graph::shared_leaf< float, SAFE_MATH > &argument, std::function< void(void)> run)
Create a max compute kernel calling function.
Definition metal_context.hpp:288
void copy_to_host(graph::shared_leaf< float, SAFE_MATH > node, float *destination)
Copy buffer contents to host.
Definition metal_context.hpp:407
void create_header(std::ostringstream &source_buffer)
Create the source header.
Definition metal_context.hpp:424
Name space for GPU backends.
Definition cpu_context.hpp:51
std::shared_ptr< variable_node< T, SAFE_MATH > > shared_variable
Convenience type alias for shared variable nodes.
Definition node.hpp:1727
std::shared_ptr< random_state_node< T, SAFE_MATH > > shared_random_state
Convenience type alias for shared sqrt nodes.
Definition random.hpp:272
std::vector< shared_variable< T, SAFE_MATH > > input_nodes
Convenience type alias for a vector of inputs.
Definition node.hpp:1730
shared_variable< T, SAFE_MATH > variable_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a variable node.
Definition node.hpp:1746
std::shared_ptr< leaf_node< T, SAFE_MATH > > shared_leaf
Convenience type alias for shared leaf nodes.
Definition node.hpp:673
std::vector< std::pair< shared_leaf< T, SAFE_MATH >, shared_variable< T, SAFE_MATH > > > map_nodes
Convenience type alias for maping end codes back to inputs.
Definition node.hpp:1734
std::vector< shared_leaf< T, SAFE_MATH > > output_nodes
Convenience type alias for a vector of output nodes.
Definition node.hpp:688
std::map< void *, size_t > texture1d_list
Type alias for indexing 1D textures.
Definition register.hpp:262
std::map< void *, std::array< size_t, 2 > > texture2d_list
Type alias for indexing 2D textures.
Definition register.hpp:264
std::map< void *, size_t > register_usage
Type alias for counting register usage.
Definition register.hpp:258
std::map< void *, std::string > register_map
Type alias for mapping node pointers to register names.
Definition register.hpp:256
std::string to_string(const char prefix, const NODE *pointer)
Convert a graph::leaf_node pointer to a string.
Definition register.hpp:245
Name space for output files.
Definition output.hpp:16
Random constants and distributions.