Graph Framework
Loading...
Searching...
No Matches
cpu_context.hpp
Go to the documentation of this file.
1//------------------------------------------------------------------------------
6//------------------------------------------------------------------------------
7
8#ifndef cpu_context_h
9#define cpu_context_h
10
11#include <fstream>
12#include <cstdlib>
13#include <cstring>
14#include <thread>
15#include <unordered_set>
16
17// Clang headers will define IBAction and IBOutlet these so undefine them here.
18#undef IBAction
19#undef IBOutlet
20#include "llvm/Support/VirtualFileSystem.h"
21#include "clang/Frontend/TextDiagnosticPrinter.h"
22#include "clang/Frontend/CompilerInvocation.h"
23#include "clang/Frontend/CompilerInstance.h"
24#include "clang/Basic/TargetInfo.h"
25#include "clang/CodeGen/CodeGenAction.h"
26#include "clang/Lex/PreprocessorOptions.h"
27#include "llvm/Support/TargetSelect.h"
28#include "llvm/ExecutionEngine/Orc/LLJIT.h"
29#ifndef NDEBUG
30#include "llvm/ExecutionEngine/Orc/Debugging/DebuggerSupport.h"
31#include "llvm/ExecutionEngine/Orc/TargetProcess/JITLoaderGDB.h"
32#endif
33#include "llvm/Support/raw_ostream.h"
34#include "llvm/ADT/IntrusiveRefCntPtr.h"
35#include "llvm/ADT/SmallVector.h"
36#include "llvm/TargetParser/Host.h"
37#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h"
38
39#include "random.hpp"
40
41#ifndef NDEBUG
42//------------------------------------------------------------------------------
44//------------------------------------------------------------------------------
45LLVM_ATTRIBUTE_USED void linkComponents() {
46 llvm::errs() << (void *)&llvm_orc_registerJITLoaderGDBWrapper
47 << (void *)&llvm_orc_registerJITLoaderGDBAllocAction;
48}
49#endif
50
51namespace gpu {
52//------------------------------------------------------------------------------
61//------------------------------------------------------------------------------
62 llvm::SmallVector<const char *, 8> split_string(char *string) {
63 llvm::SmallVector<const char *, 8> args = {string};
64
65 while (*(++string) != '\0') {
66 if (*string == ' ') {
67 *string = '\0';
68 args.push_back(++string);
69 }
70 }
71
72 return args;
73 }
74
75//------------------------------------------------------------------------------
80//------------------------------------------------------------------------------
81 template<jit::float_scalar T, bool SAFE_MATH=false>
83 private:
85 std::unique_ptr<llvm::orc::LLJIT> jit;
87 std::map<graph::leaf_node<T, SAFE_MATH> *, std::vector<T>> kernel_arguments;
89 std::map<graph::leaf_node<T, SAFE_MATH> *, std::vector<T>> host_buffers;
91 std::map<graph::leaf_node<T, SAFE_MATH> *, size_t> arg_index;
92
93 public:
95 constexpr static size_t random_state_size = 1;
96
99
100//------------------------------------------------------------------------------
104//------------------------------------------------------------------------------
105 static size_t max_concurrency() {
106 return std::thread::hardware_concurrency();
107 }
108
109//------------------------------------------------------------------------------
111//------------------------------------------------------------------------------
112 static std::string device_type() {
113 return "CPU";
114 }
115
116//------------------------------------------------------------------------------
120//------------------------------------------------------------------------------
121 cpu_context(const size_t index) {
122 llvm::InitializeNativeTarget();
123 llvm::InitializeNativeTargetAsmPrinter();
124 }
125
126//------------------------------------------------------------------------------
132//------------------------------------------------------------------------------
133 void compile(const std::string kernel_source,
134 std::vector<std::string> names,
135 const bool add_reduction=false) {
136 std::ostringstream temp_stream;
137 temp_stream << reinterpret_cast<size_t> (this);
138 const std::string thread_id = temp_stream.str();
139
140 temp_stream.str(std::string());
141 temp_stream.clear();
142
143 temp_stream << "temp_" << thread_id << ".cpp";
144
145 const std::string filename = temp_stream.str();
146
147 if (jit::verbose) {
148 std::cout << "CPU info." << std::endl;
149 std::cout << " Command Line : " << std::endl;
150 }
151
152 char arg_string[] = CXX_ARGS;
153 llvm::SmallVector<const char *, 8> args = split_string(arg_string);
154 args.push_back(filename.c_str());
155#ifdef NDEBUG
156 args.push_back("-ffast-math");
157 args.push_back("-O3");
158#else
159 args.push_back("-debug-info-kind=standalone");
160#endif
161 if (jit::verbose) {
162 for (auto &arg : args) {
163 std::cout << " " << arg << std::endl;
164 }
165 }
166
167 clang::DiagnosticOptions diagnostic_options;
168 auto diagnostic_printer = std::make_unique<clang::TextDiagnosticPrinter> (llvm::errs(),
169 diagnostic_options);
170
171 auto diagnostic_ids = llvm::makeIntrusiveRefCnt<clang::DiagnosticIDs> ();
172 clang::DiagnosticsEngine diagnostic_engine(diagnostic_ids,
173 diagnostic_options,
174 diagnostic_printer.release());
175
176 auto invocation = std::make_shared<clang::CompilerInvocation> ();
177 clang::CompilerInvocation::CreateFromArgs(*(invocation.get()), args,
178 diagnostic_engine);
179
180 llvm::StringRef source_code_data(kernel_source);
181 auto buffer = llvm::MemoryBuffer::getMemBuffer(source_code_data);
182 invocation->getPreprocessorOpts().addRemappedFile(filename.c_str(),
183 buffer.release());
184
185 clang::CompilerInstance clang(invocation);
186 clang.createDiagnostics();
187
188 clang::TargetOptions target_options;
189 target_options.Triple = llvm::sys::getProcessTriple();
190 auto *target_info = clang::TargetInfo::CreateTargetInfo(diagnostic_engine,
191 target_options);
192 clang.setTarget(target_info);
193
194 clang::EmitLLVMOnlyAction action;
195 clang.ExecuteAction(action);
196
197 auto ir_module = action.takeModule();
198 auto context = std::unique_ptr<llvm::LLVMContext> (action.takeLLVMContext());
199
200 auto jit_try = llvm::orc::LLJITBuilder()
201#ifndef NDEBUG
202 .setPrePlatformSetup([](llvm::orc::LLJIT &J) {
203 return llvm::orc::enableDebuggerSupport(J);
204 })
205#endif
206 .create();
207 if (auto jiterror = jit_try.takeError()) {
208 std::cerr << "Failed to build JIT : " << toString(std::move(jiterror)) << std::endl;
209 exit(-1);
210 }
211 jit = std::move(jit_try.get());
212
213 auto error = jit->addIRModule(llvm::orc::ThreadSafeModule(std::move(ir_module),
214 llvm::orc::ThreadSafeContext(std::move(context))));
215
216#ifdef MACOS_LIB_RT
217 error = jit->linkStaticLibraryInto(jit->getMainJITDylib(), MACOS_LIB_RT);
218#endif
219 }
220
221//------------------------------------------------------------------------------
232//------------------------------------------------------------------------------
233 std::function<void(void)> create_kernel_call(const std::string kernel_name,
237 const size_t num_rays,
238 const jit::texture1d_list &tex1d_list,
239 const jit::texture2d_list &tex2d_list) {
240 auto entry = std::move(jit->lookup(kernel_name)).get();
241
242 std::map<size_t, T *> buffers;
243
244 for (auto &input : inputs) {
245 if (!kernel_arguments.contains(input.get())) {
246 backend::buffer<T> buffer = input->evaluate();
247 std::vector<T> arg(buffer.size());
248 memcpy(arg.data(), buffer.data(), buffer.size()*sizeof(T));
249 kernel_arguments[input.get()] = arg;
250 }
251 buffers[reinterpret_cast<size_t> (input.get())] = kernel_arguments[input.get()].data();
252 }
253 for (auto &output : outputs) {
254 if (!kernel_arguments.contains(output.get())) {
255 std::vector<T> arg(num_rays);
256 kernel_arguments[output.get()] = arg;
257 }
258 buffers[reinterpret_cast<size_t> (output.get())] = kernel_arguments[output.get()].data();
259 }
260
261 if (state.get()) {
262 auto kernel = entry.toPtr<void(*)(std::map<size_t, T *> &, typename graph::random_state_node<T, SAFE_MATH>::mt_state *)> ();
263
264 if (!kernel) {
265 std::cerr << "Failed to load function. " << kernel_name
266 << std::endl;
267 exit(-1);
268 }
269
270 if (jit::verbose) {
271 std::cout << " Function pointer: "
272 << reinterpret_cast<size_t> (kernel)
273 << std::endl;
274 }
275
276 return [kernel, buffers, state] () mutable {
277 kernel(buffers, state->data());
278 };
279 } else {
280 auto kernel = entry.toPtr<void(*)(std::map<size_t, T *> &)> ();
281
282 if (!kernel) {
283 std::cerr << "Failed to load function. " << kernel_name
284 << std::endl;
285 exit(-1);
286 }
287
288 if (jit::verbose) {
289 std::cout << " Function pointer: "
290 << reinterpret_cast<size_t> (kernel)
291 << std::endl;
292 }
293
294 return [kernel, buffers] () mutable {
295 kernel(buffers);
296 };
297 }
298 }
299
300//------------------------------------------------------------------------------
305//------------------------------------------------------------------------------
306 std::function<T(void)> create_max_call(graph::shared_leaf<T, SAFE_MATH> &argument,
307 std::function<void(void)> run) {
308 auto begin = kernel_arguments[argument.get()].cbegin();
309 auto end = kernel_arguments[argument.get()].cend();
310
311 return [run, begin, end] () mutable {
312 run();
313 if constexpr (jit::complex_scalar<T>) {
314 return *std::max_element(begin, end,
315 [] (const T a, const T b) {
316 return std::abs(a) < std::abs(b);
317 });
318 } else {
319 return *std::max_element(begin, end);
320 }
321 };
322 }
323
324//------------------------------------------------------------------------------
329//------------------------------------------------------------------------------
330 void wait() {
331 for (auto &item : host_buffers) {
332 memcpy(item.second.data(),
333 kernel_arguments[item.first].data(),
334 sizeof(T)*kernel_arguments[item.first].size());
335 }
336 }
337
338//------------------------------------------------------------------------------
343//------------------------------------------------------------------------------
344 void print_results(const size_t index,
346 for (auto &out : nodes) {
347 const T temp = kernel_arguments[out.get()][index];
348 if constexpr (jit::complex_scalar<T>) {
349 std::cout << std::real(temp) << " " << std::imag(temp) << " ";
350 } else {
351 std::cout << temp << " ";
352 }
353 }
354 std::cout << std::endl;
355 }
356
357//------------------------------------------------------------------------------
363//------------------------------------------------------------------------------
364 T check_value(const size_t index,
366 return kernel_arguments[node.get()][index];
367 }
368
369//------------------------------------------------------------------------------
374//------------------------------------------------------------------------------
376 T *source) {
377 memcpy(kernel_arguments[node.get()].data(),
378 source,
379 sizeof(T)*kernel_arguments[node.get()].size());
380 }
381
382//------------------------------------------------------------------------------
387//------------------------------------------------------------------------------
389 T *destination) {
390 memcpy(destination,
391 kernel_arguments[node.get()].data(),
392 sizeof(T)*kernel_arguments[node.get()].size());
393 }
394
395//------------------------------------------------------------------------------
399//------------------------------------------------------------------------------
400 void create_header(std::ostringstream &source_buffer) {
401 source_buffer << "#include <map>" << std::endl
402 << "#include <array>" << std::endl
403 << "#include <cstdint>" << std::endl;
405 source_buffer << "#include <complex>" << std::endl;
406 source_buffer << "#include <special_functions.hpp>" << std::endl;
407 } else {
408 source_buffer << "#include <cmath>" << std::endl;
409 }
410 source_buffer << "using namespace std;" << std::endl;
411 }
412
413//------------------------------------------------------------------------------
427//------------------------------------------------------------------------------
428 void create_kernel_prefix(std::ostringstream &source_buffer,
429 const std::string name,
433 const size_t size,
434 const std::vector<bool> &is_constant,
435 jit::register_map &registers,
436 const jit::register_usage &usage,
437 jit::texture1d_list &textures1d,
438 jit::texture2d_list &textures2d) {
439 source_buffer << std::endl;
440 source_buffer << "extern \"C\" void " << name << "(" << std::endl;
441
442 source_buffer << " map<size_t, ";
443 jit::add_type<T> (source_buffer);
444 source_buffer << " *> &args";
445 if (state.get()) {
446 source_buffer << "," << std::endl
447 << " mt_state *" << jit::to_string('s', state.get());
448 }
449 source_buffer << ") {" << std::endl;
450
451 std::unordered_set<void *> used_args;
452 for (size_t i = 0, ie = inputs.size(); i < ie; i++) {
453 if (!used_args.contains(inputs[i].get())) {
454 source_buffer << " ";
455 if (is_constant[i]) {
456 source_buffer << "const ";
457 }
458 jit::add_type<T> (source_buffer);
459 source_buffer << " *" << jit::to_string('v', inputs[i].get())
460 << " = args["
461 << reinterpret_cast<size_t> (inputs[i].get())
462 << "];" << std::endl;
463 used_args.insert(inputs[i].get());
464 }
465 }
466 for (auto &output : outputs) {
467 if (!used_args.contains(output.get())) {
468 source_buffer << " ";
469 jit::add_type<T> (source_buffer);
470 source_buffer << " *" << jit::to_string('o', output.get())
471 << " = args["
472 << reinterpret_cast<size_t> (output.get())
473 << "];" << std::endl;
474 used_args.insert(output.get());
475 }
476 }
477 if (state.get()) {
478 registers[state.get()] = jit::to_string('r', state.get());
479 source_buffer << " mt_state &"
480 << registers[state.get()] << " = "
481 << jit::to_string('s', state.get()) << "[0];"
482#ifdef SHOW_USE_COUNT
483 << " // used " << usage.at(state.get())
484#endif
485 << std::endl;
486 }
487 source_buffer << " for (size_t i = 0; i < " << size << "; i++) {" << std::endl;
488
489 for (auto &input : inputs) {
490 registers[input.get()] = jit::to_string('r', input.get());
491 source_buffer << " const ";
492 jit::add_type<T> (source_buffer);
493 source_buffer << " " << registers[input.get()]
494 << " = " << jit::to_string('v', input.get())
495 << "[i]; // " << input->get_symbol()
496#ifdef SHOW_USE_COUNT
497 << " used " << usage.at(input.get())
498#endif
499 << std::endl;
500 }
501 }
502
503//------------------------------------------------------------------------------
513//------------------------------------------------------------------------------
514 void create_kernel_postfix(std::ostringstream &source_buffer,
518 jit::register_map &registers,
519 jit::register_map &indices,
520 const jit::register_usage &usage) {
521 std::unordered_set<void *> out_registers;
522 for (auto &[out, in] : setters) {
523 if (!out->is_match(in) &&
524 !out_registers.contains(out.get())) {
525 graph::shared_leaf<T, SAFE_MATH> a = out->compile(source_buffer,
526 registers,
527 indices,
528 usage);
529 source_buffer << " " << jit::to_string('v', in.get());
530 source_buffer << "[i] = ";
531 if constexpr (SAFE_MATH) {
532 if constexpr (jit::complex_scalar<T>) {
533 jit::add_type<T> (source_buffer);
534 source_buffer << " (";
535 source_buffer << "isnan(real(" << registers[a.get()]
536 << ")) ? 0.0 : real(" << registers[a.get()]
537 << "), ";
538 source_buffer << "isnan(imag(" << registers[a.get()]
539 << ")) ? 0.0 : imag(" << registers[a.get()]
540 << "));" << std::endl;
541 } else {
542 source_buffer << "isnan(" << registers[a.get()]
543 << ") ? 0.0 : " << registers[a.get()]
544 << ";" << std::endl;
545 }
546 } else {
547 source_buffer << registers[a.get()] << ";" << std::endl;
548 }
549 out_registers.insert(out.get());
550 }
551 }
552 for (auto &out : outputs) {
553 if (!graph::variable_cast(out).get() &&
554 !out_registers.contains(out.get())) {
555 graph::shared_leaf<T, SAFE_MATH> a = out->compile(source_buffer,
556 registers,
557 indices,
558 usage);
559 source_buffer << " " << jit::to_string('o', out.get());
560 source_buffer << "[i] = ";
561 if constexpr (SAFE_MATH) {
562 if constexpr (jit::complex_scalar<T>) {
563 jit::add_type<T> (source_buffer);
564 source_buffer << " (";
565 source_buffer << "isnan(real(" << registers[a.get()]
566 << ")) ? 0.0 : real(" << registers[a.get()]
567 << "), ";
568 source_buffer << "isnan(imag(" << registers[a.get()]
569 << ")) ? 0.0 : imag(" << registers[a.get()]
570 << "));" << std::endl;
571 } else {
572 source_buffer << "isnan(" << registers[a.get()]
573 << ") ? 0.0 : " << registers[a.get()]
574 << ";" << std::endl;
575 }
576 } else {
577 source_buffer << registers[a.get()] << ";" << std::endl;
578 }
579 out_registers.insert(out.get());
580 }
581 }
582
583 source_buffer << " }" << std::endl;
584 source_buffer << "}" << std::endl;
585 }
586
587//------------------------------------------------------------------------------
592//------------------------------------------------------------------------------
593 void create_reduction(std::ostringstream &source_buffer,
594 const size_t size) {}
595
596//------------------------------------------------------------------------------
605//------------------------------------------------------------------------------
607 if (!host_buffers.contains(node.get())) {
608 host_buffers[node.get()] = kernel_arguments[node.get()];
609 }
610 return host_buffers[node.get()].data();
611 }
612 };
613}
614
615#endif /* cpu_context_h */
Class representing a generic buffer.
Definition backend.hpp:29
size_t size() const
Get size of the buffer.
Definition backend.hpp:116
T * data()
Get a pointer to the basic memory buffer.
Definition backend.hpp:270
Class representing a cpu context.
Definition cpu_context.hpp:82
void print_results(const size_t index, const graph::output_nodes< T, SAFE_MATH > &nodes)
Print out the results.
Definition cpu_context.hpp:344
void copy_to_device(graph::shared_leaf< T, SAFE_MATH > node, T *source)
Copy buffer contents to the device.
Definition cpu_context.hpp:375
void create_kernel_prefix(std::ostringstream &source_buffer, const std::string name, graph::input_nodes< T, SAFE_MATH > &inputs, graph::output_nodes< T, SAFE_MATH > &outputs, graph::shared_random_state< T, SAFE_MATH > state, const size_t size, const std::vector< bool > &is_constant, jit::register_map &registers, const jit::register_usage &usage, jit::texture1d_list &textures1d, jit::texture2d_list &textures2d)
Create kernel prefix.
Definition cpu_context.hpp:428
T check_value(const size_t index, const graph::shared_leaf< T, SAFE_MATH > &node)
Check the value.
Definition cpu_context.hpp:364
void create_reduction(std::ostringstream &source_buffer, const size_t size)
Create a reduction kernel.
Definition cpu_context.hpp:593
cpu_context(const size_t index)
Construct a cpu context.
Definition cpu_context.hpp:121
void wait()
Hold the current thread until the command buffer has completed.
Definition cpu_context.hpp:330
static constexpr size_t random_state_size
Size of random state needed.
Definition cpu_context.hpp:95
std::function< void(void)> create_kernel_call(const std::string kernel_name, graph::input_nodes< T, SAFE_MATH > inputs, graph::output_nodes< T, SAFE_MATH > outputs, graph::shared_random_state< T, SAFE_MATH > state, const size_t num_rays, const jit::texture1d_list &tex1d_list, const jit::texture2d_list &tex2d_list)
Create a kernel calling function.
Definition cpu_context.hpp:233
void copy_to_host(const graph::shared_leaf< T, SAFE_MATH > node, T *destination)
Copy buffer contents to host.
Definition cpu_context.hpp:388
static std::string device_type()
Device discription.
Definition cpu_context.hpp:112
T * get_buffer(graph::shared_leaf< T, SAFE_MATH > &node)
Get the buffer for a node.
Definition cpu_context.hpp:606
void compile(const std::string kernel_source, std::vector< std::string > names, const bool add_reduction=false)
Compile the kernels.
Definition cpu_context.hpp:133
int remaining_const_memory
Remaining constant memory in bytes. NOT USED.
Definition cpu_context.hpp:98
std::function< T(void)> create_max_call(graph::shared_leaf< T, SAFE_MATH > &argument, std::function< void(void)> run)
Create a max compute pipeline.
Definition cpu_context.hpp:306
void create_header(std::ostringstream &source_buffer)
Create the source header.
Definition cpu_context.hpp:400
static size_t max_concurrency()
Get the maximum number of concurrent instances.
Definition cpu_context.hpp:105
void create_kernel_postfix(std::ostringstream &source_buffer, graph::output_nodes< T, SAFE_MATH > &outputs, graph::map_nodes< T, SAFE_MATH > &setters, graph::shared_random_state< T, SAFE_MATH > state, jit::register_map &registers, jit::register_map &indices, const jit::register_usage &usage)
Create kernel postfix.
Definition cpu_context.hpp:514
Complex scalar concept.
Definition register.hpp:24
LLVM_ATTRIBUTE_USED void linkComponents()
This just exposes the functions so the debugger links.
Definition cpu_context.hpp:45
Name space for GPU backends.
Definition cpu_context.hpp:51
llvm::SmallVector< const char *, 8 > split_string(char *string)
Split a string by the space delimiter.
Definition cpu_context.hpp:62
std::shared_ptr< random_state_node< T, SAFE_MATH > > shared_random_state
Convenience type alias for shared sqrt nodes.
Definition random.hpp:272
std::vector< shared_variable< T, SAFE_MATH > > input_nodes
Convenience type alias for a vector of inputs.
Definition node.hpp:1730
shared_variable< T, SAFE_MATH > variable_cast(shared_leaf< T, SAFE_MATH > x)
Cast to a variable node.
Definition node.hpp:1746
std::shared_ptr< leaf_node< T, SAFE_MATH > > shared_leaf
Convenience type alias for shared leaf nodes.
Definition node.hpp:673
std::vector< std::pair< shared_leaf< T, SAFE_MATH >, shared_variable< T, SAFE_MATH > > > map_nodes
Convenience type alias for maping end codes back to inputs.
Definition node.hpp:1734
std::vector< shared_leaf< T, SAFE_MATH > > output_nodes
Convenience type alias for a vector of output nodes.
Definition node.hpp:688
Name space for JIT functions.
Definition jit.hpp:41
std::map< void *, size_t > texture1d_list
Type alias for indexing 1D textures.
Definition register.hpp:262
std::map< void *, std::array< size_t, 2 > > texture2d_list
Type alias for indexing 2D textures.
Definition register.hpp:264
std::map< void *, size_t > register_usage
Type alias for counting register usage.
Definition register.hpp:258
std::map< void *, std::string > register_map
Type alias for mapping node pointers to register names.
Definition register.hpp:256
std::string to_string(const char prefix, const NODE *pointer)
Convert a graph::leaf_node pointer to a string.
Definition register.hpp:245
Name space for output files.
Definition output.hpp:16
Random constants and distributions.
Random state structure.
Definition random.hpp:29