115 const size_t num_rays,
120 id<MTLFunction> function = [library newFunctionWithName:[NSString stringWithCString:kernel_name.c_str()
121 encoding:NSUTF8StringEncoding]];
123 MTLComputePipelineDescriptor *compute = [MTLComputePipelineDescriptor
new];
124 compute.threadGroupSizeIsMultipleOfThreadExecutionWidth = YES;
125 compute.computeFunction = function;
126 compute.maxTotalThreadsPerThreadgroup = 1024;
127 for (
size_t i = 0, ie = bufferMutability[kernel_name].size(); i < ie; i++) {
128 compute.buffers[i].mutability = bufferMutability[kernel_name][i];
131 id<MTLComputePipelineState> pipline = [device newComputePipelineStateWithDescriptor:compute
132 options:MTLPipelineOptionNone
140 std::vector<id<MTLBuffer>> buffers;
141 std::set<graph::leaf_node<float, SAFE_MATH> *> needed_buffers;
143 const size_t buffer_element_size =
sizeof(float);
145 if (!kernel_arguments.contains(input.get())) {
147 kernel_arguments[input.get()] = [device newBufferWithBytes:buffer.
data()
148 length:buffer.
size()*buffer_element_size
149 options:MTLResourceStorageModeShared];
150 buffers.push_back(kernel_arguments[input.get()]);
151 needed_buffers.insert(input.get());
153 if (!needed_buffers.contains(input.get())) {
154 buffers.push_back(kernel_arguments[input.get()]);
155 needed_buffers.insert(input.get());
159 if (!kernel_arguments.contains(
output.get())) {
160 kernel_arguments[
output.get()] = [device newBufferWithLength:num_rays*
sizeof(float)
161 options:MTLResourceStorageModeShared];
162 buffers.push_back(kernel_arguments[
output.get()]);
163 needed_buffers.insert(
output.get());
165 if (!needed_buffers.contains(
output.get())) {
166 buffers.push_back(kernel_arguments[
output.get()]);
167 needed_buffers.insert(
output.get());
171 if (!kernel_arguments.contains(state.get())) {
172 kernel_arguments[state.get()] = [device newBufferWithBytes:state->data()
173 length:state->get_size_bytes()
174 options:MTLResourceCPUCacheModeWriteCombined |
175 MTLResourceStorageModeShared |
176 MTLResourceHazardTrackingModeUntracked];
178 buffers.push_back(kernel_arguments[state.get()]);
181 std::vector<id<MTLTexture>> textures;
182 command_buffer = [queue commandBuffer];
183 id<MTLBlitCommandEncoder> encoder = [command_buffer blitCommandEncoder];
184 for (
auto &[data, size] : tex1d_list) {
185 if (!texture_arguments.contains(data)) {
186 MTLTextureDescriptor *descriptor = [MTLTextureDescriptor
new];
187 descriptor.textureType = MTLTextureType1D;
188 descriptor.pixelFormat = MTLPixelFormatR32Float;
189 descriptor.width = size;
190 descriptor.storageMode = MTLStorageModeManaged;
191 descriptor.cpuCacheMode = MTLCPUCacheModeWriteCombined;
192 descriptor.hazardTrackingMode = MTLHazardTrackingModeUntracked;
193 descriptor.usage = MTLTextureUsageShaderRead;
194 texture_arguments[data] = [device newTextureWithDescriptor:descriptor];
195 [texture_arguments[data] replaceRegion:MTLRegionMake1D(0, size)
197 withBytes:
reinterpret_cast<float *
> (data)
200 [encoder optimizeContentsForGPUAccess:texture_arguments[data]];
202 textures.push_back(texture_arguments[data]);
204 for (
auto &[data, size] : tex2d_list) {
205 if (!texture_arguments.contains(data)) {
206 MTLTextureDescriptor *descriptor = [MTLTextureDescriptor
new];
207 descriptor.textureType = MTLTextureType2D;
208 descriptor.pixelFormat = MTLPixelFormatR32Float;
209 descriptor.width = size[1];
210 descriptor.height = size[0];
211 descriptor.storageMode = MTLStorageModeManaged;
212 descriptor.cpuCacheMode = MTLCPUCacheModeWriteCombined;
213 descriptor.hazardTrackingMode = MTLHazardTrackingModeUntracked;
214 descriptor.usage = MTLTextureUsageShaderRead;
215 texture_arguments[data] = [device newTextureWithDescriptor:descriptor];
216 [texture_arguments[data] replaceRegion:MTLRegionMake2D(0, 0, size[1], size[0])
218 withBytes:
reinterpret_cast<float *
> (data)
219 bytesPerRow:4*size[1]];
221 [encoder optimizeContentsForGPUAccess:texture_arguments[data]];
223 textures.push_back(texture_arguments[data]);
225 [encoder endEncoding];
226 [command_buffer commit];
228 std::vector<NSUInteger> offsets(buffers.size(), 0);
229 NSRange range = NSMakeRange(0, buffers.size());
230 NSRange tex_range = NSMakeRange(0, textures.size());
232 NSUInteger threads_per_group = pipline.maxTotalThreadsPerThreadgroup;
233 NSUInteger thread_width = pipline.threadExecutionWidth;
234 NSUInteger thread_groups = num_rays/threads_per_group + (num_rays%threads_per_group ? 1 : 0);
237 std::cout <<
" Kernel name : " << kernel_name << std::endl;
238 std::cout <<
" Thread execution width : " << thread_width << std::endl;
239 std::cout <<
" Threads per group : " << threads_per_group << std::endl;
240 std::cout <<
" Number of groups : " << thread_groups << std::endl;
241 std::cout <<
" Total problem size : " << threads_per_group*thread_groups << std::endl;
245 return [
this, num_rays, pipline, buffers, offsets, range, tex_range, thread_groups, threads_per_group, textures] ()
mutable {
246 command_buffer = [queue commandBuffer];
247 for (uint32_t i = 0; i < num_rays; i += threads_per_group) {
248 id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDispatchType:MTLDispatchTypeSerial];
250 for (
size_t j = 0, je = buffers.size() - 1; j < je; j++) {
251 offsets[j] = i*
sizeof(float);
254 [encoder setComputePipelineState:pipline];
255 [encoder setBuffers:buffers.data()
256 offsets:offsets.data()
259 length:
sizeof(uint32_t)
260 atIndex:buffers.size()];
261 [encoder setTextures:textures.data()
262 withRange:tex_range];
264 [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1)
265 threadsPerThreadgroup:MTLSizeMake(threads_per_group, 1, 1)];
266 [encoder endEncoding];
269 [command_buffer commit];
272 return [
this, pipline, buffers, offsets, range, tex_range, thread_groups, threads_per_group, textures] ()
mutable {
273 command_buffer = [queue commandBuffer];
274 id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDispatchType:MTLDispatchTypeSerial];
276 [encoder setComputePipelineState:pipline];
277 [encoder setBuffers:buffers.data()
278 offsets:offsets.data()
280 [encoder setTextures:textures.data()
281 withRange:tex_range];
283 [encoder dispatchThreadgroups:MTLSizeMake(thread_groups, 1, 1)
284 threadsPerThreadgroup:MTLSizeMake(threads_per_group, 1, 1)];
285 [encoder endEncoding];
287 [command_buffer commit];
300 std::function<
void(
void)> run) {
301 MTLComputePipelineDescriptor *compute = [MTLComputePipelineDescriptor
new];
302 compute.threadGroupSizeIsMultipleOfThreadExecutionWidth = YES;
303 compute.computeFunction = [library newFunctionWithName:
@"max_reduction"];
304 compute.maxTotalThreadsPerThreadgroup = 1024;
305 compute.buffers[0].mutability = MTLMutabilityImmutable;
308 id<MTLComputePipelineState> max_state = [device newComputePipelineStateWithDescriptor:compute
309 options:MTLPipelineOptionNone
316 id<MTLBuffer> result = [device newBufferWithLength:
sizeof(float)
317 options:MTLResourceStorageModeShared];
319 id<MTLBuffer> buffer = kernel_arguments[argument.get()];
321 NSUInteger threads_per_group = max_state.maxTotalThreadsPerThreadgroup;
322 NSUInteger thread_width = max_state.threadExecutionWidth;
324 std::cout <<
" Kernel name : max_reduction" << std::endl;
325 std::cout <<
" Thread execution width : " << thread_width << std::endl;
326 std::cout <<
" Threads per group : " << threads_per_group << std::endl;
327 std::cout <<
" Number of groups : " << 1 << std::endl;
328 std::cout <<
" Total problem size : " << threads_per_group*1 << std::endl;
331 return [
this, run, buffer, result, max_state] ()
mutable {
333 command_buffer = [queue commandBuffer];
335 id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDispatchType:MTLDispatchTypeSerial];
337 [encoder setComputePipelineState:max_state];
338 [encoder setBuffer:buffer offset:0 atIndex:0];
339 [encoder setBuffer:result offset:0 atIndex:1];
340 [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1)
341 threadsPerThreadgroup:MTLSizeMake(1024, 1, 1)];
342 [encoder endEncoding];
344 [command_buffer commit];
345 [command_buffer waitUntilCompleted];
347 return static_cast<float *
> (result.contents)[0];
457 const std::string name,
462 const std::vector<bool> &is_constant,
467 source_buffer << std::endl;
468 source_buffer <<
"kernel void " << name <<
"(" << std::endl;
470 bufferMutability[name] = std::vector<MTLMutability> ();
472 size_t buffer_count = 0;
473 std::unordered_set<void *> used_args;
474 for (
size_t i = 0, ie = inputs.size(); i < ie; i++) {
475 if (!used_args.contains(inputs[i].get())) {
476 bufferMutability[name].push_back(is_constant[i] ? MTLMutabilityMutable : MTLMutabilityImmutable);
477 source_buffer <<
" " << (is_constant[i] ?
"constant" :
"device")
480 <<
" [[buffer(" << buffer_count++ <<
")]], // "
481 << inputs[i]->get_symbol()
482#ifndef USE_INPUT_CACHE
484 <<
" used " << usage.at(inputs[i].get())
488 used_args.insert(inputs[i].get());
491 for (
size_t i = 0, ie = outputs.size(); i < ie; i++) {
492 if (!used_args.contains(outputs[i].get())) {
493 bufferMutability[name].push_back(MTLMutabilityMutable);
494 source_buffer <<
" device float *"
496 <<
" [[buffer(" << buffer_count++ <<
")]],"
498 used_args.insert(outputs[i].get());
502 bufferMutability[name].push_back(MTLMutabilityMutable);
503 source_buffer <<
" device mt_state *"
505 <<
" [[buffer(" << buffer_count++ <<
")]],"
507 <<
" constant uint32_t &offset [[buffer("
508 << buffer_count++ <<
")]],"
512 for (
auto &[key, value] : textures1d) {
513 source_buffer <<
" const texture1d<float, access::read> "
515 <<
" [[texture(" << index++ <<
")]],"
518 for (
auto &[key, value] : textures2d) {
519 source_buffer <<
" const texture2d<float, access::read> "
521 <<
" [[texture(" << index++ <<
")]],"
525 source_buffer <<
" uint thread_index [[thread_index_in_threadgroup]],"
528 source_buffer <<
" uint index [[thread_position_in_grid]]) {" << std::endl
531 source_buffer <<
"offset + ";
533 source_buffer <<
"index < " << size <<
") {" << std::endl;
535 for (
auto &input : inputs) {
536#ifdef USE_INPUT_CACHE
537 if (usage.at(input.get())) {
539 source_buffer <<
" const ";
540 jit::add_type<float> (source_buffer);
541 source_buffer <<
" " << registers[input.get()] <<
" = "
543 <<
"[index]; // " << input->get_symbol()
545 <<
" used " << usage.at(input.get())
550 registers[input.get()] =
jit::to_string(
'v', input.get()) +
"[index]";
554#ifdef USE_INPUT_CACHE
556 source_buffer <<
" device mt_state &" << registers[state.get()]
560 <<
" // used " << usage.at(input.get())
564 registers[state.get()] =
jit::to_string(
's', state.get()) +
"[thread_index]";