V3FIT
guassian_process.f
Go to the documentation of this file.
1 !*******************************************************************************
4 !
5 ! Note separating the Doxygen comment block here so detailed decription is
6 ! found in the Module not the file.
7 !
10 !*******************************************************************************
12  USE signal
13  USE pprofile_t
14  USE v3fit_params
15 
16 !*******************************************************************************
17 ! DERIVED-TYPE DECLARATIONS
18 ! 1) gaussian process base class
19 ! 2) gaussian process pointer type
20 !
21 !*******************************************************************************
22 !-------------------------------------------------------------------------------
24 !-------------------------------------------------------------------------------
32 ! Keep for now, but this might not be needed later.
33  INTEGER :: flags = model_state_all_off
34 
36  TYPE (signal_pointer), DIMENSION(:), POINTER :: &
37  & signals => null()
38 
40  REAL (rprec) :: tolerance
41 
43  REAL (rprec), DIMENSION(:,:), POINTER :: kls => null()
45  REAL (rprec), DIMENSION(:,:), POINTER :: kll => null()
47  REAL (rprec), DIMENSION(:,:), POINTER :: work => null()
50  REAL (rprec) :: cholesky_fac
51 
53  INTEGER :: profile_index = -1
54 
56  REAL (rprec), DIMENSION(:), POINTER :: fpoints => null()
57 
59  TYPE (param_pointer), DIMENSION(:), POINTER :: &
60  & params => null()
61 
63  INTEGER :: iou
64  END TYPE
65 
66 !-------------------------------------------------------------------------------
70 !-------------------------------------------------------------------------------
74  TYPE (gaussp_class), POINTER :: p => null()
75  END TYPE
76 
77  CONTAINS
78 !*******************************************************************************
79 ! CONSTRUCTION SUBROUTINES
80 !*******************************************************************************
81 !-------------------------------------------------------------------------------
99 !-------------------------------------------------------------------------------
100  FUNCTION gaussp_construct(a_model, n_signals, gaussp_type, &
101  & profile_index, vrnc, tolerance, &
102  & cholesky_fac)
104  USE safe_open_mod
105 
106  IMPLICIT NONE
107 
108 ! Declare Arguments
109  TYPE (gaussp_class), POINTER :: gaussp_construct
110  TYPE (model_class), INTENT(in) :: a_model
111  INTEGER, INTENT(in) :: n_signals
112  INTEGER, INTENT(in) :: profile_index
113  CHARACTER(len=*), INTENT(in) :: gaussp_type
114  REAL (rprec), DIMENSION(:) :: vrnc
115  REAL (rprec) :: tolerance
116  REAL (rprec) :: cholesky_fac
117 
118 ! Local Variables
119  INTEGER :: i
120  INTEGER :: n_points
121  INTEGER :: n_params
122  INTEGER, DIMENSION(:,:), ALLOCATABLE :: indices
123  CHARACTER (len=data_name_length) :: param_name
124  INTEGER :: status
125  CHARACTER (len=15) :: log_file
126  REAL (rprec) :: start_time
127 
128 ! Local Parameters
129  CHARACTER (len=*), DIMENSION(2), PARAMETER :: range_type = &
130  & 'infinity'
131  INTEGER, DIMENSION(2,2), PARAMETER :: range_indices = 0
132  real (rprec), DIMENSION(2), PARAMETER :: range_value = 0.0
133 
134 ! Start of executable code
135  start_time = profiler_get_start_time()
136 
137  ALLOCATE(gaussp_construct)
138 
139  SELECT CASE (trim(gaussp_type))
140 
141  CASE ('sxrem')
142  n_params = model_get_gp_sxrem_num_hyper_param(a_model, &
143  & profile_index)
144  gaussp_construct%flags = ibset(gaussp_construct%flags, &
145  & model_state_sxrem_flag + &
146  & profile_index - 1)
147  gaussp_construct%profile_index = profile_index
148  gaussp_construct%fpoints => &
149  & model_get_sxrem_af(a_model, profile_index)
150  param_name = 'pp_sxrem_b_a'
151  ALLOCATE(indices(data_max_indices,n_params))
152  indices = 0
153  DO i = 1, n_params
154  indices(1,i) = profile_index
155  indices(2,i) = i - 1
156  END DO
157  gaussp_construct%iou = 0
158  WRITE (log_file,1001) 'sxrem', profile_index
159 
160  CASE ('te')
161  n_params = model_get_gp_te_num_hyper_param(a_model)
162  gaussp_construct%flags = ibset(gaussp_construct%flags, &
163  & model_state_te_flag)
164  gaussp_construct%fpoints => model_get_te_af(a_model)
165  param_name = 'pp_te_b'
166  ALLOCATE(indices(data_max_indices,n_params))
167  indices = 0
168  DO i = 1, n_params
169  indices(1,i) = i - 1
170  END DO
171  gaussp_construct%iou = 0
172  WRITE (log_file,1000) 'te'
173 
174  CASE ('ti')
175  n_params = model_get_gp_te_num_hyper_param(a_model)
176  gaussp_construct%flags = ibset(gaussp_construct%flags, &
177  & model_state_ti_flag)
178  gaussp_construct%fpoints => model_get_te_af(a_model)
179  param_name = 'pp_te_b'
180  ALLOCATE(indices(data_max_indices,n_params))
181  indices = 0
182  DO i = 1, n_params
183  indices(1,i) = i - 1
184  END DO
185  gaussp_construct%iou = 0
186  WRITE (log_file,1000) 'ti'
187 
188  CASE ('ne')
189  n_params = model_get_gp_ne_num_hyper_param(a_model)
190  gaussp_construct%flags = ibset(gaussp_construct%flags, &
191  & model_state_ne_flag)
192  gaussp_construct%fpoints => model_get_ne_af(a_model)
193  param_name = 'pp_ne_b'
194  ALLOCATE(indices(data_max_indices,n_params))
195  indices = 0
196  DO i = 1, n_params
197  indices(1,i) = i - 1
198  END DO
199  gaussp_construct%iou = 0
200  WRITE (log_file,1000) 'ne'
201 
202  END SELECT
203 
204  CALL safe_open(gaussp_construct%iou, status, trim(log_file), &
205  & 'replace', 'formatted', delim_in='none')
206 
207  n_points = SIZE(gaussp_construct%fpoints)
208 
209  ALLOCATE(gaussp_construct%signals(n_signals))
210 
211  ALLOCATE(gaussp_construct%kls(n_signals,n_points))
212  ALLOCATE(gaussp_construct%kll(n_signals,n_signals))
213 
214  ALLOCATE(gaussp_construct%work(n_signals,n_points + 1))
215 
216  gaussp_construct%tolerance = tolerance
217  gaussp_construct%cholesky_fac = cholesky_fac
218 
219  ALLOCATE(gaussp_construct%params(n_params))
220  DO i = 1, n_params
223  gaussp_construct%params(i)%p => &
224  & param_construct(a_model, param_name, indices(:,i), vrnc(i), &
225  & range_type, range_indices, range_value, &
226  & n_signals, n_params)
227  END DO
228 
229  DEALLOCATE(indices)
230 
231  CALL profiler_set_stop_time('gaussp_construct', start_time)
232 
233 1000 FORMAT(a,'_gp.log')
234 1001 FORMAT(a,'_',i0.2,'_gp.log')
235 
236  END FUNCTION
237 
238 !*******************************************************************************
239 ! DESTRUCTION SUBROUTINES
240 !*******************************************************************************
241 !-------------------------------------------------------------------------------
249 !-------------------------------------------------------------------------------
250  SUBROUTINE gaussp_destruct(this)
251 
252  IMPLICIT NONE
253 
254 ! Declare Arguments
255  TYPE (gaussp_class), POINTER :: this
256 
257 ! Local Variables
258  INTEGER :: i
259 
260 ! Start of executable code
261 
262 ! Null all pointers in the signals array. Do not deallocate the signals in the
263 ! array since this object is not the owner.
264  DO i = 1, SIZE(this%signals)
265  this%signals(i)%p => null()
266  END DO
267 
268  IF (ASSOCIATED(this%signals)) THEN
269  DEALLOCATE(this%signals)
270  this%signals => null()
271  END IF
272 
273  IF (ASSOCIATED(this%kls)) THEN
274  DEALLOCATE(this%kls)
275  this%kls => null()
276  END IF
277 
278  IF (ASSOCIATED(this%kll)) THEN
279  DEALLOCATE(this%kll)
280  this%kll => null()
281  END IF
282 
283  IF (ASSOCIATED(this%work)) THEN
284  DEALLOCATE(this%work)
285  this%work => null()
286  END IF
287 
288  this%fpoints => null()
289 
290 ! Deconstruct and deallocate all the hyper parameters.
291  IF (ASSOCIATED(this%params)) THEN
292  DO i = 1, SIZE(this%params)
293  IF (ASSOCIATED(this%params(i)%p)) THEN
294  CALL param_destruct(this%params(i)%p)
295  this%params(i)%p => null()
296  END IF
297  END DO
298  DEALLOCATE(this%params)
299  this%params => null()
300  END IF
301 
302  CLOSE (this%iou)
303 
304  DEALLOCATE(this)
305 
306  END SUBROUTINE
307 
308 !*******************************************************************************
309 ! SETTER SUBROUTINES
310 !*******************************************************************************
311 !-------------------------------------------------------------------------------
321 !-------------------------------------------------------------------------------
322  SUBROUTINE gaussp_set_signal(this, signal, index)
323 
324  IMPLICIT NONE
325 
326 ! Declare Arguments
327  TYPE (gaussp_class), INTENT(inout) :: this
328  CLASS (signal_class), POINTER :: signal
329  INTEGER, INTENT(in) :: index
330 
331 ! local Variables
332  REAL (rprec) :: start_time
333 
334 ! Start of executable code
335  start_time = profiler_get_start_time()
336 
337  this%signals(index)%p => signal
338 
339  CALL profiler_set_stop_time('gaussp_set_signal', start_time)
340 
341  END SUBROUTINE
342 
343 !-------------------------------------------------------------------------------
352 !-------------------------------------------------------------------------------
353  SUBROUTINE gaussp_set_profile(this, a_model)
354 
355  IMPLICIT NONE
356 
357 ! Declare Arguments
358  TYPE (gaussp_class), INTENT(inout) :: this
359  TYPE (model_class), POINTER :: a_model
360 
361 ! Local Variables
362  INTEGER :: i
363  INTEGER :: j
364  REAL (rprec) :: start_time
365  REAL (rprec), DIMENSION(:), ALLOCATABLE :: cached_value
366  REAL (rprec), DIMENSION(:), ALLOCATABLE :: gradient
367  REAL (rprec) :: evidence
368  REAL (rprec) :: new_evidence
369  REAL (rprec) :: gamma
370  REAL (rprec) :: temp_param
371  INTEGER :: status
372 
373 ! local parameters
374  REAL (rprec), PARAMETER :: gamma_init = 0.01
375 
376 ! Start of executable code
377  start_time = profiler_get_start_time()
378 
379 ! Flags will have a single bit position of a corresponding profile set.
380 !
381 ! this%flags .and. a_model%state_flags
382 !
383 ! will evaluate to true if the coresponding flag is set in the model state
384 ! flags.
385  IF (btest(a_model%state_flags, model_state_vmec_flag) .or. &
386  & btest(a_model%state_flags, model_state_siesta_flag) .or. &
387  & btest(a_model%state_flags, model_state_shift_flag) .or. &
388  & btest(a_model%state_flags, model_state_signal_flag) .or. &
389  & iand(this%flags, a_model%state_flags) .ne. 0) THEN
390 
391 ! Use gradient ascent to fine the optimal hyper parameters by maximizing the
392 ! the evidence.
393  new_evidence = gaussp_get_evidence(this, a_model)
394 
395  ALLOCATE(cached_value(SIZE(this%params)))
396  ALLOCATE(gradient(SIZE(this%params)))
397 
398  DO i = 1, SIZE(this%params)
399  cached_value(i) = param_get_value(this%params(i)%p, a_model)
400  END DO
401 
402  WRITE (this%iou,1000)
403  WRITE (this%iou,1001) new_evidence
404 
405  DO
406  evidence = new_evidence
407  gamma = gamma_init
408 
409 ! Compute gradient.
410  DO i = 1, SIZE(this%params)
411  CALL param_increment(this%params(i)%p, a_model, &
412  & mpi_comm_null, .false.)
413  gradient(i) = (gaussp_get_evidence(this, a_model) - &
414  & evidence)/this%params(i)%p%recon%delta
415  CALL param_decrement(this%params(i)%p, a_model, &
416  & mpi_comm_null)
417  END DO
418 
419 ! Step towards the maximum.
420  cached_value = cached_value + gamma*gradient
421 
422  DO i = 1, SIZE(this%params)
423  CALL param_set_value(this%params(i)%p, a_model, &
424  & cached_value(i), mpi_comm_null, &
425  & .false.)
426  END DO
427  new_evidence = gaussp_get_evidence(this, a_model)
428 
429  DO WHILE (new_evidence - evidence .lt. 0.0)
430  gamma = 0.5*gamma
431  cached_value = cached_value - gamma*gradient
432 
433  DO i = 1, SIZE(this%params)
434  CALL param_set_value(this%params(i)%p, a_model, &
435  & cached_value(i), mpi_comm_null, &
436  & .false.)
437  END DO
438  new_evidence = gaussp_get_evidence(this, a_model)
439  END DO
440 
441  WRITE (this%iou,1002) new_evidence, &
442  & new_evidence - evidence, gamma
443  FLUSH(this%iou)
444 
445  IF (abs(new_evidence - evidence) .lt. this%tolerance) THEN
446  EXIT
447  END IF
448  END DO
449 
450  WRITE (this%iou,*) 'Final Parameters', cached_value
451  FLUSH(this%iou)
452 
453  DEALLOCATE(cached_value)
454  DEALLOCATE(gradient)
455  END IF
456 
457 ! Found the maximum now set the profile. The profile is determined by sampling
458 ! the distribution.
459 !
460 ! af = K_SL * A^-1 (1)
461 !$OMP PARALLEL
462 !$OMP& DEFAULT(SHARED)
463 !$OMP DO
464 !$OMP& SCHEDULE(STATIC)
465  DO i = 1, SIZE(this%fpoints)
466  this%fpoints(i) = dot_product(this%kls(:,i), this%work(:,1))
467  END DO
468 !$OMP END DO
469 !$OMP END PARALLEL
470 
471  CALL profiler_set_stop_time('gaussp_set_profile', start_time)
472 
473 1000 FORMAT('New Process')
474 1001 FORMAT('Log Evidence : ',es12.5)
475 1002 FORMAT('Log Evidence : ',es12.5,' Change : ',es12.5, &
476  & ' Gamma : ',es12.5)
477 
478  END SUBROUTINE
479 
480 !*******************************************************************************
481 ! GETTER SUBROUTINES
482 !*******************************************************************************
483 !-------------------------------------------------------------------------------
492 !-------------------------------------------------------------------------------
493  FUNCTION gaussp_get_evidence(this, a_model)
494  USE stel_constants, ONLY: twopi
495 
496  IMPLICIT NONE
497 
498 ! Declare Arguments
499  REAL (rprec) :: gaussp_get_evidence
500  TYPE (gaussp_class), INTENT(inout) :: this
501  TYPE (model_class), POINTER :: a_model
502 
503 ! Local Variables
504  class(signal_class), POINTER :: signal_obj
505  INTEGER :: i
506  INTEGER :: j
507  REAL (rprec) :: start_time
508  INTEGER :: ierr
509 
510 ! Local Parameters
511  REAL (rprec), PARAMETER :: log2pi = log(twopi)
512 
513 ! Start of executable code
514  start_time = profiler_get_start_time()
515 
516 !$OMP PARALLEL
517 !$OMP& DEFAULT(SHARED)
518 !$OMP& PRIVATE(i,j,signal_obj)
519 
520 !$OMP DO
521 !$OMP& SCHEDULE(DYNAMIC)
522  DO j = 1, SIZE(this%fpoints)
523  DO i = 1, SIZE(this%signals)
524  signal_obj => this%signals(i)%p
525  this%kls(i,j) = signal_obj%get_gp(a_model, j, this%flags)
526  END DO
527  this%work(:,j + 1) = this%kls(:,j)
528  END DO
529 !$OMP END DO
530 
531 !$OMP DO
532 !$OMP& SCHEDULE(DYNAMIC)
533  DO j = 1, SIZE(this%signals)
534  DO i = j + 1, SIZE(this%signals)
535  signal_obj => this%signals(i)%p
536  this%kll(i,j) = signal_obj%get_gp(a_model, &
537  & this%signals(j)%p, &
538  & this%flags)
539  this%kll(j,i) = this%kll(i,j)
540  END DO
541  signal_obj => this%signals(j)%p
542  this%kll(j,j) = signal_obj%get_gp(a_model, signal_obj, &
543  & this%flags) &
544  & + signal_obj%get_sigma2() &
545  & + this%cholesky_fac
546  this%work(j,1) = signal_obj%get_observed_signal(a_model)
547  END DO
548 !$OMP END DO
549 !$OMP END PARALLEL
550 
551 ! CALL LAPACK to do Cholesky factorization of the A=(K_LL + Sigma_y)
552  CALL dpotrf('L', SIZE(this%signals), this%kll, &
553  & SIZE(this%signals), ierr)
554 
555 ! The elements in the cholesky matrix for j > i are meaning less.
556  IF (ierr .lt. 0) THEN
557  CALL err_fatal('gaussp_get_evidence: DROTRF cannot ' // &
558  & 'factor matrix. ierr is negative')
559  ELSE IF (ierr .gt. 0) THEN
560  CALL err_fatal('gaussp_get_evidence: DROTRF cannot ' // &
561  & 'factor matrix. ierr is positive. ' // &
562  & 'Try increasing gp_cholesky_fact')
563  END IF
564 
565 ! IF statement correctly throws an error if I cannot factor the matrix!
566 
567 ! CALL LAPACK to solve A^-1y and A^-1 K_LS simultaneously. The work matrix
568 ! stores the RHS on input, and returns the solution matrix.
569  CALL dpotrs('L', SIZE(this%signals), SIZE(this%fpoints) + 1, &
570  & this%kll, SIZE(this%signals), this%work, &
571  & SIZE(this%signals), ierr)
572  IF (ierr .lt. 0) THEN
573  CALL err_fatal('gaussp_get_modeled_signal: DROTRS ' // &
574  & 'failed to solve the equation')
575  END IF
576 
577 ! Calculate log evidence.
578 !
579 ! P(y|sigma,x) = 1/Sqrt((2Pi)^N*|A|)Exp(-y^T.A^-1.y/2) (1)
580 !
581 ! Taking the ln of the evidence results in.
582 !
583 ! ln(P) = -N/2*ln(2Pi) - ln(|A|)/2 - y^T.A^-1.y/2 (2)
584 !
585 ! The determinant of A is found by the Cholesky factorization and equals the
586 ! product of the squares of the diagonal.
587 !
588 ! |A| = Prod (A_ii)^2 (3)
589 !
590 ! Taking the ln of this results in.
591 !
592 ! ln(|A|) = Sum 2*ln(A_ii) (4)
593 !
594 
595  gaussp_get_evidence = -0.5*SIZE(this%signals)*log2pi
596 !$OMP PARALLEL
597 !$OMP& DEFAULT(SHARED)
598 !$OMP& PRIVATE(j)
599 !$OMP DO
600 !$OMP& REDUCTION(-:gaussp_get_evidence)
601  DO j = 1, SIZE(this%signals)
603  & - log(this%kll(j,j)) &
604  & - 0.5*signal_get_observed_signal( &
605  & this%signals(j)%p, a_model)* &
606  & this%work(j,1)
607  END DO
608 !$OMP END DO
609 !$OMP END PARALLEL
610 
611  CALL profiler_set_stop_time('gaussp_get_evidence', start_time)
612 
613  END FUNCTION
614 
615  END MODULE
guassian_process::gaussp_construct
type(gaussp_class) function, pointer gaussp_construct(a_model, n_signals, gaussp_type, profile_index, vrnc, tolerance, cholesky_fac)
Construct a gaussp_class.
Definition: guassian_process.f:103
guassian_process::gaussp_get_evidence
real(rprec) function gaussp_get_evidence(this, a_model)
Calculates the evidence.
Definition: guassian_process.f:494
guassian_process::gaussp_class_pointer
Pointer to a gaussian process object. Used for creating arrays of gaussian process pointers....
Definition: guassian_process.f:71
signal::signal_get_observed_signal
real(rprec) function signal_get_observed_signal(this, a_model)
Calculates the observed signal.
Definition: signal.f:388
v3fit_params::param_pointer
Pointer to a parameter object. Used for creating arrays of signal pointers. This is needed because fo...
Definition: v3fit_params.f:122
v3fit_params::param_increment
subroutine param_increment(this, a_model, eq_comm, is_central)
Increments the parameter value.
Definition: v3fit_params.f:986
guassian_process::gaussp_destruct
subroutine gaussp_destruct(this)
Deconstruct a gaussp_class object.
Definition: guassian_process.f:251
data_parameters::data_max_indices
integer, parameter data_max_indices
Max number of parameter indicies.
Definition: data_parameters.f:29
guassian_process
Defines the base class of the type guassian_process_class. The guassian_process contains code to comp...
Definition: guassian_process.f:11
v3fit_params::param_destruct
subroutine param_destruct(this)
Deconstruct a param_class object.
Definition: v3fit_params.f:420
v3fit_params::param_get_value
real(rprec) function param_get_value(this, a_model)
Gets the parameter value.
Definition: v3fit_params.f:652
v3fit_params::param_set_value
subroutine param_set_value(this, a_model, value, eq_comm, is_central)
Sets the parameter value.
Definition: v3fit_params.f:493
data_parameters
This modules contains parameters used by equilibrium models.
Definition: data_parameters.f:10
signal::signal_pointer
Pointer to a signal object. Used for creating arrays of signal pointers. This is needed because fortr...
Definition: signal.f:100
guassian_process::gaussp_set_signal
subroutine gaussp_set_signal(this, signal, index)
Set the object and coefficient for an index.
Definition: guassian_process.f:323
v3fit_params
Defines the base class of the type param_class.
Definition: v3fit_params.f:11
v3fit_params::param_decrement
subroutine param_decrement(this, a_model, eq_comm)
Decrements the parameter value.
Definition: v3fit_params.f:1084
pprofile_t
Defines the base class of the type pprofile_class. This module contains all the code necessary to def...
Definition: pprofile_T.f:88
signal::signal_class
Base class representing a signal.
Definition: signal.f:33
guassian_process::gaussp_set_profile
subroutine gaussp_set_profile(this, a_model)
Set the object and coefficient for an index.
Definition: guassian_process.f:354
guassian_process::gaussp_class
Base class representing a gaussian process.
Definition: guassian_process.f:25
signal
Defines the base class of the type signal_class.
Definition: signal.f:14
v3fit_params::param_construct
Interface for the construction of param_class types using param_construct_basic or param_construct_re...
Definition: v3fit_params.f:135