
EnKF type schemes

  2EnKF type schemes
  4# External imports
  5import numpy as np
  6from scipy.linalg import solve
  7from copy import deepcopy
  8from geostat.decomp import Cholesky                     # Making realizations
 10# Internal imports
 11from pipt.loop.ensemble import Ensemble
 12# Misc. tools used in analysis schemes
 13from pipt.misc_tools import analysis_tools as at
 15from pipt.update_schemes.update_methods_ns.approx_update import approx_update
 16from pipt.update_schemes.update_methods_ns.full_update import full_update
 17from pipt.update_schemes.update_methods_ns.subspace_update import subspace_update
 20class enkfMixIn(Ensemble):
 21    """
 22    Straightforward EnKF analysis scheme implementation. The sequential updating can be done with general grouping and
 23    ordering of data. If only one-step EnKF is to be done, use `es` instead.
 24    """
 26    def __init__(self, keys_da, keys_fwd, sim):
 27        """
 28        The class is initialized by passing the PIPT init. file upwards in the hierarchy to be read and parsed in
 29        `pipt.input_output.pipt_init.ReadInitFile`.
 31        Parameters
 32        ----------
 33        init_file: str
 34            PIPT init. file containing info. to run the inversion algorithm
 36        """
 37        # Pass the init_file upwards in the hierarchy
 38        super().__init__(keys_da, keys_fwd, sim)
 40        self.prev_data_misfit = None
 42        if self.restart is False:
 43            self.prior_state = deepcopy(self.state)
 44            self.list_states = list(self.state.keys())
 45            # At the moment, the iterative loop is threated as an iterative smoother an thus we check if assim. indices
 46            # are given as in the Simultaneous loop.
 47            self.check_assimindex_sequential()
 49            # Extract no. assimilation steps from MDA keyword in DATAASSIM part of init. file and set this equal to
 50            # the number of iterations pluss one. Need one additional because the iter=0 is the prior run.
 51            self.max_iter = len(self.keys_da['assimindex'])+1
 52            self.iteration = 0
 53            self.lam = 0  # set LM lamda to zero as we are doing one full update.
 54            if 'energy' in self.keys_da:
 55                # initial energy (Remember to extract this)
 56                self.trunc_energy = self.keys_da['energy']
 57                if self.trunc_energy > 1:  # ensure that it is given as percentage
 58                    self.trunc_energy /= 100.
 59            else:
 60                self.trunc_energy = 0.98
 61            self.current_state = deepcopy(self.state)
 63            self.state_scaling = at.calc_scaling(
 64                self.prior_state, self.list_states, self.prior_info)
 66    def calc_analysis(self):
 67        """
 68        Calculate the analysis step of the EnKF procedure. The updating is done using the Kalman filter equations, using
 69        svd for numerical stability. Localization is available.
 70        """
 71        # If this is initial analysis we calculate the objective function for all data. In the final convergence check
 72        # we calculate the posterior objective function for all data
 73        if not hasattr(self, 'prior_data_misfit'):
 74            assim_index = [self.keys_da['obsname'], list(
 75                np.concatenate(self.keys_da['assimindex']))]
 76            list_datatypes, list_active_dataypes = at.get_list_data_types(
 77                self.obs_data, assim_index)
 78            if not hasattr(self, 'cov_data'):
 79                self.full_cov_data = at.gen_covdata(
 80                    self.datavar, assim_index, list_datatypes)
 81            else:
 82                self.full_cov_data = self.cov_data
 83            obs_data_vector, pred_data = at.aug_obs_pred_data(
 84                self.obs_data, self.pred_data, assim_index, list_datatypes)
 85            # Generate realizations of the observed data
 86            init_en = Cholesky()  # Initialize GeoStat class for generating realizations
 87            self.full_real_obs_data = init_en.gen_real(
 88                obs_data_vector, self.full_cov_data, self.ne)
 90            # Calc. misfit for the initial iteration
 91            data_misfit = at.calc_objectivefun(
 92                self.full_real_obs_data, pred_data, self.full_cov_data)
 94            # Store the (mean) data misfit (also for conv. check)
 95            self.data_misfit = np.mean(data_misfit)
 96            self.prior_data_misfit = np.mean(data_misfit)
 97            self.data_misfit_std = np.std(data_misfit)
 99            self.logger.info(
100                f'Prior run complete with data misfit: {self.prior_data_misfit:0.1f}.')
102        # Get assimilation order as a list
103        # must subtract one to be inline
104        self.assim_index = [self.keys_da['obsname'],
105                            self.keys_da['assimindex'][self.iteration-1]]
107        # Get list of data types to be assimilated and of the free states. Do this once, because listing keys from a
108        # Python dictionary just when needed (in different places) may not yield the same list!
109        self.list_datatypes, list_active_dataypes = at.get_list_data_types(
110            self.obs_data, self.assim_index)
112        # Augment observed and predicted data
113        self.obs_data_vector, self.aug_pred_data = at.aug_obs_pred_data(self.obs_data, self.pred_data, self.assim_index,
114                                                                        self.list_datatypes)
115        self.cov_data = at.gen_covdata(
116            self.datavar, self.assim_index, self.list_datatypes)
118        init_en = Cholesky()  # Initialize GeoStat class for generating realizations
119        self.data_random_state = deepcopy(np.random.get_state())
120        self.real_obs_data, self.scale_data = init_en.gen_real(self.obs_data_vector, self.cov_data, self.ne,
121                                                               return_chol=True)
123        self.E = np.dot(self.real_obs_data, self.proj)
125        if 'localanalysis' in self.keys_da:
126            self.local_analysis_update()
127        else:
128            # Mean pred_data and perturbation matrix with scaling
129            if len(self.scale_data.shape) == 1:
130                self.pert_preddata = np.dot(np.expand_dims(self.scale_data ** (-1), axis=1),
131                                            np.ones((1, self.ne))) * np.dot(self.aug_pred_data, self.proj)
132            else:
133                self.pert_preddata = solve(
134                    self.scale_data, np.dot(self.aug_pred_data, self.proj))
136            aug_state = at.aug_state(self.current_state, self.list_states)
137            self.update()
138            if hasattr(self, 'step'):
139                aug_state_upd = aug_state + self.step
140            if hasattr(self, 'w_step'):
141                self.W = self.current_W + self.w_step
142                aug_prior_state = at.aug_state(self.prior_state, self.list_states)
143                aug_state_upd = np.dot(aug_prior_state, (np.eye(
144                    self.ne) + self.W / np.sqrt(self.ne - 1)))
145            # Extract updated state variables from aug_update
146            self.state = at.update_state(aug_state_upd, self.state, self.list_states)
147            self.state = at.limits(self.state, self.prior_info)
149    def check_convergence(self):
150        """
151        Calculate the "convergence" of the method. Important to
152        """
153        self.prev_data_misfit = self.prior_data_misfit
154        # only calulate for the final (posterior) estimate
155        if self.iteration == len(self.keys_da['assimindex']):
156            assim_index = [self.keys_da['obsname'], list(
157                np.concatenate(self.keys_da['assimindex']))]
158            list_datatypes = self.list_datatypes
159            obs_data_vector, pred_data = at.aug_obs_pred_data(self.obs_data, self.pred_data, assim_index,
160                                                              list_datatypes)
162            data_misfit = at.calc_objectivefun(
163                self.full_real_obs_data, pred_data, self.full_cov_data)
164            self.data_misfit = np.mean(data_misfit)
165            self.data_misfit_std = np.std(data_misfit)
167        else:  # sequential updates not finished. Misfit is not relevant
168            self.data_misfit = self.prior_data_misfit
170        # Logical variables for conv. criteria
171        why_stop = {'rel_data_misfit': 1 - (self.data_misfit / self.prev_data_misfit),
172                    'data_misfit': self.data_misfit,
173                    'prev_data_misfit': self.prev_data_misfit}
175        self.current_state = deepcopy(self.state)
176        if self.data_misfit == self.prev_data_misfit:
177            self.logger.info(
178                f'EnKF update {self.iteration} complete!')
179        else:
180            if self.data_misfit < self.prior_data_misfit:
181                self.logger.info(
182                    f'EnKF update complete! Objective function decreased from {self.prior_data_misfit:0.1f} to {self.data_misfit:0.1f}.')
183            else:
184                self.logger.info(
185                    f'EnKF update complete! Objective function increased from {self.prior_data_misfit:0.1f} to {self.data_misfit:0.1f}.')
186        # Return conv = False, why_stop var.
187        return False, True, why_stop
190class enkf_approx(enkfMixIn, approx_update):
191    """
192    MixIn the main EnKF update class with the standard analysis scheme.
193    """
194    pass
197class enkf_full(enkfMixIn, approx_update):
198    """
199    MixIn the main EnKF update class with the standard analysis scheme. Note that this class is only included for
200    completness. The EnKF does not iterate, and the standard scheme is therefor always applied.
201    """
202    pass
205class enkf_subspace(enkfMixIn, subspace_update):
206    """
207    MixIn the main EnKF update class with the subspace analysis scheme.
208    """
209    pass
