33 INTEGER :: flags = model_state_all_off
40 REAL (rprec) :: tolerance
43 REAL (rprec),
DIMENSION(:,:),
POINTER :: kls => null()
45 REAL (rprec),
DIMENSION(:,:),
POINTER :: kll => null()
47 REAL (rprec),
DIMENSION(:,:),
POINTER :: work => null()
50 REAL (rprec) :: cholesky_fac
53 INTEGER :: profile_index = -1
56 REAL (rprec),
DIMENSION(:),
POINTER :: fpoints => null()
101 & profile_index, vrnc, tolerance, &
110 TYPE (model_class),
INTENT(in) :: a_model
111 INTEGER,
INTENT(in) :: n_signals
112 INTEGER,
INTENT(in) :: profile_index
113 CHARACTER(len=*),
INTENT(in) :: gaussp_type
114 REAL (rprec),
DIMENSION(:) :: vrnc
115 REAL (rprec) :: tolerance
116 REAL (rprec) :: cholesky_fac
122 INTEGER,
DIMENSION(:,:),
ALLOCATABLE :: indices
123 CHARACTER (len=data_name_length) :: param_name
125 CHARACTER (len=15) :: log_file
126 REAL (rprec) :: start_time
129 CHARACTER (len=*),
DIMENSION(2),
PARAMETER :: range_type = &
131 INTEGER,
DIMENSION(2,2),
PARAMETER :: range_indices = 0
132 real (rprec),
DIMENSION(2),
PARAMETER :: range_value = 0.0
135 start_time = profiler_get_start_time()
139 SELECT CASE (trim(gaussp_type))
142 n_params = model_get_gp_sxrem_num_hyper_param(a_model,
145 & model_state_sxrem_flag +
149 & model_get_sxrem_af(a_model, profile_index)
150 param_name =
'pp_sxrem_b_a'
154 indices(1,i) = profile_index
158 WRITE (log_file,1001)
'sxrem', profile_index
161 n_params = model_get_gp_te_num_hyper_param(a_model)
163 & model_state_te_flag)
165 param_name =
'pp_te_b'
172 WRITE (log_file,1000)
'te'
175 n_params = model_get_gp_te_num_hyper_param(a_model)
177 & model_state_ti_flag)
179 param_name =
'pp_te_b'
186 WRITE (log_file,1000)
'ti'
189 n_params = model_get_gp_ne_num_hyper_param(a_model)
191 & model_state_ne_flag)
193 param_name =
'pp_ne_b'
200 WRITE (log_file,1000)
'ne'
205 &
'replace',
'formatted', delim_in=
'none')
225 & range_type, range_indices, range_value,
226 & n_signals, n_params)
231 CALL profiler_set_stop_time(
'gaussp_construct', start_time)
233 1000
FORMAT(a,
'_gp.log')
234 1001
FORMAT(a,
'_',i0.2,
'_gp.log')
255 TYPE (gaussp_class),
POINTER :: this
264 DO i = 1,
SIZE(this%signals)
265 this%signals(i)%p => null()
268 IF (
ASSOCIATED(this%signals))
THEN
269 DEALLOCATE(this%signals)
270 this%signals => null()
273 IF (
ASSOCIATED(this%kls))
THEN
278 IF (
ASSOCIATED(this%kll))
THEN
283 IF (
ASSOCIATED(this%work))
THEN
284 DEALLOCATE(this%work)
288 this%fpoints => null()
291 IF (
ASSOCIATED(this%params))
THEN
292 DO i = 1,
SIZE(this%params)
293 IF (
ASSOCIATED(this%params(i)%p))
THEN
295 this%params(i)%p => null()
298 DEALLOCATE(this%params)
299 this%params => null()
327 TYPE (gaussp_class),
INTENT(inout) :: this
328 CLASS (signal_class),
POINTER :: signal
329 INTEGER,
INTENT(in) :: index
332 REAL (rprec) :: start_time
335 start_time = profiler_get_start_time()
337 this%signals(index)%p =>
signal
339 CALL profiler_set_stop_time(
'gaussp_set_signal', start_time)
358 TYPE (gaussp_class),
INTENT(inout) :: this
359 TYPE (model_class),
POINTER :: a_model
364 REAL (rprec) :: start_time
365 REAL (rprec),
DIMENSION(:),
ALLOCATABLE :: cached_value
366 REAL (rprec),
DIMENSION(:),
ALLOCATABLE :: gradient
367 REAL (rprec) :: evidence
368 REAL (rprec) :: new_evidence
369 REAL (rprec) :: gamma
370 REAL (rprec) :: temp_param
374 REAL (rprec),
PARAMETER :: gamma_init = 0.01
377 start_time = profiler_get_start_time()
385 IF (btest(a_model%state_flags, model_state_vmec_flag) .or.
386 & btest(a_model%state_flags, model_state_siesta_flag) .or.
387 & btest(a_model%state_flags, model_state_shift_flag) .or.
388 & btest(a_model%state_flags, model_state_signal_flag) .or.
389 & iand(this%flags, a_model%state_flags) .ne. 0)
THEN
395 ALLOCATE(cached_value(
SIZE(this%params)))
396 ALLOCATE(gradient(
SIZE(this%params)))
398 DO i = 1,
SIZE(this%params)
402 WRITE (this%iou,1000)
403 WRITE (this%iou,1001) new_evidence
406 evidence = new_evidence
410 DO i = 1,
SIZE(this%params)
412 & mpi_comm_null, .false.)
414 & evidence)/this%params(i)%p%recon%delta
420 cached_value = cached_value + gamma*gradient
422 DO i = 1,
SIZE(this%params)
424 & cached_value(i), mpi_comm_null,
429 DO WHILE (new_evidence - evidence .lt. 0.0)
431 cached_value = cached_value - gamma*gradient
433 DO i = 1,
SIZE(this%params)
435 & cached_value(i), mpi_comm_null,
441 WRITE (this%iou,1002) new_evidence,
442 & new_evidence - evidence, gamma
445 IF (abs(new_evidence - evidence) .lt. this%tolerance)
THEN
450 WRITE (this%iou,*)
'Final Parameters', cached_value
453 DEALLOCATE(cached_value)
465 DO i = 1,
SIZE(this%fpoints)
466 this%fpoints(i) = dot_product(this%kls(:,i), this%work(:,1))
471 CALL profiler_set_stop_time(
'gaussp_set_profile', start_time)
473 1000
FORMAT(
'New Process')
474 1001
FORMAT(
'Log Evidence : ',es12.5)
475 1002
FORMAT(
'Log Evidence : ',es12.5,
' Change : ',es12.5,
476 &
' Gamma : ',es12.5)
494 USE stel_constants,
ONLY: twopi
501 TYPE (model_class),
POINTER :: a_model
507 REAL (rprec) :: start_time
511 REAL (rprec),
PARAMETER :: log2pi = log(twopi)
514 start_time = profiler_get_start_time()
522 DO j = 1,
SIZE(this%fpoints)
523 DO i = 1,
SIZE(this%signals)
524 signal_obj => this%signals(i)%p
525 this%kls(i,j) = signal_obj%get_gp(a_model, j, this%flags)
527 this%work(:,j + 1) = this%kls(:,j)
533 DO j = 1,
SIZE(this%signals)
534 DO i = j + 1,
SIZE(this%signals)
535 signal_obj => this%signals(i)%p
536 this%kll(i,j) = signal_obj%get_gp(a_model,
539 this%kll(j,i) = this%kll(i,j)
541 signal_obj => this%signals(j)%p
542 this%kll(j,j) = signal_obj%get_gp(a_model, signal_obj,
544 & + signal_obj%get_sigma2()
545 & + this%cholesky_fac
546 this%work(j,1) = signal_obj%get_observed_signal(a_model)
552 CALL dpotrf(
'L',
SIZE(this%signals), this%kll,
553 &
SIZE(this%signals), ierr)
556 IF (ierr .lt. 0)
THEN
557 CALL err_fatal(
'gaussp_get_evidence: DROTRF cannot ' //
558 &
'factor matrix. ierr is negative')
559 ELSE IF (ierr .gt. 0)
THEN
560 CALL err_fatal(
'gaussp_get_evidence: DROTRF cannot ' //
561 &
'factor matrix. ierr is positive. ' //
562 &
'Try increasing gp_cholesky_fact')
569 CALL dpotrs(
'L',
SIZE(this%signals),
SIZE(this%fpoints) + 1,
570 & this%kll,
SIZE(this%signals), this%work,
571 &
SIZE(this%signals), ierr)
572 IF (ierr .lt. 0)
THEN
573 CALL err_fatal(
'gaussp_get_modeled_signal: DROTRS ' //
574 &
'failed to solve the equation')
601 DO j = 1,
SIZE(this%signals)
603 & - log(this%kll(j,j))
605 & this%signals(j)%p, a_model)*
611 CALL profiler_set_stop_time(
'gaussp_get_evidence', start_time)