pipt.loop.assimilation

Descriptive description.

  1"""Descriptive description."""
  2
  3# External imports
  4import numpy as np
  5from tqdm import tqdm
  6from p_tqdm import p_map
  7import pickle
  8from copy import deepcopy
  9import sys
 10import os
 11from shutil import rmtree
 12import datetime as dt
 13import random
 14import psutil
 15from copy import copy
 16from importlib import import_module
 17
 18# Internal imports
 19from pipt.misc_tools import qaqc_tools
 20from pipt.loop.ensemble import Ensemble
 21from misc.system_tools.environ_var import OpenBlasSingleThread
 22from pipt.misc_tools import analysis_tools as at
 23
 24
 25class Assimilate:
 26    """
 27    Class for iterative ensemble-based methods. This loop is similar/equal to a deterministic/optimization loop, but
 28    since we use ensemble-based method, we need to invoke `pipt.fwd_sim.ensemble.Ensemble` to get correct hierarchy of
 29    classes. The iterative loop will go until the max. iterations OR convergence has been met. Parameters for both these
 30    stopping criteria have to be given by the user through methods in their `pipt.update_schemes` class. Note that only
 31    iterative ensemble smoothers can be implemented with this loop (at the moment). Methods needed to be provided by
 32    user in their update_schemes class:  
 33
 34    `calc_analysis`  
 35    `check_convergence`  
 36
 37    % Copyright (c) 2019-2022 NORCE, All Rights Reserved. 4DSEIS
 38    """
 39    # TODO: Sequential iterative loop
 40
 41    def __init__(self, ensemble: Ensemble):
 42        """
 43        Initialize by passing the PIPT init. file up the hierarchy.
 44
 45        Parameters
 46        ----------
 47        init_file: str
 48            PIPT init. filename
 49        """
 50        # Internalize ensemble and simulator class instances
 51        self.ensemble = ensemble
 52
 53        if self.ensemble.restart is False:
 54            # Default max. iter if not defined in the ensemble
 55            if hasattr(ensemble, 'max_iter'):
 56                self.max_iter = self.ensemble.max_iter
 57            else:
 58                self.max_iter = self._ext_max_iter()
 59
 60            # Within variables
 61            self.why_stop = None    # Output of why iter. loop stopped
 62
 63            self.scale_val = []  # Used to scale seismic data
 64
 65            # This feature is removed
 66            # Initialize temporary storage of state variable during the assimilation (if option is supplied in DATAASSIM
 67            # part). Save initially regardless of which option you have chosen as long as it is not 'no'
 68            # if 'tempsave' in self.ensemble.keys_da and self.ensemble.keys_da['tempsave'] != 'no':
 69            #     self.ensemble.save_temp_state_iter(0, self.max_iter)  # save init. ensemble
 70
 71    def run(self):
 72        """
 73        The general loop implemented here is:
 74
 75        <ol>
 76            <li>Forecast/forward simulation</li>
 77            <li>Check for convergence</li>
 78            <li>If convergence have not been achieved, do analysis/update</li>
 79        </ol>
 80
 81        % Copyright (c) 2019-2022 NORCE, All Rights Reserved. 4DSEIS
 82        """
 83        # TODO: Implement a 'calc_sensitivity' method in the loop. For now it is assumed that the sensitivity is
 84        # calculated in 'calc_analysis' using some kind of ensemble approximation.
 85
 86        # Init. while loop condition variable
 87        conv = False
 88        success_iter = True
 89
 90        # Initiallize progressbar
 91        pbar_out = tqdm(total=self.max_iter,
 92                        desc='Iterations (Obj. func. val: )', position=0)
 93
 94        # Check if we want to perform a Quality Assurance of the forecast
 95        qaqc = None
 96        if 'qa' in self.ensemble.sim.input_dict or 'qc' in self.ensemble.keys_da:
 97            qaqc = qaqc_tools.QAQC({**self.ensemble.keys_da, **self.ensemble.sim.input_dict},
 98                                   self.ensemble.obs_data, self.ensemble.datavar, self.ensemble.logger,
 99                                   self.ensemble.prior_info, self.ensemble.sim, self.ensemble.prior_state)
100
101        # Run a while loop until max. iterations or convergence is reached
102        while self.ensemble.iteration < self.max_iter and conv is False:
103            # Add a check to see if this is the prior model
104            if self.ensemble.iteration == 0:
105                # Calc forecast for prior model
106                # Inset 0 as input to forecast all data
107                self.calc_forecast()
108
109                # remove outliers
110                if 'remove_outliers' in self.ensemble.sim.input_dict:
111                    self.remove_outliers()
112
113                if 'qa' in self.ensemble.keys_da:  # Check if we want to perform a Quality Assurance of the forecast
114                    # set updated prediction, state and lam
115                    qaqc.set(self.ensemble.pred_data,
116                             self.ensemble.state, self.ensemble.lam)
117                    # Level 1,2 all data, and subspace
118                    qaqc.calc_mahalanobis((1, 'time', 2, 'time', 1, None, 2, None))
119                    qaqc.calc_coverage()  # Compute data coverage
120                    qaqc.calc_kg({'plot_all_kg': True, 'only_log': False,
121                                 'num_store': 5})  # Compute kalman gain
122
123                success_iter = True
124
125                # always store prior forcast, unless specifically told not to
126                if 'nosave' not in self.ensemble.keys_da:
127                    np.savez('prior_forecast.npz', **
128                             {'pred_data': self.ensemble.pred_data})
129
130            # For the remaining iterations we start by applying the analysis and finish by running the forecast
131            else:
132                # Analysis (in the update_scheme class)
133                self.ensemble.calc_analysis()
134
135                if 'qa' in self.ensemble.keys_da and 'screendata' in self.ensemble.keys_da and \
136                        self.ensemble.keys_da['screendata'] == 'yes' and self.ensemble.iteration == 1:
137                    #  need to update datavar, and recompute mahalanobis measures
138                    self.logger.info(
139                        'Recomputing Mahalanobis distance with updated datavar')
140                    qaqc.datavar = self.datavar  # this is updated from calc_analysis
141                    # Level 1,2 all data, and subspace
142                    qaqc.calc_mahalanobis((1, 'time', 2, 'time', 1, None, 2, None))
143
144                # Forecast with the updated state
145                self.calc_forecast()
146
147                if 'remove_outliers' in self.ensemble.keys_da:
148                    self.remove_outliers()
149
150                # Check convergence (in the update_scheme class). Outputs logical variable to tell the while loop to
151                # stop, and a variable telling what criteria for convergence was reached.
152                # Also check if the objective function has been reduced, and use this function to accept the state and
153                # update the lambda values.
154                #
155                conv, success_iter, self.why_stop = self.ensemble.check_convergence()
156
157            # if reduction of objective function -> save the state
158            if success_iter:
159                # More general method to save all relevant information from an iteration analysis/forecast step
160                if 'iterinfo' in self.ensemble.keys_da:
161                    #
162                    self._save_iteration_information()
163                if self.ensemble.iteration > 0:
164                    # Temporary save state if options in TEMPSAVE have been given and the option is not 'no'
165                    if 'tempsave' in self.ensemble.keys_da and self.ensemble.keys_da['tempsave'] != 'no':
166                        self._save_during_iteration(self.ensemble.keys_da['tempsave'])
167                    if 'analysisdebug' in self.ensemble.keys_da:
168                        self._save_analysis_debug()
169                    if 'qc' in self.ensemble.keys_da:  # Check if we want to perform a Quality Control of the updated state
170                        # set updated prediction, state and lam
171                        qaqc.set(self.ensemble.pred_data,
172                                 self.ensemble.state, self.ensemble.lam)
173                        qaqc.calc_da_stat()  # Compute statistics for updated parameters
174                    if 'qa' in self.ensemble.keys_da:  # Check if we want to perform a Quality Assurance of the forecast
175                        # set updated prediction, state and lam
176                        qaqc.set(self.ensemble.pred_data,
177                                 self.ensemble.state, self.ensemble.lam)
178                        qaqc.calc_mahalanobis(
179                            (1, 'time', 2, 'time', 1, None, 2, None))  # Level 1,2 all data, and subspace
180                        #  qaqc.calc_coverage()  # Compute data coverage
181                        qaqc.calc_kg()  # Compute kalman gain
182
183            # Update iteration counter if iteration was successful
184            if self.ensemble.iteration >= 0 and success_iter is True:
185                if self.ensemble.iteration == 0:
186                    self.ensemble.iteration += 1
187                    pbar_out.update(1)
188                    # pbar_out.set_description(f'Iterations (Obj. func. val:{self.data_misfit:.1f})')
189                    # self.prior_data_misfit = self.data_misfit
190                    # self.pbar_out.refresh()
191                else:
192                    self.ensemble.iteration += 1
193                    pbar_out.update(1)
194                    pbar_out.set_description(
195                        f'Iterations (Obj. func. val:{self.ensemble.data_misfit:.1f}'
196                        f' Reduced: {100 * (1 - (self.ensemble.data_misfit / self.ensemble.prev_data_misfit)):.0f} %)')
197                    # self.pbar_out.refresh()
198
199            if 'restartsave' in self.ensemble.keys_da and self.ensemble.keys_da['restartsave'] == 'yes':
200                self.ensemble.save()
201
202        # always store posterior forcast and state, unless specifically told not to
203        if 'nosave' not in self.ensemble.keys_da:
204            try: # first try to save as npz file
205                np.savez('posterior_state_estimate.npz', **self.ensemble.state)
206                np.savez('posterior_forecast.npz', **{'pred_data': self.ensemble.pred_data})
207            except: # If this fails, store as pickle
208                with open('posterior_state_estimate.p', 'wb') as file:
209                    pickle.dump(self.ensemble.state, file)
210                with open('posterior_forecast.p', 'wb') as file:
211                    pickle.dump(self.ensemble.pred_data, file)
212
213        # If none of the convergence criteria were met, max. iteration was the reason iterations stopped.
214        if conv is False:
215            reason = 'Iterations stopped due to max iterations reached!'
216        else:
217            reason = 'Convergence was met :)'
218
219        # Save why_stop in Numpy save file
220        # savez('why_iter_loop_stopped', why=self.why_stop, conv_string=reason)
221
222        # Save why_stop in pickle save file
223        why = self.why_stop
224        if why is not None:
225            why['conv_string'] = reason
226        with open('why_iter_loop_stopped.p', 'wb') as f:
227            pickle.dump(why, f, protocol=4)
228        # pbar.close()
229        pbar_out.close()
230        if self.ensemble.prev_data_misfit is not None:
231            out_str = 'Convergence was met.'
232            if self.ensemble.prior_data_misfit > self.ensemble.data_misfit:
233                out_str += f' Obj. function reduced from {self.ensemble.prior_data_misfit:0.1f} ' \
234                           f'to {self.ensemble.data_misfit:0.1f}'
235            tqdm.write(out_str)
236            self.ensemble.logger.info(out_str)
237
238    def remove_outliers(self):
239
240        # function to remove ouliers
241
242        # get the cov data
243        prod_obs = np.array([])
244
245        prod_cov = np.array([])
246        prod_pred = np.empty([0, self.ensemble.ne])
247        for i in range(len(self.ensemble.obs_data)):
248            for key in self.ensemble.obs_data[i].keys():
249                if self.ensemble.obs_data[i][key] is not None and self.ensemble.obs_data[i][key].shape == (1,):
250                    prod_obs = np.concatenate((prod_obs, self.ensemble.obs_data[i][key]))
251                    prod_cov = np.concatenate((prod_cov, self.ensemble.datavar[i][key]))
252                    prod_pred = np.concatenate(
253                        (prod_pred, self.ensemble.pred_data[i][key]))
254
255        mat_prod_obs = np.dot(prod_obs.reshape((len(prod_obs), 1)),
256                              np.ones((1, self.ensemble.ne)))
257
258        hm = np.diag(np.dot((prod_pred - mat_prod_obs).T, np.dot(np.expand_dims(prod_cov ** (-1), axis=1),
259                                                                 np.ones((1, self.ensemble.ne))) * (prod_pred - mat_prod_obs)))
260        hm_std = np.std(hm)
261        hm_mean = np.mean(hm)
262        outliers = np.argwhere(np.abs(hm - hm_mean) > 4 * hm_std)
263        print('Outliers: ' + str(np.squeeze(outliers)))
264        members = np.arange(self.ensemble.ne)
265        members = np.delete(members, outliers)
266        for index in outliers.flatten():
267
268            new_index = np.random.choice(members)
269
270            # replace state
271            for el in self.ensemble.state.keys():
272                self.ensemble.state[el][:, index] = deepcopy(
273                    self.ensemble.state[el][:, new_index])
274
275            # replace the failed forecast
276            for i, data_ind in enumerate(self.ensemble.pred_data):
277                if self.ensemble.pred_data[i] is not None:
278                    for el in data_ind.keys():
279                        if self.ensemble.pred_data[i][el] is not None:
280                            if type(self.ensemble.pred_data[i][el]) is list:
281                                self.ensemble.pred_data[i][el][index] = deepcopy(
282                                    self.ensemble.pred_data[i][el][new_index])
283                            else:
284                                self.ensemble.pred_data[i][el][:, index] = deepcopy(
285                                    self.ensemble.pred_data[i][el][:, new_index])
286
287    def _ext_max_iter(self):
288        """
289        Extract max iterations from ITERATION keyword in DATAASSIM part (mandatory keyword for iteration loops).
290
291        Parameters
292        ----------
293        keys_da : dict
294            A dictionary containing all keywords from DATAASSIM part.
295            - 'iteration' : object
296                Information for iterative methods.
297
298        Returns
299        -------
300        max_iter : int
301            The maximum number of iterations allowed before abort.
302
303        Changelog
304        ---------
305        - ST 7/6-16
306        """
307        if 'iteration' in self.ensemble.keys_da:
308            # Make sure ITERATION is a list
309            if not isinstance(self.ensemble.keys_da['iteration'][0], list):
310                iter_opts = [self.ensemble.keys_da['iteration']]
311            else:
312                iter_opts = self.ensemble.keys_da['iteration']
313
314            # Check if 'max_iter' has been given; if not, give error (mandatory in ITERATION)
315            assert 'max_iter' in list(
316                zip(*iter_opts))[0], 'MAX_ITER has not been given in ITERATION!'
317
318            # Extract max. iter
319            max_iter = [item[1] for item in iter_opts if item[0] == 'max_iter'][0]
320
321        elif 'mda' in self.ensemble.keys_da:
322            # Make sure ITERATION is a list
323            if not isinstance(self.ensemble.keys_da['mda'][0], list):
324                iter_opts = [self.ensemble.keys_da['mda']]
325            else:
326                iter_opts = self.ensemble.keys_da['mda']
327
328            # Check if 'max_iter' has been given; if not, give error (mandatory in ITERATION)
329            assert 'tot_assim_steps' in list(
330                zip(*iter_opts))[0], 'TOT_ASSIM_STEPS has not been given in MDA!'
331
332            # Extract max. iter
333            max_iter = [item[1] for item in iter_opts if item[0] == 'tot_assim_steps'][0]
334
335        else:
336            max_iter = 1
337        # Return max. iter
338        return max_iter
339
340    def _save_iteration_information(self):
341        """
342        More general method for saving all relevant information from a analysis/forecast step. Note that this is
343        only performed when there is a reduction in objective function.
344
345        Parameters
346        ----------
347        values : list
348            List of values to be saved. It can also contain a separate Python file.
349
350        If one reads a python file, it is
351        """
352        # Make sure "ANALYSISDEBUG" gives a list
353        if isinstance(self.ensemble.keys_da['iterinfo'], list):
354            saveinfo = self.ensemble.keys_da['iterinfo']
355        else:
356            saveinfo = [self.ensemble.keys_da['iterinfo']]
357
358        for el in saveinfo:
359            if '.py' in el:  # This is a unique python file
360                iter_info_func = import_module(el.strip('.py'))
361                # Note: the function must be named main, and we pass the full current instance of the object.
362                iter_info_func.main(self)
363
364    def _save_during_iteration(self, tempsave):
365        """
366        Save during an iteration. How often is determined by the `TEMPSAVE` keyword; confer the manual for all the
367        different options.
368
369        Parameters
370        ----------
371        tempsave: list
372            Info. from the TEMPSAVE keyword
373        """
374        self.ensemble.logger.info(
375            'The TEMPSAVE feature is no longer supported. Please you debug_analyses, or iterinfo.')
376        # Save at specific points
377        # if isinstance(tempsave, list):
378        #     # Save at regular intervals
379        #     if tempsave[0] == 'each' or tempsave[0] == 'every' and self.ensemble.iteration % tempsave[1] == 0:
380        #         self.ensemble.save_temp_state_iter(self.ensemble.iteration + 1, self.max_iter)
381        #
382        #     # Save at points given by input
383        #     elif tempsave[0] == 'list' or tempsave[0] == 'at':
384        #         # Check if one or more save points have been given, and save if we are at that point
385        #         savepoint = tempsave[1] if isinstance(tempsave[1], list) else [tempsave[1]]
386        #         if self.ensemble.iteration in savepoint:
387        #             self.ensemble.save_temp_state_iter(self.ensemble.iteration + 1, self.max_iter)
388        #
389        # # Save at all assimilation steps
390        # elif tempsave == 'yes' or tempsave == 'all':
391        #     self.ensemble.save_temp_state_iter(self.ensemble.iteration + 1, self.max_iter)
392
393    def _save_analysis_debug(self):
394        """
395        Moved Old analysis debug here to retain consistency.
396
397        .. danger:: only class variables can be stored now.
398        """
399        # Init dict. of variables to save
400        save_dict = {}
401
402        # Make sure "ANALYSISDEBUG" gives a list
403        if isinstance(self.ensemble.keys_da['analysisdebug'], list):
404            analysisdebug = self.ensemble.keys_da['analysisdebug']
405        else:
406            analysisdebug = [self.ensemble.keys_da['analysisdebug']]
407
408        # Loop over variables to store in save list
409        for save_typ in analysisdebug:
410            if hasattr(self, save_typ):
411                save_dict[save_typ] = eval('self.{}'.format(save_typ))
412            elif hasattr(self.ensemble, save_typ):
413                save_dict[save_typ] = eval('self.ensemble.{}'.format(save_typ))
414            # Save with key equal variable name and the actual variable
415            else:
416                print(f'Cannot save {save_typ}, because it is a local variable!\n\n')
417
418        # Save the variables
419        at.save_analysisdebug(self.ensemble.iteration, **save_dict)
420
421    def calc_forecast(self):
422        """
423        Calculate the forecast step.
424
425        Run the forward simulator, generating predicted data for the analysis step. First input to the simulator
426        instances is the ensemble of (joint) state to be run and how many to run in parallel. The forward runs are done
427        in a while-loop consisting of the following steps:
428
429                1. Run the simulator for each ensemble member in the background.
430                2. Check for errors during run (if error, correct and run again or abort).
431                3. Check if simulation has ended; if yes, run simulation for the next ensemble members.
432                4. Get results from successfully ended simulations.
433
434        The procedure here is general, hence a simulator used here must contain the initial step of setting up the
435        parameters and steps i-iv, if not an error will be outputted. Initialization of the simulator is done when
436        initializing the Ensemble class (see __init__). The names of the mandatory methods in a simulator are:
437
438                > setup_fwd_sim
439                > run_fwd_sim
440                > check_sim_end
441                > get_sim_results
442
443        Parameters
444        ----------
445        assim_step : int
446                     Current assimilation step.
447
448        Notes
449        -----
450        Parallel run in "ampersand" mode means that it will be started in the background and run independently of the
451        Python script. Hence, check for simulation finished or error must be conducted!
452
453        .. info:: It is only necessary to get the results from the forward simulations that corresponds to the observed
454        data at the particular assimilation step. That is, results from all data types are not necessary to
455        extract at step iv; if they are not present in the obs_data (indicated by a None type) then this result does
456        not need to be extracted.
457
458        .. info:: It is assumed that no underscore is inputted in DATATYPE. If there are underscores in DATATYPE
459        entries, well, then we may have a problem when finding out which response to extract in get_sim_results below.
460        """
461        # Add an option to load existing sim results. The user must actively create the restart file by renaming an
462        # existing sim_results.p file to restart_sim_results.p.
463        if os.path.exists('restart_sim_results.p'):
464            with open('restart_sim_results.p', 'rb') as f:
465                self.ensemble.pred_data = pickle.load(f)
466            os.rename('restart_sim_results.p', 'sim_results.p')
467            print('--- Restart sim results used ---')
468            return
469
470        # If we are doing an sequential assimilation, such as enkf, we loop over assimilation steps
471        if len(self.ensemble.keys_da['assimindex']) > 1:
472            assim_step = self.ensemble.iteration
473        else:
474            assim_step = 0
475
476        # Get assimilation order as a list where first entry are the string(s) in OBSNAME and second entry are
477        # the associated array(s)
478        if assim_step == 0 or assim_step == len(self.ensemble.keys_da['assimindex']):
479            assim_ind = [self.ensemble.keys_da['obsname'], list(
480                np.concatenate(self.ensemble.keys_da['assimindex']))]
481        else:
482            assim_ind = [self.ensemble.keys_da['obsname'],
483                         self.ensemble.keys_da['assimindex'][assim_step]]
484
485        # Get TRUEDATAINDEX
486        true_order = [self.ensemble.keys_da['obsname'],
487                      self.ensemble.keys_da['truedataindex']]
488
489        # List assim. index
490        if isinstance(true_order[1], list):  # Check if true data prim. ind. is a list
491            true_prim = [true_order[0], [x for x in true_order[1]]]
492        else:  # Float
493            true_prim = [true_order[0], [true_order[1]]]
494        if isinstance(assim_ind[1], list):  # Check if prim. ind. is a list
495            l_prim = [int(x) for x in assim_ind[1]]
496        else:  # Float
497            l_prim = [int(assim_ind[1])]
498
499        # Run forecast. Predicted data solved in self.ensemble.pred_data
500        self.ensemble.calc_prediction()
501
502        # Filter pred. data needed at current assimilation step. This essentially means deleting pred. data not
503        # contained in the assim. indices for current assim. step or does not have obs. data at this index
504        self.ensemble.pred_data = [elem for i, elem in enumerate(self.ensemble.pred_data) if i in l_prim or
505                                   true_prim[1][i] is not None]
506
507        # Scale data if required (currently only one group of data can be scaled)
508        if 'scale' in self.ensemble.keys_da:
509            for pred_data in self.ensemble.pred_data:
510                for key in pred_data:
511                    if key in self.ensemble.keys_da['scale'][0]:
512                        pred_data[key] *= self.ensemble.keys_da['scale'][1]
513
514        # Post process predicted data if wanted
515        if 'post_process_forecast' in self.ensemble.keys_da and self.ensemble.keys_da['post_process_forecast'] == 'yes':
516            self.post_process_forecast()
517
518        # If we have dynamic variables, and we are in the first assimilation step, we must convert lists to (2D)
519        # numpy arrays
520        if 'dynamicvar' in self.ensemble.keys_da and assim_step == 0:
521            for dyn_state in self.ensemble.keys_da['dynamicvar']:
522                self.ensemble.state[dyn_state] = np.array(
523                    self.ensemble.state[dyn_state]).T
524
525        # Extra option debug
526        if 'saveforecast' in self.ensemble.sim.input_dict:
527            with open('sim_results.p', 'wb') as f:
528                pickle.dump(self.ensemble.pred_data, f)
529
530    def post_process_forecast(self):
531        """
532        Post processing of predicted data after a forecast run
533        """
534        # Temporary storage of seismic data that need to be scaled
535        pred_data_tmp = [None for _ in self.ensemble.pred_data]
536
537        # Loop over pred data and store temporary
538        if self.ensemble.sparse_info is not None:
539            for i, pred_data in enumerate(self.ensemble.pred_data):
540                for key in pred_data:
541                    # Reset vintage
542                    vintage = 0
543
544                    # Store according to sparse_info
545                    if vintage < len(self.ensemble.sparse_info['mask']) and \
546                            pred_data[key].shape[0] == int(np.sum(self.ensemble.sparse_info['mask'][vintage])):
547
548                        # If first entry in pred_data_tmp
549                        if pred_data_tmp[i] is None:
550                            pred_data_tmp[i] = {key: pred_data[key]}
551
552                        else:
553                            pred_data_tmp[i][key] = pred_data[key]
554
555                        # Update vintage
556                        vintage += 1
557
558        # Scaling used in sim2seis
559        if os.path.exists('scale_results.p'):
560            if not self.scale_val:
561                with open('scale_results.p', 'rb') as f:
562                    scale = pickle.load(f)
563                # base the scaling on the first dataset and the first iteration
564                self.scale_val = np.sum(scale[0]) / len(scale[0])
565
566            if self.ensemble.sparse_info is not None:
567                for i in range(len(pred_data_tmp)):  # INDEX
568                    if pred_data_tmp[i] is not None:
569                        for k in pred_data_tmp[i]:  # DATATYPE
570                            if 'sim2seis' in k and pred_data_tmp[i][k] is not None:
571                                pred_data_tmp[i][k] = pred_data_tmp[i][k] / self.scale_val
572
573            else:
574                for i in range(len(self.ensemble.pred_data)):  # TRUEDATAINDEX
575                    for k in self.ensemble.pred_data[i]:  # DATATYPE
576                        if 'sim2seis' in k and self.ensemble.pred_data[i][k] is not None:
577                            self.ensemble.pred_data[i][k] = self.ensemble.pred_data[i][k] / \
578                                self.scale_val
579
580        # If wavelet compression is based on the simulated data, we need to recompute obs_data, datavar and pred_data.
581        if self.ensemble.sparse_info:
582            vintage = 0
583            self.ensemble.data_rec = []
584            for i in range(len(pred_data_tmp)):  # INDEX
585                if pred_data_tmp[i] is not None:
586                    for k in pred_data_tmp[i]:  # DATATYPE
587                        if vintage < len(self.ensemble.sparse_info['mask']) and \
588                                len(pred_data_tmp[i][k]) == int(np.sum(self.ensemble.sparse_info['mask'][vintage])):
589                            self.ensemble.pred_data[i][k] = np.zeros(
590                                (len(self.ensemble.obs_data[i][k]), self.ensemble.ne))
591                            for m in range(pred_data_tmp[i][k].shape[1]):
592                                data_array = self.ensemble.compress(pred_data_tmp[i][k][:, m], vintage,
593                                                                    self.ensemble.sparse_info['use_ensemble'])
594                                self.ensemble.pred_data[i][k][:, m] = data_array
595                            vintage = vintage + 1
596            if self.ensemble.sparse_info['use_ensemble']:
597                self.ensemble.compress()
598                self.ensemble.sparse_info['use_ensemble'] = None
599
600        # Extra option debug
601        if 'saveforecast' in self.ensemble.sim.input_dict:
602            # Save the reconstructed signal for later analysis
603            if self.ensemble.sparse_data:
604                for vint in np.arange(len(self.ensemble.data_rec)):
605                    self.ensemble.data_rec[vint] = np.asarray(
606                        self.ensemble.data_rec[vint]).T
607                with open('rec_results.p', 'wb') as f:
608                    pickle.dump(self.ensemble.data_rec, f)
class Assimilate:
 26class Assimilate:
 27    """
 28    Class for iterative ensemble-based methods. This loop is similar/equal to a deterministic/optimization loop, but
 29    since we use ensemble-based method, we need to invoke `pipt.fwd_sim.ensemble.Ensemble` to get correct hierarchy of
 30    classes. The iterative loop will go until the max. iterations OR convergence has been met. Parameters for both these
 31    stopping criteria have to be given by the user through methods in their `pipt.update_schemes` class. Note that only
 32    iterative ensemble smoothers can be implemented with this loop (at the moment). Methods needed to be provided by
 33    user in their update_schemes class:  
 34
 35    `calc_analysis`  
 36    `check_convergence`  
 37
 38    % Copyright (c) 2019-2022 NORCE, All Rights Reserved. 4DSEIS
 39    """
 40    # TODO: Sequential iterative loop
 41
 42    def __init__(self, ensemble: Ensemble):
 43        """
 44        Initialize by passing the PIPT init. file up the hierarchy.
 45
 46        Parameters
 47        ----------
 48        init_file: str
 49            PIPT init. filename
 50        """
 51        # Internalize ensemble and simulator class instances
 52        self.ensemble = ensemble
 53
 54        if self.ensemble.restart is False:
 55            # Default max. iter if not defined in the ensemble
 56            if hasattr(ensemble, 'max_iter'):
 57                self.max_iter = self.ensemble.max_iter
 58            else:
 59                self.max_iter = self._ext_max_iter()
 60
 61            # Within variables
 62            self.why_stop = None    # Output of why iter. loop stopped
 63
 64            self.scale_val = []  # Used to scale seismic data
 65
 66            # This feature is removed
 67            # Initialize temporary storage of state variable during the assimilation (if option is supplied in DATAASSIM
 68            # part). Save initially regardless of which option you have chosen as long as it is not 'no'
 69            # if 'tempsave' in self.ensemble.keys_da and self.ensemble.keys_da['tempsave'] != 'no':
 70            #     self.ensemble.save_temp_state_iter(0, self.max_iter)  # save init. ensemble
 71
 72    def run(self):
 73        """
 74        The general loop implemented here is:
 75
 76        <ol>
 77            <li>Forecast/forward simulation</li>
 78            <li>Check for convergence</li>
 79            <li>If convergence have not been achieved, do analysis/update</li>
 80        </ol>
 81
 82        % Copyright (c) 2019-2022 NORCE, All Rights Reserved. 4DSEIS
 83        """
 84        # TODO: Implement a 'calc_sensitivity' method in the loop. For now it is assumed that the sensitivity is
 85        # calculated in 'calc_analysis' using some kind of ensemble approximation.
 86
 87        # Init. while loop condition variable
 88        conv = False
 89        success_iter = True
 90
 91        # Initiallize progressbar
 92        pbar_out = tqdm(total=self.max_iter,
 93                        desc='Iterations (Obj. func. val: )', position=0)
 94
 95        # Check if we want to perform a Quality Assurance of the forecast
 96        qaqc = None
 97        if 'qa' in self.ensemble.sim.input_dict or 'qc' in self.ensemble.keys_da:
 98            qaqc = qaqc_tools.QAQC({**self.ensemble.keys_da, **self.ensemble.sim.input_dict},
 99                                   self.ensemble.obs_data, self.ensemble.datavar, self.ensemble.logger,
100                                   self.ensemble.prior_info, self.ensemble.sim, self.ensemble.prior_state)
101
102        # Run a while loop until max. iterations or convergence is reached
103        while self.ensemble.iteration < self.max_iter and conv is False:
104            # Add a check to see if this is the prior model
105            if self.ensemble.iteration == 0:
106                # Calc forecast for prior model
107                # Inset 0 as input to forecast all data
108                self.calc_forecast()
109
110                # remove outliers
111                if 'remove_outliers' in self.ensemble.sim.input_dict:
112                    self.remove_outliers()
113
114                if 'qa' in self.ensemble.keys_da:  # Check if we want to perform a Quality Assurance of the forecast
115                    # set updated prediction, state and lam
116                    qaqc.set(self.ensemble.pred_data,
117                             self.ensemble.state, self.ensemble.lam)
118                    # Level 1,2 all data, and subspace
119                    qaqc.calc_mahalanobis((1, 'time', 2, 'time', 1, None, 2, None))
120                    qaqc.calc_coverage()  # Compute data coverage
121                    qaqc.calc_kg({'plot_all_kg': True, 'only_log': False,
122                                 'num_store': 5})  # Compute kalman gain
123
124                success_iter = True
125
126                # always store prior forcast, unless specifically told not to
127                if 'nosave' not in self.ensemble.keys_da:
128                    np.savez('prior_forecast.npz', **
129                             {'pred_data': self.ensemble.pred_data})
130
131            # For the remaining iterations we start by applying the analysis and finish by running the forecast
132            else:
133                # Analysis (in the update_scheme class)
134                self.ensemble.calc_analysis()
135
136                if 'qa' in self.ensemble.keys_da and 'screendata' in self.ensemble.keys_da and \
137                        self.ensemble.keys_da['screendata'] == 'yes' and self.ensemble.iteration == 1:
138                    #  need to update datavar, and recompute mahalanobis measures
139                    self.logger.info(
140                        'Recomputing Mahalanobis distance with updated datavar')
141                    qaqc.datavar = self.datavar  # this is updated from calc_analysis
142                    # Level 1,2 all data, and subspace
143                    qaqc.calc_mahalanobis((1, 'time', 2, 'time', 1, None, 2, None))
144
145                # Forecast with the updated state
146                self.calc_forecast()
147
148                if 'remove_outliers' in self.ensemble.keys_da:
149                    self.remove_outliers()
150
151                # Check convergence (in the update_scheme class). Outputs logical variable to tell the while loop to
152                # stop, and a variable telling what criteria for convergence was reached.
153                # Also check if the objective function has been reduced, and use this function to accept the state and
154                # update the lambda values.
155                #
156                conv, success_iter, self.why_stop = self.ensemble.check_convergence()
157
158            # if reduction of objective function -> save the state
159            if success_iter:
160                # More general method to save all relevant information from an iteration analysis/forecast step
161                if 'iterinfo' in self.ensemble.keys_da:
162                    #
163                    self._save_iteration_information()
164                if self.ensemble.iteration > 0:
165                    # Temporary save state if options in TEMPSAVE have been given and the option is not 'no'
166                    if 'tempsave' in self.ensemble.keys_da and self.ensemble.keys_da['tempsave'] != 'no':
167                        self._save_during_iteration(self.ensemble.keys_da['tempsave'])
168                    if 'analysisdebug' in self.ensemble.keys_da:
169                        self._save_analysis_debug()
170                    if 'qc' in self.ensemble.keys_da:  # Check if we want to perform a Quality Control of the updated state
171                        # set updated prediction, state and lam
172                        qaqc.set(self.ensemble.pred_data,
173                                 self.ensemble.state, self.ensemble.lam)
174                        qaqc.calc_da_stat()  # Compute statistics for updated parameters
175                    if 'qa' in self.ensemble.keys_da:  # Check if we want to perform a Quality Assurance of the forecast
176                        # set updated prediction, state and lam
177                        qaqc.set(self.ensemble.pred_data,
178                                 self.ensemble.state, self.ensemble.lam)
179                        qaqc.calc_mahalanobis(
180                            (1, 'time', 2, 'time', 1, None, 2, None))  # Level 1,2 all data, and subspace
181                        #  qaqc.calc_coverage()  # Compute data coverage
182                        qaqc.calc_kg()  # Compute kalman gain
183
184            # Update iteration counter if iteration was successful
185            if self.ensemble.iteration >= 0 and success_iter is True:
186                if self.ensemble.iteration == 0:
187                    self.ensemble.iteration += 1
188                    pbar_out.update(1)
189                    # pbar_out.set_description(f'Iterations (Obj. func. val:{self.data_misfit:.1f})')
190                    # self.prior_data_misfit = self.data_misfit
191                    # self.pbar_out.refresh()
192                else:
193                    self.ensemble.iteration += 1
194                    pbar_out.update(1)
195                    pbar_out.set_description(
196                        f'Iterations (Obj. func. val:{self.ensemble.data_misfit:.1f}'
197                        f' Reduced: {100 * (1 - (self.ensemble.data_misfit / self.ensemble.prev_data_misfit)):.0f} %)')
198                    # self.pbar_out.refresh()
199
200            if 'restartsave' in self.ensemble.keys_da and self.ensemble.keys_da['restartsave'] == 'yes':
201                self.ensemble.save()
202
203        # always store posterior forcast and state, unless specifically told not to
204        if 'nosave' not in self.ensemble.keys_da:
205            try: # first try to save as npz file
206                np.savez('posterior_state_estimate.npz', **self.ensemble.state)
207                np.savez('posterior_forecast.npz', **{'pred_data': self.ensemble.pred_data})
208            except: # If this fails, store as pickle
209                with open('posterior_state_estimate.p', 'wb') as file:
210                    pickle.dump(self.ensemble.state, file)
211                with open('posterior_forecast.p', 'wb') as file:
212                    pickle.dump(self.ensemble.pred_data, file)
213
214        # If none of the convergence criteria were met, max. iteration was the reason iterations stopped.
215        if conv is False:
216            reason = 'Iterations stopped due to max iterations reached!'
217        else:
218            reason = 'Convergence was met :)'
219
220        # Save why_stop in Numpy save file
221        # savez('why_iter_loop_stopped', why=self.why_stop, conv_string=reason)
222
223        # Save why_stop in pickle save file
224        why = self.why_stop
225        if why is not None:
226            why['conv_string'] = reason
227        with open('why_iter_loop_stopped.p', 'wb') as f:
228            pickle.dump(why, f, protocol=4)
229        # pbar.close()
230        pbar_out.close()
231        if self.ensemble.prev_data_misfit is not None:
232            out_str = 'Convergence was met.'
233            if self.ensemble.prior_data_misfit > self.ensemble.data_misfit:
234                out_str += f' Obj. function reduced from {self.ensemble.prior_data_misfit:0.1f} ' \
235                           f'to {self.ensemble.data_misfit:0.1f}'
236            tqdm.write(out_str)
237            self.ensemble.logger.info(out_str)
238
239    def remove_outliers(self):
240
241        # function to remove ouliers
242
243        # get the cov data
244        prod_obs = np.array([])
245
246        prod_cov = np.array([])
247        prod_pred = np.empty([0, self.ensemble.ne])
248        for i in range(len(self.ensemble.obs_data)):
249            for key in self.ensemble.obs_data[i].keys():
250                if self.ensemble.obs_data[i][key] is not None and self.ensemble.obs_data[i][key].shape == (1,):
251                    prod_obs = np.concatenate((prod_obs, self.ensemble.obs_data[i][key]))
252                    prod_cov = np.concatenate((prod_cov, self.ensemble.datavar[i][key]))
253                    prod_pred = np.concatenate(
254                        (prod_pred, self.ensemble.pred_data[i][key]))
255
256        mat_prod_obs = np.dot(prod_obs.reshape((len(prod_obs), 1)),
257                              np.ones((1, self.ensemble.ne)))
258
259        hm = np.diag(np.dot((prod_pred - mat_prod_obs).T, np.dot(np.expand_dims(prod_cov ** (-1), axis=1),
260                                                                 np.ones((1, self.ensemble.ne))) * (prod_pred - mat_prod_obs)))
261        hm_std = np.std(hm)
262        hm_mean = np.mean(hm)
263        outliers = np.argwhere(np.abs(hm - hm_mean) > 4 * hm_std)
264        print('Outliers: ' + str(np.squeeze(outliers)))
265        members = np.arange(self.ensemble.ne)
266        members = np.delete(members, outliers)
267        for index in outliers.flatten():
268
269            new_index = np.random.choice(members)
270
271            # replace state
272            for el in self.ensemble.state.keys():
273                self.ensemble.state[el][:, index] = deepcopy(
274                    self.ensemble.state[el][:, new_index])
275
276            # replace the failed forecast
277            for i, data_ind in enumerate(self.ensemble.pred_data):
278                if self.ensemble.pred_data[i] is not None:
279                    for el in data_ind.keys():
280                        if self.ensemble.pred_data[i][el] is not None:
281                            if type(self.ensemble.pred_data[i][el]) is list:
282                                self.ensemble.pred_data[i][el][index] = deepcopy(
283                                    self.ensemble.pred_data[i][el][new_index])
284                            else:
285                                self.ensemble.pred_data[i][el][:, index] = deepcopy(
286                                    self.ensemble.pred_data[i][el][:, new_index])
287
288    def _ext_max_iter(self):
289        """
290        Extract max iterations from ITERATION keyword in DATAASSIM part (mandatory keyword for iteration loops).
291
292        Parameters
293        ----------
294        keys_da : dict
295            A dictionary containing all keywords from DATAASSIM part.
296            - 'iteration' : object
297                Information for iterative methods.
298
299        Returns
300        -------
301        max_iter : int
302            The maximum number of iterations allowed before abort.
303
304        Changelog
305        ---------
306        - ST 7/6-16
307        """
308        if 'iteration' in self.ensemble.keys_da:
309            # Make sure ITERATION is a list
310            if not isinstance(self.ensemble.keys_da['iteration'][0], list):
311                iter_opts = [self.ensemble.keys_da['iteration']]
312            else:
313                iter_opts = self.ensemble.keys_da['iteration']
314
315            # Check if 'max_iter' has been given; if not, give error (mandatory in ITERATION)
316            assert 'max_iter' in list(
317                zip(*iter_opts))[0], 'MAX_ITER has not been given in ITERATION!'
318
319            # Extract max. iter
320            max_iter = [item[1] for item in iter_opts if item[0] == 'max_iter'][0]
321
322        elif 'mda' in self.ensemble.keys_da:
323            # Make sure ITERATION is a list
324            if not isinstance(self.ensemble.keys_da['mda'][0], list):
325                iter_opts = [self.ensemble.keys_da['mda']]
326            else:
327                iter_opts = self.ensemble.keys_da['mda']
328
329            # Check if 'max_iter' has been given; if not, give error (mandatory in ITERATION)
330            assert 'tot_assim_steps' in list(
331                zip(*iter_opts))[0], 'TOT_ASSIM_STEPS has not been given in MDA!'
332
333            # Extract max. iter
334            max_iter = [item[1] for item in iter_opts if item[0] == 'tot_assim_steps'][0]
335
336        else:
337            max_iter = 1
338        # Return max. iter
339        return max_iter
340
341    def _save_iteration_information(self):
342        """
343        More general method for saving all relevant information from a analysis/forecast step. Note that this is
344        only performed when there is a reduction in objective function.
345
346        Parameters
347        ----------
348        values : list
349            List of values to be saved. It can also contain a separate Python file.
350
351        If one reads a python file, it is
352        """
353        # Make sure "ANALYSISDEBUG" gives a list
354        if isinstance(self.ensemble.keys_da['iterinfo'], list):
355            saveinfo = self.ensemble.keys_da['iterinfo']
356        else:
357            saveinfo = [self.ensemble.keys_da['iterinfo']]
358
359        for el in saveinfo:
360            if '.py' in el:  # This is a unique python file
361                iter_info_func = import_module(el.strip('.py'))
362                # Note: the function must be named main, and we pass the full current instance of the object.
363                iter_info_func.main(self)
364
365    def _save_during_iteration(self, tempsave):
366        """
367        Save during an iteration. How often is determined by the `TEMPSAVE` keyword; confer the manual for all the
368        different options.
369
370        Parameters
371        ----------
372        tempsave: list
373            Info. from the TEMPSAVE keyword
374        """
375        self.ensemble.logger.info(
376            'The TEMPSAVE feature is no longer supported. Please you debug_analyses, or iterinfo.')
377        # Save at specific points
378        # if isinstance(tempsave, list):
379        #     # Save at regular intervals
380        #     if tempsave[0] == 'each' or tempsave[0] == 'every' and self.ensemble.iteration % tempsave[1] == 0:
381        #         self.ensemble.save_temp_state_iter(self.ensemble.iteration + 1, self.max_iter)
382        #
383        #     # Save at points given by input
384        #     elif tempsave[0] == 'list' or tempsave[0] == 'at':
385        #         # Check if one or more save points have been given, and save if we are at that point
386        #         savepoint = tempsave[1] if isinstance(tempsave[1], list) else [tempsave[1]]
387        #         if self.ensemble.iteration in savepoint:
388        #             self.ensemble.save_temp_state_iter(self.ensemble.iteration + 1, self.max_iter)
389        #
390        # # Save at all assimilation steps
391        # elif tempsave == 'yes' or tempsave == 'all':
392        #     self.ensemble.save_temp_state_iter(self.ensemble.iteration + 1, self.max_iter)
393
394    def _save_analysis_debug(self):
395        """
396        Moved Old analysis debug here to retain consistency.
397
398        .. danger:: only class variables can be stored now.
399        """
400        # Init dict. of variables to save
401        save_dict = {}
402
403        # Make sure "ANALYSISDEBUG" gives a list
404        if isinstance(self.ensemble.keys_da['analysisdebug'], list):
405            analysisdebug = self.ensemble.keys_da['analysisdebug']
406        else:
407            analysisdebug = [self.ensemble.keys_da['analysisdebug']]
408
409        # Loop over variables to store in save list
410        for save_typ in analysisdebug:
411            if hasattr(self, save_typ):
412                save_dict[save_typ] = eval('self.{}'.format(save_typ))
413            elif hasattr(self.ensemble, save_typ):
414                save_dict[save_typ] = eval('self.ensemble.{}'.format(save_typ))
415            # Save with key equal variable name and the actual variable
416            else:
417                print(f'Cannot save {save_typ}, because it is a local variable!\n\n')
418
419        # Save the variables
420        at.save_analysisdebug(self.ensemble.iteration, **save_dict)
421
422    def calc_forecast(self):
423        """
424        Calculate the forecast step.
425
426        Run the forward simulator, generating predicted data for the analysis step. First input to the simulator
427        instances is the ensemble of (joint) state to be run and how many to run in parallel. The forward runs are done
428        in a while-loop consisting of the following steps:
429
430                1. Run the simulator for each ensemble member in the background.
431                2. Check for errors during run (if error, correct and run again or abort).
432                3. Check if simulation has ended; if yes, run simulation for the next ensemble members.
433                4. Get results from successfully ended simulations.
434
435        The procedure here is general, hence a simulator used here must contain the initial step of setting up the
436        parameters and steps i-iv, if not an error will be outputted. Initialization of the simulator is done when
437        initializing the Ensemble class (see __init__). The names of the mandatory methods in a simulator are:
438
439                > setup_fwd_sim
440                > run_fwd_sim
441                > check_sim_end
442                > get_sim_results
443
444        Parameters
445        ----------
446        assim_step : int
447                     Current assimilation step.
448
449        Notes
450        -----
451        Parallel run in "ampersand" mode means that it will be started in the background and run independently of the
452        Python script. Hence, check for simulation finished or error must be conducted!
453
454        .. info:: It is only necessary to get the results from the forward simulations that corresponds to the observed
455        data at the particular assimilation step. That is, results from all data types are not necessary to
456        extract at step iv; if they are not present in the obs_data (indicated by a None type) then this result does
457        not need to be extracted.
458
459        .. info:: It is assumed that no underscore is inputted in DATATYPE. If there are underscores in DATATYPE
460        entries, well, then we may have a problem when finding out which response to extract in get_sim_results below.
461        """
462        # Add an option to load existing sim results. The user must actively create the restart file by renaming an
463        # existing sim_results.p file to restart_sim_results.p.
464        if os.path.exists('restart_sim_results.p'):
465            with open('restart_sim_results.p', 'rb') as f:
466                self.ensemble.pred_data = pickle.load(f)
467            os.rename('restart_sim_results.p', 'sim_results.p')
468            print('--- Restart sim results used ---')
469            return
470
471        # If we are doing an sequential assimilation, such as enkf, we loop over assimilation steps
472        if len(self.ensemble.keys_da['assimindex']) > 1:
473            assim_step = self.ensemble.iteration
474        else:
475            assim_step = 0
476
477        # Get assimilation order as a list where first entry are the string(s) in OBSNAME and second entry are
478        # the associated array(s)
479        if assim_step == 0 or assim_step == len(self.ensemble.keys_da['assimindex']):
480            assim_ind = [self.ensemble.keys_da['obsname'], list(
481                np.concatenate(self.ensemble.keys_da['assimindex']))]
482        else:
483            assim_ind = [self.ensemble.keys_da['obsname'],
484                         self.ensemble.keys_da['assimindex'][assim_step]]
485
486        # Get TRUEDATAINDEX
487        true_order = [self.ensemble.keys_da['obsname'],
488                      self.ensemble.keys_da['truedataindex']]
489
490        # List assim. index
491        if isinstance(true_order[1], list):  # Check if true data prim. ind. is a list
492            true_prim = [true_order[0], [x for x in true_order[1]]]
493        else:  # Float
494            true_prim = [true_order[0], [true_order[1]]]
495        if isinstance(assim_ind[1], list):  # Check if prim. ind. is a list
496            l_prim = [int(x) for x in assim_ind[1]]
497        else:  # Float
498            l_prim = [int(assim_ind[1])]
499
500        # Run forecast. Predicted data solved in self.ensemble.pred_data
501        self.ensemble.calc_prediction()
502
503        # Filter pred. data needed at current assimilation step. This essentially means deleting pred. data not
504        # contained in the assim. indices for current assim. step or does not have obs. data at this index
505        self.ensemble.pred_data = [elem for i, elem in enumerate(self.ensemble.pred_data) if i in l_prim or
506                                   true_prim[1][i] is not None]
507
508        # Scale data if required (currently only one group of data can be scaled)
509        if 'scale' in self.ensemble.keys_da:
510            for pred_data in self.ensemble.pred_data:
511                for key in pred_data:
512                    if key in self.ensemble.keys_da['scale'][0]:
513                        pred_data[key] *= self.ensemble.keys_da['scale'][1]
514
515        # Post process predicted data if wanted
516        if 'post_process_forecast' in self.ensemble.keys_da and self.ensemble.keys_da['post_process_forecast'] == 'yes':
517            self.post_process_forecast()
518
519        # If we have dynamic variables, and we are in the first assimilation step, we must convert lists to (2D)
520        # numpy arrays
521        if 'dynamicvar' in self.ensemble.keys_da and assim_step == 0:
522            for dyn_state in self.ensemble.keys_da['dynamicvar']:
523                self.ensemble.state[dyn_state] = np.array(
524                    self.ensemble.state[dyn_state]).T
525
526        # Extra option debug
527        if 'saveforecast' in self.ensemble.sim.input_dict:
528            with open('sim_results.p', 'wb') as f:
529                pickle.dump(self.ensemble.pred_data, f)
530
531    def post_process_forecast(self):
532        """
533        Post processing of predicted data after a forecast run
534        """
535        # Temporary storage of seismic data that need to be scaled
536        pred_data_tmp = [None for _ in self.ensemble.pred_data]
537
538        # Loop over pred data and store temporary
539        if self.ensemble.sparse_info is not None:
540            for i, pred_data in enumerate(self.ensemble.pred_data):
541                for key in pred_data:
542                    # Reset vintage
543                    vintage = 0
544
545                    # Store according to sparse_info
546                    if vintage < len(self.ensemble.sparse_info['mask']) and \
547                            pred_data[key].shape[0] == int(np.sum(self.ensemble.sparse_info['mask'][vintage])):
548
549                        # If first entry in pred_data_tmp
550                        if pred_data_tmp[i] is None:
551                            pred_data_tmp[i] = {key: pred_data[key]}
552
553                        else:
554                            pred_data_tmp[i][key] = pred_data[key]
555
556                        # Update vintage
557                        vintage += 1
558
559        # Scaling used in sim2seis
560        if os.path.exists('scale_results.p'):
561            if not self.scale_val:
562                with open('scale_results.p', 'rb') as f:
563                    scale = pickle.load(f)
564                # base the scaling on the first dataset and the first iteration
565                self.scale_val = np.sum(scale[0]) / len(scale[0])
566
567            if self.ensemble.sparse_info is not None:
568                for i in range(len(pred_data_tmp)):  # INDEX
569                    if pred_data_tmp[i] is not None:
570                        for k in pred_data_tmp[i]:  # DATATYPE
571                            if 'sim2seis' in k and pred_data_tmp[i][k] is not None:
572                                pred_data_tmp[i][k] = pred_data_tmp[i][k] / self.scale_val
573
574            else:
575                for i in range(len(self.ensemble.pred_data)):  # TRUEDATAINDEX
576                    for k in self.ensemble.pred_data[i]:  # DATATYPE
577                        if 'sim2seis' in k and self.ensemble.pred_data[i][k] is not None:
578                            self.ensemble.pred_data[i][k] = self.ensemble.pred_data[i][k] / \
579                                self.scale_val
580
581        # If wavelet compression is based on the simulated data, we need to recompute obs_data, datavar and pred_data.
582        if self.ensemble.sparse_info:
583            vintage = 0
584            self.ensemble.data_rec = []
585            for i in range(len(pred_data_tmp)):  # INDEX
586                if pred_data_tmp[i] is not None:
587                    for k in pred_data_tmp[i]:  # DATATYPE
588                        if vintage < len(self.ensemble.sparse_info['mask']) and \
589                                len(pred_data_tmp[i][k]) == int(np.sum(self.ensemble.sparse_info['mask'][vintage])):
590                            self.ensemble.pred_data[i][k] = np.zeros(
591                                (len(self.ensemble.obs_data[i][k]), self.ensemble.ne))
592                            for m in range(pred_data_tmp[i][k].shape[1]):
593                                data_array = self.ensemble.compress(pred_data_tmp[i][k][:, m], vintage,
594                                                                    self.ensemble.sparse_info['use_ensemble'])
595                                self.ensemble.pred_data[i][k][:, m] = data_array
596                            vintage = vintage + 1
597            if self.ensemble.sparse_info['use_ensemble']:
598                self.ensemble.compress()
599                self.ensemble.sparse_info['use_ensemble'] = None
600
601        # Extra option debug
602        if 'saveforecast' in self.ensemble.sim.input_dict:
603            # Save the reconstructed signal for later analysis
604            if self.ensemble.sparse_data:
605                for vint in np.arange(len(self.ensemble.data_rec)):
606                    self.ensemble.data_rec[vint] = np.asarray(
607                        self.ensemble.data_rec[vint]).T
608                with open('rec_results.p', 'wb') as f:
609                    pickle.dump(self.ensemble.data_rec, f)

Class for iterative ensemble-based methods. This loop is similar/equal to a deterministic/optimization loop, but since we use ensemble-based method, we need to invoke pipt.fwd_sim.ensemble.Ensemble to get correct hierarchy of classes. The iterative loop will go until the max. iterations OR convergence has been met. Parameters for both these stopping criteria have to be given by the user through methods in their pipt.update_schemes class. Note that only iterative ensemble smoothers can be implemented with this loop (at the moment). Methods needed to be provided by user in their update_schemes class:

calc_analysis
check_convergence

% Copyright (c) 2019-2022 NORCE, All Rights Reserved. 4DSEIS

Assimilate(ensemble: pipt.loop.ensemble.Ensemble)
42    def __init__(self, ensemble: Ensemble):
43        """
44        Initialize by passing the PIPT init. file up the hierarchy.
45
46        Parameters
47        ----------
48        init_file: str
49            PIPT init. filename
50        """
51        # Internalize ensemble and simulator class instances
52        self.ensemble = ensemble
53
54        if self.ensemble.restart is False:
55            # Default max. iter if not defined in the ensemble
56            if hasattr(ensemble, 'max_iter'):
57                self.max_iter = self.ensemble.max_iter
58            else:
59                self.max_iter = self._ext_max_iter()
60
61            # Within variables
62            self.why_stop = None    # Output of why iter. loop stopped
63
64            self.scale_val = []  # Used to scale seismic data
65
66            # This feature is removed
67            # Initialize temporary storage of state variable during the assimilation (if option is supplied in DATAASSIM
68            # part). Save initially regardless of which option you have chosen as long as it is not 'no'
69            # if 'tempsave' in self.ensemble.keys_da and self.ensemble.keys_da['tempsave'] != 'no':
70            #     self.ensemble.save_temp_state_iter(0, self.max_iter)  # save init. ensemble

Initialize by passing the PIPT init. file up the hierarchy.

Parameters
  • init_file (str): PIPT init. filename
ensemble
def run(self):
 72    def run(self):
 73        """
 74        The general loop implemented here is:
 75
 76        <ol>
 77            <li>Forecast/forward simulation</li>
 78            <li>Check for convergence</li>
 79            <li>If convergence have not been achieved, do analysis/update</li>
 80        </ol>
 81
 82        % Copyright (c) 2019-2022 NORCE, All Rights Reserved. 4DSEIS
 83        """
 84        # TODO: Implement a 'calc_sensitivity' method in the loop. For now it is assumed that the sensitivity is
 85        # calculated in 'calc_analysis' using some kind of ensemble approximation.
 86
 87        # Init. while loop condition variable
 88        conv = False
 89        success_iter = True
 90
 91        # Initiallize progressbar
 92        pbar_out = tqdm(total=self.max_iter,
 93                        desc='Iterations (Obj. func. val: )', position=0)
 94
 95        # Check if we want to perform a Quality Assurance of the forecast
 96        qaqc = None
 97        if 'qa' in self.ensemble.sim.input_dict or 'qc' in self.ensemble.keys_da:
 98            qaqc = qaqc_tools.QAQC({**self.ensemble.keys_da, **self.ensemble.sim.input_dict},
 99                                   self.ensemble.obs_data, self.ensemble.datavar, self.ensemble.logger,
100                                   self.ensemble.prior_info, self.ensemble.sim, self.ensemble.prior_state)
101
102        # Run a while loop until max. iterations or convergence is reached
103        while self.ensemble.iteration < self.max_iter and conv is False:
104            # Add a check to see if this is the prior model
105            if self.ensemble.iteration == 0:
106                # Calc forecast for prior model
107                # Inset 0 as input to forecast all data
108                self.calc_forecast()
109
110                # remove outliers
111                if 'remove_outliers' in self.ensemble.sim.input_dict:
112                    self.remove_outliers()
113
114                if 'qa' in self.ensemble.keys_da:  # Check if we want to perform a Quality Assurance of the forecast
115                    # set updated prediction, state and lam
116                    qaqc.set(self.ensemble.pred_data,
117                             self.ensemble.state, self.ensemble.lam)
118                    # Level 1,2 all data, and subspace
119                    qaqc.calc_mahalanobis((1, 'time', 2, 'time', 1, None, 2, None))
120                    qaqc.calc_coverage()  # Compute data coverage
121                    qaqc.calc_kg({'plot_all_kg': True, 'only_log': False,
122                                 'num_store': 5})  # Compute kalman gain
123
124                success_iter = True
125
126                # always store prior forcast, unless specifically told not to
127                if 'nosave' not in self.ensemble.keys_da:
128                    np.savez('prior_forecast.npz', **
129                             {'pred_data': self.ensemble.pred_data})
130
131            # For the remaining iterations we start by applying the analysis and finish by running the forecast
132            else:
133                # Analysis (in the update_scheme class)
134                self.ensemble.calc_analysis()
135
136                if 'qa' in self.ensemble.keys_da and 'screendata' in self.ensemble.keys_da and \
137                        self.ensemble.keys_da['screendata'] == 'yes' and self.ensemble.iteration == 1:
138                    #  need to update datavar, and recompute mahalanobis measures
139                    self.logger.info(
140                        'Recomputing Mahalanobis distance with updated datavar')
141                    qaqc.datavar = self.datavar  # this is updated from calc_analysis
142                    # Level 1,2 all data, and subspace
143                    qaqc.calc_mahalanobis((1, 'time', 2, 'time', 1, None, 2, None))
144
145                # Forecast with the updated state
146                self.calc_forecast()
147
148                if 'remove_outliers' in self.ensemble.keys_da:
149                    self.remove_outliers()
150
151                # Check convergence (in the update_scheme class). Outputs logical variable to tell the while loop to
152                # stop, and a variable telling what criteria for convergence was reached.
153                # Also check if the objective function has been reduced, and use this function to accept the state and
154                # update the lambda values.
155                #
156                conv, success_iter, self.why_stop = self.ensemble.check_convergence()
157
158            # if reduction of objective function -> save the state
159            if success_iter:
160                # More general method to save all relevant information from an iteration analysis/forecast step
161                if 'iterinfo' in self.ensemble.keys_da:
162                    #
163                    self._save_iteration_information()
164                if self.ensemble.iteration > 0:
165                    # Temporary save state if options in TEMPSAVE have been given and the option is not 'no'
166                    if 'tempsave' in self.ensemble.keys_da and self.ensemble.keys_da['tempsave'] != 'no':
167                        self._save_during_iteration(self.ensemble.keys_da['tempsave'])
168                    if 'analysisdebug' in self.ensemble.keys_da:
169                        self._save_analysis_debug()
170                    if 'qc' in self.ensemble.keys_da:  # Check if we want to perform a Quality Control of the updated state
171                        # set updated prediction, state and lam
172                        qaqc.set(self.ensemble.pred_data,
173                                 self.ensemble.state, self.ensemble.lam)
174                        qaqc.calc_da_stat()  # Compute statistics for updated parameters
175                    if 'qa' in self.ensemble.keys_da:  # Check if we want to perform a Quality Assurance of the forecast
176                        # set updated prediction, state and lam
177                        qaqc.set(self.ensemble.pred_data,
178                                 self.ensemble.state, self.ensemble.lam)
179                        qaqc.calc_mahalanobis(
180                            (1, 'time', 2, 'time', 1, None, 2, None))  # Level 1,2 all data, and subspace
181                        #  qaqc.calc_coverage()  # Compute data coverage
182                        qaqc.calc_kg()  # Compute kalman gain
183
184            # Update iteration counter if iteration was successful
185            if self.ensemble.iteration >= 0 and success_iter is True:
186                if self.ensemble.iteration == 0:
187                    self.ensemble.iteration += 1
188                    pbar_out.update(1)
189                    # pbar_out.set_description(f'Iterations (Obj. func. val:{self.data_misfit:.1f})')
190                    # self.prior_data_misfit = self.data_misfit
191                    # self.pbar_out.refresh()
192                else:
193                    self.ensemble.iteration += 1
194                    pbar_out.update(1)
195                    pbar_out.set_description(
196                        f'Iterations (Obj. func. val:{self.ensemble.data_misfit:.1f}'
197                        f' Reduced: {100 * (1 - (self.ensemble.data_misfit / self.ensemble.prev_data_misfit)):.0f} %)')
198                    # self.pbar_out.refresh()
199
200            if 'restartsave' in self.ensemble.keys_da and self.ensemble.keys_da['restartsave'] == 'yes':
201                self.ensemble.save()
202
203        # always store posterior forcast and state, unless specifically told not to
204        if 'nosave' not in self.ensemble.keys_da:
205            try: # first try to save as npz file
206                np.savez('posterior_state_estimate.npz', **self.ensemble.state)
207                np.savez('posterior_forecast.npz', **{'pred_data': self.ensemble.pred_data})
208            except: # If this fails, store as pickle
209                with open('posterior_state_estimate.p', 'wb') as file:
210                    pickle.dump(self.ensemble.state, file)
211                with open('posterior_forecast.p', 'wb') as file:
212                    pickle.dump(self.ensemble.pred_data, file)
213
214        # If none of the convergence criteria were met, max. iteration was the reason iterations stopped.
215        if conv is False:
216            reason = 'Iterations stopped due to max iterations reached!'
217        else:
218            reason = 'Convergence was met :)'
219
220        # Save why_stop in Numpy save file
221        # savez('why_iter_loop_stopped', why=self.why_stop, conv_string=reason)
222
223        # Save why_stop in pickle save file
224        why = self.why_stop
225        if why is not None:
226            why['conv_string'] = reason
227        with open('why_iter_loop_stopped.p', 'wb') as f:
228            pickle.dump(why, f, protocol=4)
229        # pbar.close()
230        pbar_out.close()
231        if self.ensemble.prev_data_misfit is not None:
232            out_str = 'Convergence was met.'
233            if self.ensemble.prior_data_misfit > self.ensemble.data_misfit:
234                out_str += f' Obj. function reduced from {self.ensemble.prior_data_misfit:0.1f} ' \
235                           f'to {self.ensemble.data_misfit:0.1f}'
236            tqdm.write(out_str)
237            self.ensemble.logger.info(out_str)

The general loop implemented here is:

  1. Forecast/forward simulation
  2. Check for convergence
  3. If convergence have not been achieved, do analysis/update

% Copyright (c) 2019-2022 NORCE, All Rights Reserved. 4DSEIS

def remove_outliers(self):
239    def remove_outliers(self):
240
241        # function to remove ouliers
242
243        # get the cov data
244        prod_obs = np.array([])
245
246        prod_cov = np.array([])
247        prod_pred = np.empty([0, self.ensemble.ne])
248        for i in range(len(self.ensemble.obs_data)):
249            for key in self.ensemble.obs_data[i].keys():
250                if self.ensemble.obs_data[i][key] is not None and self.ensemble.obs_data[i][key].shape == (1,):
251                    prod_obs = np.concatenate((prod_obs, self.ensemble.obs_data[i][key]))
252                    prod_cov = np.concatenate((prod_cov, self.ensemble.datavar[i][key]))
253                    prod_pred = np.concatenate(
254                        (prod_pred, self.ensemble.pred_data[i][key]))
255
256        mat_prod_obs = np.dot(prod_obs.reshape((len(prod_obs), 1)),
257                              np.ones((1, self.ensemble.ne)))
258
259        hm = np.diag(np.dot((prod_pred - mat_prod_obs).T, np.dot(np.expand_dims(prod_cov ** (-1), axis=1),
260                                                                 np.ones((1, self.ensemble.ne))) * (prod_pred - mat_prod_obs)))
261        hm_std = np.std(hm)
262        hm_mean = np.mean(hm)
263        outliers = np.argwhere(np.abs(hm - hm_mean) > 4 * hm_std)
264        print('Outliers: ' + str(np.squeeze(outliers)))
265        members = np.arange(self.ensemble.ne)
266        members = np.delete(members, outliers)
267        for index in outliers.flatten():
268
269            new_index = np.random.choice(members)
270
271            # replace state
272            for el in self.ensemble.state.keys():
273                self.ensemble.state[el][:, index] = deepcopy(
274                    self.ensemble.state[el][:, new_index])
275
276            # replace the failed forecast
277            for i, data_ind in enumerate(self.ensemble.pred_data):
278                if self.ensemble.pred_data[i] is not None:
279                    for el in data_ind.keys():
280                        if self.ensemble.pred_data[i][el] is not None:
281                            if type(self.ensemble.pred_data[i][el]) is list:
282                                self.ensemble.pred_data[i][el][index] = deepcopy(
283                                    self.ensemble.pred_data[i][el][new_index])
284                            else:
285                                self.ensemble.pred_data[i][el][:, index] = deepcopy(
286                                    self.ensemble.pred_data[i][el][:, new_index])
def calc_forecast(self):
422    def calc_forecast(self):
423        """
424        Calculate the forecast step.
425
426        Run the forward simulator, generating predicted data for the analysis step. First input to the simulator
427        instances is the ensemble of (joint) state to be run and how many to run in parallel. The forward runs are done
428        in a while-loop consisting of the following steps:
429
430                1. Run the simulator for each ensemble member in the background.
431                2. Check for errors during run (if error, correct and run again or abort).
432                3. Check if simulation has ended; if yes, run simulation for the next ensemble members.
433                4. Get results from successfully ended simulations.
434
435        The procedure here is general, hence a simulator used here must contain the initial step of setting up the
436        parameters and steps i-iv, if not an error will be outputted. Initialization of the simulator is done when
437        initializing the Ensemble class (see __init__). The names of the mandatory methods in a simulator are:
438
439                > setup_fwd_sim
440                > run_fwd_sim
441                > check_sim_end
442                > get_sim_results
443
444        Parameters
445        ----------
446        assim_step : int
447                     Current assimilation step.
448
449        Notes
450        -----
451        Parallel run in "ampersand" mode means that it will be started in the background and run independently of the
452        Python script. Hence, check for simulation finished or error must be conducted!
453
454        .. info:: It is only necessary to get the results from the forward simulations that corresponds to the observed
455        data at the particular assimilation step. That is, results from all data types are not necessary to
456        extract at step iv; if they are not present in the obs_data (indicated by a None type) then this result does
457        not need to be extracted.
458
459        .. info:: It is assumed that no underscore is inputted in DATATYPE. If there are underscores in DATATYPE
460        entries, well, then we may have a problem when finding out which response to extract in get_sim_results below.
461        """
462        # Add an option to load existing sim results. The user must actively create the restart file by renaming an
463        # existing sim_results.p file to restart_sim_results.p.
464        if os.path.exists('restart_sim_results.p'):
465            with open('restart_sim_results.p', 'rb') as f:
466                self.ensemble.pred_data = pickle.load(f)
467            os.rename('restart_sim_results.p', 'sim_results.p')
468            print('--- Restart sim results used ---')
469            return
470
471        # If we are doing an sequential assimilation, such as enkf, we loop over assimilation steps
472        if len(self.ensemble.keys_da['assimindex']) > 1:
473            assim_step = self.ensemble.iteration
474        else:
475            assim_step = 0
476
477        # Get assimilation order as a list where first entry are the string(s) in OBSNAME and second entry are
478        # the associated array(s)
479        if assim_step == 0 or assim_step == len(self.ensemble.keys_da['assimindex']):
480            assim_ind = [self.ensemble.keys_da['obsname'], list(
481                np.concatenate(self.ensemble.keys_da['assimindex']))]
482        else:
483            assim_ind = [self.ensemble.keys_da['obsname'],
484                         self.ensemble.keys_da['assimindex'][assim_step]]
485
486        # Get TRUEDATAINDEX
487        true_order = [self.ensemble.keys_da['obsname'],
488                      self.ensemble.keys_da['truedataindex']]
489
490        # List assim. index
491        if isinstance(true_order[1], list):  # Check if true data prim. ind. is a list
492            true_prim = [true_order[0], [x for x in true_order[1]]]
493        else:  # Float
494            true_prim = [true_order[0], [true_order[1]]]
495        if isinstance(assim_ind[1], list):  # Check if prim. ind. is a list
496            l_prim = [int(x) for x in assim_ind[1]]
497        else:  # Float
498            l_prim = [int(assim_ind[1])]
499
500        # Run forecast. Predicted data solved in self.ensemble.pred_data
501        self.ensemble.calc_prediction()
502
503        # Filter pred. data needed at current assimilation step. This essentially means deleting pred. data not
504        # contained in the assim. indices for current assim. step or does not have obs. data at this index
505        self.ensemble.pred_data = [elem for i, elem in enumerate(self.ensemble.pred_data) if i in l_prim or
506                                   true_prim[1][i] is not None]
507
508        # Scale data if required (currently only one group of data can be scaled)
509        if 'scale' in self.ensemble.keys_da:
510            for pred_data in self.ensemble.pred_data:
511                for key in pred_data:
512                    if key in self.ensemble.keys_da['scale'][0]:
513                        pred_data[key] *= self.ensemble.keys_da['scale'][1]
514
515        # Post process predicted data if wanted
516        if 'post_process_forecast' in self.ensemble.keys_da and self.ensemble.keys_da['post_process_forecast'] == 'yes':
517            self.post_process_forecast()
518
519        # If we have dynamic variables, and we are in the first assimilation step, we must convert lists to (2D)
520        # numpy arrays
521        if 'dynamicvar' in self.ensemble.keys_da and assim_step == 0:
522            for dyn_state in self.ensemble.keys_da['dynamicvar']:
523                self.ensemble.state[dyn_state] = np.array(
524                    self.ensemble.state[dyn_state]).T
525
526        # Extra option debug
527        if 'saveforecast' in self.ensemble.sim.input_dict:
528            with open('sim_results.p', 'wb') as f:
529                pickle.dump(self.ensemble.pred_data, f)

Calculate the forecast step.

Run the forward simulator, generating predicted data for the analysis step. First input to the simulator instances is the ensemble of (joint) state to be run and how many to run in parallel. The forward runs are done in a while-loop consisting of the following steps:

    1. Run the simulator for each ensemble member in the background.
    2. Check for errors during run (if error, correct and run again or abort).
    3. Check if simulation has ended; if yes, run simulation for the next ensemble members.
    4. Get results from successfully ended simulations.

The procedure here is general, hence a simulator used here must contain the initial step of setting up the parameters and steps i-iv, if not an error will be outputted. Initialization of the simulator is done when initializing the Ensemble class (see __init__). The names of the mandatory methods in a simulator are:

    > setup_fwd_sim
    > run_fwd_sim
    > check_sim_end
    > get_sim_results
Parameters
  • assim_step (int): Current assimilation step.
Notes

Parallel run in "ampersand" mode means that it will be started in the background and run independently of the Python script. Hence, check for simulation finished or error must be conducted!

.. info:: It is only necessary to get the results from the forward simulations that corresponds to the observed data at the particular assimilation step. That is, results from all data types are not necessary to extract at step iv; if they are not present in the obs_data (indicated by a None type) then this result does not need to be extracted.

.. info:: It is assumed that no underscore is inputted in DATATYPE. If there are underscores in DATATYPE entries, well, then we may have a problem when finding out which response to extract in get_sim_results below.

def post_process_forecast(self):
531    def post_process_forecast(self):
532        """
533        Post processing of predicted data after a forecast run
534        """
535        # Temporary storage of seismic data that need to be scaled
536        pred_data_tmp = [None for _ in self.ensemble.pred_data]
537
538        # Loop over pred data and store temporary
539        if self.ensemble.sparse_info is not None:
540            for i, pred_data in enumerate(self.ensemble.pred_data):
541                for key in pred_data:
542                    # Reset vintage
543                    vintage = 0
544
545                    # Store according to sparse_info
546                    if vintage < len(self.ensemble.sparse_info['mask']) and \
547                            pred_data[key].shape[0] == int(np.sum(self.ensemble.sparse_info['mask'][vintage])):
548
549                        # If first entry in pred_data_tmp
550                        if pred_data_tmp[i] is None:
551                            pred_data_tmp[i] = {key: pred_data[key]}
552
553                        else:
554                            pred_data_tmp[i][key] = pred_data[key]
555
556                        # Update vintage
557                        vintage += 1
558
559        # Scaling used in sim2seis
560        if os.path.exists('scale_results.p'):
561            if not self.scale_val:
562                with open('scale_results.p', 'rb') as f:
563                    scale = pickle.load(f)
564                # base the scaling on the first dataset and the first iteration
565                self.scale_val = np.sum(scale[0]) / len(scale[0])
566
567            if self.ensemble.sparse_info is not None:
568                for i in range(len(pred_data_tmp)):  # INDEX
569                    if pred_data_tmp[i] is not None:
570                        for k in pred_data_tmp[i]:  # DATATYPE
571                            if 'sim2seis' in k and pred_data_tmp[i][k] is not None:
572                                pred_data_tmp[i][k] = pred_data_tmp[i][k] / self.scale_val
573
574            else:
575                for i in range(len(self.ensemble.pred_data)):  # TRUEDATAINDEX
576                    for k in self.ensemble.pred_data[i]:  # DATATYPE
577                        if 'sim2seis' in k and self.ensemble.pred_data[i][k] is not None:
578                            self.ensemble.pred_data[i][k] = self.ensemble.pred_data[i][k] / \
579                                self.scale_val
580
581        # If wavelet compression is based on the simulated data, we need to recompute obs_data, datavar and pred_data.
582        if self.ensemble.sparse_info:
583            vintage = 0
584            self.ensemble.data_rec = []
585            for i in range(len(pred_data_tmp)):  # INDEX
586                if pred_data_tmp[i] is not None:
587                    for k in pred_data_tmp[i]:  # DATATYPE
588                        if vintage < len(self.ensemble.sparse_info['mask']) and \
589                                len(pred_data_tmp[i][k]) == int(np.sum(self.ensemble.sparse_info['mask'][vintage])):
590                            self.ensemble.pred_data[i][k] = np.zeros(
591                                (len(self.ensemble.obs_data[i][k]), self.ensemble.ne))
592                            for m in range(pred_data_tmp[i][k].shape[1]):
593                                data_array = self.ensemble.compress(pred_data_tmp[i][k][:, m], vintage,
594                                                                    self.ensemble.sparse_info['use_ensemble'])
595                                self.ensemble.pred_data[i][k][:, m] = data_array
596                            vintage = vintage + 1
597            if self.ensemble.sparse_info['use_ensemble']:
598                self.ensemble.compress()
599                self.ensemble.sparse_info['use_ensemble'] = None
600
601        # Extra option debug
602        if 'saveforecast' in self.ensemble.sim.input_dict:
603            # Save the reconstructed signal for later analysis
604            if self.ensemble.sparse_data:
605                for vint in np.arange(len(self.ensemble.data_rec)):
606                    self.ensemble.data_rec[vint] = np.asarray(
607                        self.ensemble.data_rec[vint]).T
608                with open('rec_results.p', 'wb') as f:
609                    pickle.dump(self.ensemble.data_rec, f)

Post processing of predicted data after a forecast run