Graph Framework
Loading...
Searching...
No Matches
output.hpp
Go to the documentation of this file.
1//------------------------------------------------------------------------------
4//------------------------------------------------------------------------------
5
6#ifndef output_h
7#define output_h
8
9#include <mutex>
10
11#include <netcdf.h>
12
13#include "jit.hpp"
14
16namespace output {
18 static std::mutex sync;
19
20//------------------------------------------------------------------------------
24//------------------------------------------------------------------------------
25 static void check_error(const int status) {
26 assert(status == NC_NOERR && nc_strerror(status));
27 }
28
29//------------------------------------------------------------------------------
31//------------------------------------------------------------------------------
33 private:
35 int ncid;
37 int unlimited_dim;
39 int num_rays_dim;
41 size_t num_rays;
42
43 public:
44//------------------------------------------------------------------------------
49//------------------------------------------------------------------------------
50 result_file(const std::string &filename,
51 const size_t num_rays) :
52 num_rays(num_rays) {
53 const std::string temp = filename.empty() ? jit::format_to_string(reinterpret_cast<size_t> (this)) :
54 filename;
55
56 sync.lock();
57 check_error(nc_create(temp.c_str(),
58 filename.empty() || num_rays == 0 ? NC_DISKLESS : NC_CLOBBER,
59 &ncid));
60
61 check_error(nc_def_dim(ncid, "time", NC_UNLIMITED, &unlimited_dim));
62 check_error(nc_def_dim(ncid, "num_rays",
63 num_rays ? num_rays : 1,
64 &num_rays_dim));
65 sync.unlock();
66 }
67
68//------------------------------------------------------------------------------
72//------------------------------------------------------------------------------
73 result_file(const std::string &filename) {
74 sync.lock();
75 check_error(nc_open(filename.c_str(), NC_WRITE, &ncid));
76
77 check_error(nc_inq_dimid(ncid, "time", &unlimited_dim));
78 check_error(nc_inq_dimid(ncid, "num_rays", &num_rays_dim));
79 check_error(nc_inq_dimlen(ncid, num_rays_dim, &num_rays));
80 check_error(nc_redef(ncid));
81 sync.unlock();
82 }
83
84//------------------------------------------------------------------------------
86//------------------------------------------------------------------------------
88 check_error(nc_close(ncid));
89 }
90
91//------------------------------------------------------------------------------
93//------------------------------------------------------------------------------
94 void end_define_mode() const {
95 sync.lock();
96 check_error(nc_enddef(ncid));
97 sync.unlock();
98 }
99
100//------------------------------------------------------------------------------
104//------------------------------------------------------------------------------
105 int get_ncid() const {
106 return ncid;
107 }
108
109//------------------------------------------------------------------------------
113//------------------------------------------------------------------------------
114 size_t get_num_rays() const {
115 return num_rays;
116 }
117
118//------------------------------------------------------------------------------
122//------------------------------------------------------------------------------
123 int get_num_rays_dim() const {
124 return num_rays_dim;
125 }
126
127//------------------------------------------------------------------------------
131//------------------------------------------------------------------------------
132 int get_unlimited_dim() const {
133 return unlimited_dim;
134 }
135
136//------------------------------------------------------------------------------
140//------------------------------------------------------------------------------
141 size_t get_unlimited_size() const {
142 size_t size;
143 sync.lock();
144 check_error(nc_inq_dimlen(ncid, unlimited_dim, &size));
145 sync.unlock();
146
147 return size;
148 }
149
150//------------------------------------------------------------------------------
152//------------------------------------------------------------------------------
153 void sync_file() const {
154 sync.lock();
155 check_error(nc_sync(ncid));
156 sync.unlock();
157 }
158 };
159
160//------------------------------------------------------------------------------
164//------------------------------------------------------------------------------
165 template<jit::float_scalar T>
166 class data_set {
167 private:
169 int ray_dim;
171 std::array<int, 3> dims;
173 std::array<size_t, 3> count;
175 static constexpr size_t ray_dim_size = 1 + jit::complex_scalar<T>;
177 static constexpr nc_type type = jit::float_base<T> ? NC_FLOAT : NC_DOUBLE;
178
179//------------------------------------------------------------------------------
181//------------------------------------------------------------------------------
182 struct variable {
184 int id;
186 T *buffer;
187 };
189 std::vector<variable> variables;
190
191//------------------------------------------------------------------------------
193//------------------------------------------------------------------------------
194 struct reference {
196 int id;
198 T *buffer;
200 size_t ray_dim_size;
202 std::ptrdiff_t stride;
204 size_t index;
205 };
207 std::vector<reference> references;
208
209 public:
210//------------------------------------------------------------------------------
214//------------------------------------------------------------------------------
215 data_set(const result_file &result) {
216 sync.lock();
217 if constexpr (jit::complex_scalar<T>) {
218 if (NC_NOERR != nc_inq_dimid(result.get_ncid(),
219 "ray_dim_cplx",
220 &ray_dim)) {
221 check_error(nc_def_dim(result.get_ncid(),
222 "ray_dim_cplx", ray_dim_size,
223 &ray_dim));
224 }
225 } else {
226 if (NC_NOERR != nc_inq_dimid(result.get_ncid(),
227 "ray_dim",
228 &ray_dim)) {
229 check_error(nc_def_dim(result.get_ncid(),
230 "ray_dim", ray_dim_size,
231 &ray_dim));
232 }
233 }
234 sync.unlock();
235
236 dims = {
237 result.get_unlimited_dim(),
238 result.get_num_rays_dim(),
239 ray_dim
240 };
241
242 count = {
243 1,
244 result.get_num_rays(),
245 ray_dim_size
246 };
247 }
248
249//------------------------------------------------------------------------------
258//------------------------------------------------------------------------------
259 template<bool SAFE_MATH=false>
260 void create_variable(const result_file &result,
261 const std::string &name,
264 variable var;
265 sync.lock();
266 check_error(nc_def_var(result.get_ncid(), name.c_str(), type,
267 static_cast<int> (dims.size()), dims.data(),
268 &var.id));
269 sync.unlock();
270
271 var.buffer = context.get_buffer(node);
272 variables.push_back(var);
273 }
274
275//------------------------------------------------------------------------------
283//------------------------------------------------------------------------------
284 template<bool SAFE_MATH=false>
285 void reference_variable(const result_file &result,
286 const std::string &name,
288 reference ref;
289 nc_type type;
290 std::array<int, 3> ref_dims;
291
292 sync.lock();
293 check_error(nc_inq_varid(result.get_ncid(),
294 name.c_str(),
295 &ref.id));
296 check_error(nc_inq_var(result.get_ncid(), ref.id, NULL, &type,
297 NULL, ref_dims.data(), NULL));
298 check_error(nc_inq_dimlen(result.get_ncid(), ref_dims[2],
299 &ref.ray_dim_size));
300 sync.unlock();
301
302 assert(ref.ray_dim_size <= ray_dim_size &&
303 "Context variable too small to read reference.");
304
305 ref.stride = ref.ray_dim_size < ray_dim_size ? 2 : 1;
306 ref.buffer = node->data();
307 ref.index = 0;
308 references.push_back(ref);
309 }
310
311//------------------------------------------------------------------------------
319//------------------------------------------------------------------------------
320 template<bool SAFE_MATH=false>
322 const std::string &name,
324 reference ref;
325 nc_type type;
326 std::array<int, 3> ref_dims;
327
328 sync.lock();
329 check_error(nc_inq_varid(result.get_ncid(),
330 name.c_str(),
331 &ref.id));
332 check_error(nc_inq_var(result.get_ncid(), ref.id, NULL, &type,
333 NULL, ref_dims.data(), NULL));
334 check_error(nc_inq_dimlen(result.get_ncid(), ref_dims[2],
335 &ref.ray_dim_size));
336 sync.unlock();
337
338 assert(ref.ray_dim_size == 2 &&
339 "Not a complex variable.");
340
341 ref.ray_dim_size = 1;
342 ref.stride = ref.ray_dim_size < ray_dim_size ? 2 : 1;
343 ref.buffer = node->data();
344 ref.index = 1;
345 references.push_back(ref);
346 }
347
348//------------------------------------------------------------------------------
352//------------------------------------------------------------------------------
353 void write(const result_file &result) {
354 write(result, result.get_unlimited_size());
355 }
356
357//------------------------------------------------------------------------------
362//------------------------------------------------------------------------------
363 void write(const result_file &result,
364 const size_t index) {
365 const std::array<size_t, 3> start = {
366 index, 0, 0
367 };
368
369 for (variable &var : variables) {
370 sync.lock();
371 if constexpr (jit::float_base<T>) {
372 if constexpr (jit::complex_scalar<T>) {
373 check_error(nc_put_vara_float(result.get_ncid(),
374 var.id,
375 start.data(),
376 count.data(),
377 reinterpret_cast<float *> (var.buffer)));
378 } else {
379 check_error(nc_put_vara_float(result.get_ncid(),
380 var.id,
381 start.data(),
382 count.data(),
383 var.buffer));
384 }
385 } else {
386 if constexpr (jit::complex_scalar<T>) {
387 check_error(nc_put_vara_double(result.get_ncid(),
388 var.id,
389 start.data(),
390 count.data(),
391 reinterpret_cast<double *> (var.buffer)));
392 } else {
393 check_error(nc_put_vara_double(result.get_ncid(),
394 var.id,
395 start.data(),
396 count.data(),
397 var.buffer));
398 }
399 }
400 sync.unlock();
401 }
402
403 result.sync_file();
404 }
405
406//------------------------------------------------------------------------------
411//------------------------------------------------------------------------------
412 void read(const result_file &result,
413 const size_t index) {
414 const std::array<std::ptrdiff_t, 3> stride = {
415 1, 1, 1
416 };
417
418 for (reference &ref : references) {
419 const std::array<size_t, 3> ref_start = {
420 index, 0, ref.index
421 };
422 const std::array<size_t, 3> ref_count = {
423 1,
424 result.get_num_rays(),
425 ref.ray_dim_size
426 };
427 const std::array<std::ptrdiff_t, 3> map = {
428 1, ref.stride, 1
429 };
430
431 sync.lock();
432 if constexpr (jit::float_base<T>) {
433 if constexpr (jit::complex_scalar<T>) {
434 check_error(nc_get_varm_float(result.get_ncid(),
435 ref.id,
436 ref_start.data(),
437 ref_count.data(),
438 stride.data(),
439 map.data(),
440 reinterpret_cast<float *> (ref.buffer)));
441 } else {
442 check_error(nc_get_varm_float(result.get_ncid(),
443 ref.id,
444 ref_start.data(),
445 ref_count.data(),
446 stride.data(),
447 map.data(),
448 ref.buffer));
449 }
450 } else {
451 if constexpr (jit::complex_scalar<T>) {
452 check_error(nc_get_varm_double(result.get_ncid(),
453 ref.id,
454 ref_start.data(),
455 ref_count.data(),
456 stride.data(),
457 map.data(),
458 reinterpret_cast<double *> (ref.buffer)));
459 } else {
460 check_error(nc_get_varm_double(result.get_ncid(),
461 ref.id,
462 ref_start.data(),
463 ref_count.data(),
464 stride.data(),
465 map.data(),
466 ref.buffer));
467 }
468 }
469
470 sync.unlock();
471 }
472 }
473 };
474}
475
476#endif /* output_h */
Class for JIT compile of the GPU kernels.
Definition jit.hpp:49
T * get_buffer(graph::shared_leaf< T, SAFE_MATH > &node)
Get buffer from the gpu_context.
Definition jit.hpp:328
Class representing a netcdf dataset.
Definition output.hpp:166
void read(const result_file &result, const size_t index)
Read step.
Definition output.hpp:412
data_set(const result_file &result)
Construct a dataset.
Definition output.hpp:215
void reference_imag_variable(const result_file &result, const std::string &name, graph::shared_variable< T, SAFE_MATH > &&node)
Load imaginary reference.
Definition output.hpp:321
void write(const result_file &result, const size_t index)
Write step.
Definition output.hpp:363
void reference_variable(const result_file &result, const std::string &name, graph::shared_variable< T, SAFE_MATH > &&node)
Load reference.
Definition output.hpp:285
void create_variable(const result_file &result, const std::string &name, graph::shared_leaf< T, SAFE_MATH > &node, jit::context< T, SAFE_MATH > &context)
Create a variable.
Definition output.hpp:260
void write(const result_file &result)
Write step.
Definition output.hpp:353
Class representing a netcdf based output file.
Definition output.hpp:32
void end_define_mode() const
End define mode.
Definition output.hpp:94
~result_file()
Destructor.
Definition output.hpp:87
size_t get_num_rays() const
Get the number of rays.
Definition output.hpp:114
int get_num_rays_dim() const
Get the number of rays dimension.
Definition output.hpp:123
result_file(const std::string &filename)
Open a new result file.
Definition output.hpp:73
int get_ncid() const
Get ncid.
Definition output.hpp:105
size_t get_unlimited_size() const
Get unlimited size.
Definition output.hpp:141
int get_unlimited_dim() const
Get unlimited dimension.
Definition output.hpp:132
result_file(const std::string &filename, const size_t num_rays)
Construct a new result file.
Definition output.hpp:50
void sync_file() const
Sync the file.
Definition output.hpp:153
Complex scalar concept.
Definition register.hpp:24
float base concept.
Definition register.hpp:37
subroutine assert(test, message)
Assert check.
Definition f_binding_test.f90:38
Class to just in time compile a kernel.
std::shared_ptr< variable_node< T, SAFE_MATH > > shared_variable
Convenience type alias for shared variable nodes.
Definition node.hpp:1727
std::shared_ptr< leaf_node< T, SAFE_MATH > > shared_leaf
Convenience type alias for shared leaf nodes.
Definition node.hpp:673
std::string format_to_string(const T value)
Convert a value to a string while avoiding locale.
Definition register.hpp:211
Name space for output files.
Definition output.hpp:16