pipt.loop.ensemble

Descriptive description.

  1"""Descriptive description."""
  2
  3# External import
  4import logging
  5import os.path
  6
  7import numpy
  8import numpy as np
  9import sys
 10from copy import deepcopy, copy
 11from scipy.linalg import solve, cholesky
 12from scipy.spatial import distance
 13import itertools
 14from geostat.decomp import Cholesky
 15
 16# Internal import
 17from ensemble.ensemble import Ensemble as PETEnsemble
 18import misc.read_input_csv as rcsv
 19from pipt.misc_tools import wavelet_tools as wt
 20from pipt.misc_tools import cov_regularization
 21import pipt.misc_tools.analysis_tools as at
 22
 23
 24class Ensemble(PETEnsemble):
 25    """
 26    Class for organizing/initializing misc. variables and simulator for an
 27    ensemble-based inversion run. Inherits the PET ensemble structure
 28    """
 29
 30    def __init__(self, keys_da, keys_en, sim):
 31        """
 32        Parameters
 33        ----------
 34        keys_da : dict
 35            Options for the data assimilation class
 36
 37            - daalg: spesification of the method, first the main type (e.g., "enrml"), then the solver (e.g., "gnenrml")
 38            - analysis: update flavour ("approx", "full" or "subspace")
 39            - energy: percent of singular values kept after SVD
 40            - obsvarsave: save the observations as a file (default false)
 41            - restart: restart optimization from a restart file (default false)
 42            - restartsave: save a restart file after each successful iteration (defalut false)
 43            - analysisdebug: specify which class variables to save to the result files
 44            - truedataindex: order of the simulated data (for timeseries this is points in time)
 45            - obsname: unit for truedataindex (for timeseries this is days or hours or seconds, etc.)
 46            - truedata: the data, e.g., provided as a .csv file
 47            - assimindex: index for the data that will be used for assimilation
 48            - datatype: list with the name of the datatypes
 49            - staticvar: name of the static variables
 50            - datavar: data variance, e.g., provided as a .csv file
 51
 52        keys_en : dict
 53            Options for the ensemble class
 54
 55            - ne: number of perturbations used to compute the gradient
 56            - state: name of state variables passed to the .mako file
 57            - prior_<name>: the prior information the state variables, including mean, variance and variable limits
 58
 59        sim : callable
 60            The forward simulator (e.g. flow)
 61        """
 62
 63
 64        # do the initiallization of the PETensemble
 65        super(Ensemble, self).__init__(keys_en, sim)
 66
 67        # set logger
 68        self.logger = logging.getLogger('PET.PIPT')
 69
 70        # write initial information
 71        self.logger.info(f'Starting a {keys_da["daalg"][0]} run with the {keys_da["daalg"][1]} algorithm applying the '
 72                         f'{keys_da["analysis"]} update scheme with {keys_da["energy"]} Energy.')
 73
 74        # Internalize PIPT dictionary
 75        if not hasattr(self, 'keys_da'):
 76            self.keys_da = keys_da
 77        if not hasattr(self, 'keys_en'):
 78            self.keys_en = keys_en
 79
 80        if self.restart is False:
 81            # Init in _init_prediction_output (used in run_prediction)
 82            self.prediction = None
 83            self.temp_state = None  # temporary state saving
 84            self.cov_prior = None  # Prior cov. matrix
 85            self.sparse_info = None  # Init in _org_sparse_representation
 86            self.sparse_data = []  # List of the compression info
 87            self.data_rec = []  # List of reconstructed data
 88            self.scale_val = None  # Use to scale data
 89
 90            # Prepare sparse representation
 91            if 'compress' in self.keys_da:
 92                self._org_sparse_representation()
 93
 94            self._org_obs_data()
 95            self._org_data_var()
 96
 97            # define projection for centring and scaling
 98            self.proj = (np.eye(self.ne) - (1 / self.ne) *
 99                         np.ones((self.ne, self.ne))) / np.sqrt(self.ne - 1)
100
101            # If we have dynamic state variables, we allocate keys for them in 'state'. Since we do not know the size
102            #  of the arrays of the dynamic variables, we only allocate an NE list to be filled in later (in
103            # calc_forecast)
104            if 'dynamicvar' in self.keys_da:
105                dyn_var = self.keys_da['dynamicvar'] if isinstance(self.keys_da['dynamicvar'], list) else \
106                    [self.keys_da['dynamicvar']]
107                for name in dyn_var:
108                    self.state[name] = [None] * self.ne
109
110            # Option to store the dictionaries containing observed data and data variance
111            if 'obsvarsave' in self.keys_da and self.keys_da['obsvarsave'] == 'yes':
112                np.savez('obs_var', obs=self.obs_data, var=self.datavar)
113
114            # Initialize localization
115            if 'localization' in self.keys_da:
116                self.localization = cov_regularization.localization(self.keys_da['localization'],
117                                                                    self.keys_da['truedataindex'],
118                                                                    self.keys_da['datatype'],
119                                                                    self.keys_da['staticvar'],
120                                                                    self.ne)
121            # Initialize local analysis
122            if 'localanalysis' in self.keys_da:
123                self.local_analysis = at.init_local_analysis(
124                    init=self.keys_da['localanalysis'], state=self.state.keys())
125
126            self.pred_data = [{k: np.zeros((1, self.ne), dtype='float32') for k in self.keys_da['datatype']}
127                              for _ in self.obs_data]
128
129            self.cell_index = None  # default value for extracting states
130
131    def check_assimindex_sequential(self):
132        """
133        Check if assim. indices is given as a 2D list as is needed in sequential updating. If not, make it a 2D list
134        """
135        # Check if ASSIMINDEX is a list. If not, make it a 2D list
136        if not isinstance(self.keys_da['assimindex'], list):
137            self.keys_da['assimindex'] = [[self.keys_da['assimindex']]]
138
139        # If ASSIMINDEX is a 1D list (either given in as a single row or single column), we reshape to a 2D list
140        elif not isinstance(self.keys_da['assimindex'][0], list):
141            assimindex_temp = [None] * len(self.keys_da['assimindex'])
142
143            for i in range(len(self.keys_da['assimindex'])):
144                assimindex_temp[i] = [self.keys_da['assimindex'][i]]
145
146            self.keys_da['assimindex'] = assimindex_temp
147
148    def check_assimindex_simultaneous(self):
149        """
150        Check if assim. indices is given as a 1D list as is needed in simultaneous updating. If not, make it a 2D list
151        with one row.
152        """
153        # Check if ASSIMINDEX is a list. If not, make it a 2D list with one row
154        if not isinstance(self.keys_da['assimindex'], list):
155            self.keys_da['assimindex'] = [[self.keys_da['assimindex']]]
156
157        # Check if ASSIMINDEX is a 1D list. If true, make it a 2D list with one row
158        elif not isinstance(self.keys_da['assimindex'][0], list):
159            self.keys_da['assimindex'] = [self.keys_da['assimindex']]
160
161        # If ASSIMINDEX is a 2D list, we reshape it to a 2D list with one row
162        elif isinstance(self.keys_da['assimindex'][0], list):
163            self.keys_da['assimindex'] = [
164                [item for sublist in self.keys_da['assimindex'] for item in sublist]]
165
166    def _org_obs_data(self):
167        """
168        Organize the input true observed data. The obs_data will be a list of length equal length of "TRUEDATAINDEX",
169        and each entery in the list will be a dictionary with keys equal to the "DATATYPE".
170        Also, the pred_data variable (predicted data or forward simulation) will be initialized here with the same
171        structure as the obs_data variable.
172
173        .. warning:: An "N/A" entry in "TRUEDATA" is treated as a None-entry; that is, there is NOT an observed data at this
174        assimilation step.
175
176        .. warning:: The array associated with the first string inputted in "TRUEDATAINDEX" is assumed to be the "main"
177        index, that is, the length of this array will determine the length of the obs_data list! There arrays
178        associated with the subsequent strings in "TRUEDATAINDEX" are then assumed to be a subset of the first
179        string.  An example: the first string is SOURCE (e.g., sources in CSEM), where the array will be a list of numbering
180        for the sources; and the second string is FREQ, where the array associated will be a list of frequencies.
181
182        .. note:: It is assumed that the number of data associated with a subset is the same for each index in the subset.
183        For example: If two frequencies are inputted in FREQ, then the number of data for one SOURCE index and one
184        frequency is 1/2 of the total no. of data for that SOURCE index. If three frequencies are inputted, the number
185        of data for one SOURCE index and one frequencies is 1/3 of the total no of data for that SOURCE index,
186        and so on.
187        """
188
189        # # Check if keys_da['datatype'] is a string or list, and make it a list if single string is given
190        # if isinstance(self.keys_da['datatype'], str):
191        #     datatype = [self.keys_da['datatype']]
192        # else:
193        #     datatype = self.keys_da['datatype']
194        #
195        # # Extract primary indices from "TRUEDATAINDEX"
196        # if isinstance(self.keys_da['truedataindex'], list):  # List of prim. ind
197        #     true_prim = self.keys_da['truedataindex']
198        # else:  # Float
199        #     true_prim = [self.keys_da['truedataindex']]
200        #
201        # # Check if a csv file has been included as "TRUEDATAINDEX". If so, we read it and make a list,
202        # if isinstance(self.keys_da['truedataindex'], str) and self.keys_da['truedataindex'].endswith('.csv'):
203        #     with open(self.keys_da['truedataindex']) as csvfile:
204        #         reader = csv.reader(csvfile)  # get a reader object
205        #         true_prim = []  # Initialize the list of csv data
206        #         for rows in reader:  # Rows is a list of values in the csv file
207        #             csv_data = [None] * len(rows)
208        #             for ind, col in enumerate(rows):
209        #                 csv_data[ind] = int(col)
210        #             true_prim.extend(csv_data)
211        #     self.keys_da['truedataindex'] = true_prim
212        #
213        # # Check if a csv file has been included as "PREDICTION". If so, we read it and make a list,
214        # if 'prediction' in self.keys_da:
215        #     if isinstance(self.keys_da['prediction'], str) and self.keys_da['prediction'].endswith('.csv'):
216        #         with open(self.keys_da['prediction']) as csvfile:
217        #             reader = csv.reader(csvfile)  # get a reader object
218        #             pred_prim = []  # Initialize the list of csv data
219        #             for rows in reader:  # Rows is a list of values in the csv file
220        #                 csv_data = [None] * len(rows)
221        #                 for ind, col in enumerate(rows):
222        #                     csv_data[ind] = int(col)
223        #                 pred_prim.extend(csv_data)
224        #         self.keys_da['prediction'] = pred_prim
225
226        # Extract the observed data from "TRUEDATA"
227        if len(self.keys_da['truedataindex']) == 1:  # Only one assimilation step
228            if isinstance(self.keys_da['truedata'], list):
229                truedata = [self.keys_da['truedata']]
230            else:
231                truedata = [[self.keys_da['truedata']]]
232        else:  # More than one assim. step
233            if isinstance(self.keys_da['truedata'][0], list):  # 2D list
234                truedata = self.keys_da['truedata']
235            else:
236                truedata = [[x] for x in self.keys_da['truedata']]  # Make it a 2D list
237
238        # Initialize obs_data list. List length = len("TRUEDATAINDEX"); dictionary in each list entry = d
239        self.obs_data = [None] * len(self.keys_da['truedataindex'])
240
241        # Check if a csv file has been included in TRUEDATA. If so, we read it and make a 2D list, which we can use
242        # in the below when assigning data to obs_data dictionary
243        if isinstance(self.keys_da['truedata'], str) and self.keys_da['truedata'].endswith('.csv'):
244            truedata = rcsv.read_data_csv(
245                self.keys_da['truedata'], self.keys_da['datatype'], self.keys_da['truedataindex'])
246
247        # # Check if assimindex is given as a csv file. If so, we read and make a potential 2D list (if sequential).
248        # if isinstance(self.keys_da['assimindex'], str) and self.keys_da['assimindex'].endswith('.csv'):
249        #     with open(self.keys_da['assimindex']) as csvfile:
250        #         reader = csv.reader(csvfile)  # get a reader object
251        #         assimindx = []  # Initialize the 2D list of csv data
252        #         for rows in reader:  # Rows is a list of values in the csv file
253        #             csv_data = [None] * len(rows)
254        #             for col in range(len(rows)):
255        #                 csv_data[col] = int(rows[col])
256        #             assimindx.append(csv_data)
257        #     self.keys_da['assimindex'] = assimindx
258
259        # Now we loop over all list entries in obs_data and fill in the observed data from "TRUEDATA".
260        # NOTE: Not all data types may have observed data at each "TRUEDATAINDEX"; in this case it will have a None
261        # entry.
262        # NOTE2: If "TRUEDATA" contains a .npz file, this will be loaded. BUT the array loaded MUST be a 1D numpy
263        # array! So resize BEFORE saving the .npz file!
264        # NOTE3: If CSV file has been included in TRUEDATA, we read the data from this file
265        vintage = 0
266        for i in range(len(self.obs_data)):  # TRUEDATAINDEX
267            # Init. dict. with datatypes (do inside loop to avoid copy of same entry)
268            self.obs_data[i] = {}
269            # Make unified inputs
270            if 'unif_in' in self.keys_da and self.keys_da['unif_in'] == 'yes':
271                if isinstance(truedata[i][0], str) and truedata[i][0].endswith('.npz'):
272                    load_data = np.load(truedata[i][0])  # Load the .npz file
273                    data_array = load_data[load_data.files[0]]
274
275                    # Perform compression if required (we only and always compress signals with same size as number of active cells)
276                    if self.sparse_info is not None and \
277                            vintage < len(self.sparse_info['mask']) and \
278                            len(data_array) == int(np.sum(self.sparse_info['mask'][vintage])):
279                        data_array = self.compress(data_array, vintage, False)
280                        vintage = vintage + 1
281
282                    # Save array in obs_data. If it is an array with single value (not list), then we convert it to a
283                    # list with one entry.
284                    self.obs_data[i][self.keys_da['datatype'][0]] = np.array(
285                        [data_array[()]]) if data_array.shape == () else data_array
286
287                    # Entry is N/A, i.e., no data given
288                elif isinstance(truedata[i][0], str) and not truedata[i][0].endswith('.npz') \
289                        and truedata[i][0].lower() == 'n/a':
290                    self.obs_data[i][self.keys_da['datatype'][0]] = None
291
292                # Unknown string entry
293                elif isinstance(truedata[i][0], str) and not truedata[i][0].endswith('.npz') \
294                        and not truedata[i][0].lower() == 'n/a':
295                    print(
296                        '\n\033[1;31mERROR: Cannot load observed data file! Maybe it is not a .npz file?\033[1;m')
297                    sys.exit(1)
298                # Entry is a numerical value
299                elif not isinstance(truedata[i][0], str):  # Some numerical value or None
300                    self.obs_data[i][self.keys_da['datatype'][0]] = np.array(
301                        truedata[i][:])  # no need to make this into a list
302            else:
303                for j in range(len(self.keys_da['datatype'])):  # DATATYPE
304                    # Load a Numpy npz file
305                    if isinstance(truedata[i][j], str) and truedata[i][j].endswith('.npz'):
306                        load_data = np.load(truedata[i][j])  # Load the .npz file
307                        data_array = load_data[load_data.files[0]]
308
309                        # Perform compression if required (we only and always compress signals with same size as number of active cells)
310                        if self.sparse_info is not None and \
311                                vintage < len(self.sparse_info['mask']) and \
312                                len(data_array) == int(np.sum(self.sparse_info['mask'][vintage])):
313                            data_array = self.compress(data_array, vintage, False)
314                            vintage = vintage + 1
315
316                        # Save array in obs_data. If it is an array with single value (not list), then we convert it to a
317                        # list with one entry
318                        self.obs_data[i][self.keys_da['datatype'][j]] = np.array(
319                            [data_array[()]]) if data_array.shape == () else data_array
320
321                    # Entry is N/A, i.e., no data given
322                    elif isinstance(truedata[i][j], str) and not truedata[i][j].endswith('.npz') \
323                            and truedata[i][j].lower() == 'n/a':
324                        self.obs_data[i][self.keys_da['datatype'][j]] = None
325
326                    # Unknown string entry
327                    elif isinstance(truedata[i][j], str) and not truedata[i][j].endswith('.npz') \
328                            and not truedata[i][j].lower() == 'n/a':
329                        print(
330                            '\n\033[1;31mERROR: Cannot load observed data file! Maybe it is not a .npz file?\033[1;m')
331                        sys.exit(1)
332
333                    # Entry is a numerical value
334                    # Some numerical value or None
335                    elif not isinstance(truedata[i][j], str):
336                        if type(truedata[i][j]) is numpy.ndarray:
337                            self.obs_data[i][self.keys_da['datatype'][j]] = truedata[i][j]
338                        else:
339                            self.obs_data[i][self.keys_da['datatype'][j]] = np.array([truedata[i][j]])
340
341                    # Scale data if required (currently only one group of data can be scaled)
342                    if 'scale' in self.keys_da and self.keys_da['scale'][0] in self.keys_da['datatype'][j] and \
343                            self.obs_data[i][self.keys_da['datatype'][j]] is not None:
344                        self.obs_data[i][self.keys_da['datatype']
345                                         [j]] *= self.keys_da['scale'][1]
346
347    def _org_data_var(self):
348        """
349        Organize the input data variance given by the keyword "DATAVAR" in the "DATAASSIM" part the init_file.
350
351        If a diagonal auto-covariance is to be used to generate data, there are two options for data variance: absolute
352        and relative variance. Absolute is a fixed value for the variance, and relative is a percentage of
353        the observed data as standard deviation which in turn is set as variance. If we want to use an empirical data
354        covariance matrix to generate data, the user must supply a Numpy save file with samples, which is loaded here.
355        If we want to specify the whole covariance matrix, this can also be done. The user must supply a Numpy save file
356        which is loaded here.
357
358        .. warning:: When relative variance is given as input, we set the variance as (true_obs_data*rel_perc*0.01)**2
359        BECAUSE we often want this alternative in cases where we "add some percentage of Gaussian noise to the
360        observed data". Hence, we actually want some percentage of the true observed data as STANDARD DEVIATION since
361        it ultimately is the standard deviation (through square-root decompostion of Cd) that is used when adding
362        noise to observed data.Note that this is ONLY a matter of definition, but we feel that this way of defining
363        relative variance is most common.
364        """
365        # TODO: Change when sub-assim. indices have been re-implemented.
366
367        # Check if keys_da['datatype'] is a string or list, and make it a list if single string is given
368        if isinstance(self.keys_da['datatype'], str):
369            datatype = [self.keys_da['datatype']]
370        else:
371            datatype = self.keys_da['datatype']
372
373        # Extract primary indices from "TRUEDATAINDEX"
374        if isinstance(self.keys_da['truedataindex'], list):  # List of prim. ind
375            true_prim = self.keys_da['truedataindex']
376        else:  # Float
377            true_prim = [self.keys_da['truedataindex']]
378
379        #
380        # Extract the data variance from "DATAVAR"
381        #
382        # Only one assimilation step
383        if len(true_prim) == 1:
384            # More than one DATATYPE, but only one entry in DATAVAR
385            if len(self.keys_da['datavar']) == 2 and len(datatype) > 1:
386                # Copy list entry no. data type times
387                datavar = [self.keys_da['datavar'] * len(datatype)]
388
389            # One DATATYPE
390            else:
391                datavar = [self.keys_da['datavar']]
392
393        # More than one assim. step
394        else:
395            # More than one DATATYPE, but only one entry in DATAVAR
396            if not isinstance(self.keys_da['datavar'][0], list) and len(self.keys_da['datavar']) == 2 and \
397                    len(datatype) > 1:
398                # Need to make a list with entries equal to 2*no. data types (since there are 2 entries in DATAVAR
399                # for one data type). Then we copy this list as many times as we have TRUEDATAINDEX (i.e.,
400                # we get a 2D list)
401                # Copy list entry no. data types times
402                datavar_temp = self.keys_da['datavar'] * len(datatype)
403                datavar = [None] * len(true_prim)  # Init.
404                for i in range(len(true_prim)):
405                    datavar[i] = deepcopy(datavar_temp)
406
407            # Entry for each DATATYPE, but not for each TRUEDATAINDEX
408            elif (len(self.keys_da['datavar'])) / 2 == len(datatype) and \
409                    not isinstance(self.keys_da['datavar'][0], list):
410                # If we have entry for each DATATYPE but NOT for each TRUEDATAINDEX, then we just copy the list of
411                # entries to each TRUEDATAINDEX
412                datavar = [None] * len(true_prim)  # Init.
413                for i in range(len(true_prim)):
414                    datavar[i] = deepcopy(self.keys_da['datavar'])
415
416            else:
417                datavar = self.keys_da['datavar']
418
419        # Check if a csv file has been included in DATAVAR. If so datavar will be redefined and variance info will be
420        #  extracted from the csv file
421        if isinstance(self.keys_da['datavar'], str) and self.keys_da['datavar'].endswith('.csv'):
422            datavar = rcsv.read_var_csv(self.keys_da['datavar'], datatype, true_prim)
423
424        # Initialize datavar output
425        self.datavar = [None] * len(true_prim)
426
427        # Loop over all entries in datavar and fill in values from "DATAVAR" (use obs_data values in the REL variance
428        #  cases)
429        # TODO: Implement loading of data variance from .npz file
430        vintage = 0
431        for i in range(len(self.obs_data)):  # TRUEDATAINDEX
432            # Init. dict. with datatypes (do inside loop to avoid copy of same entry)
433            self.datavar[i] = {}
434            for j in range(len(datatype)):  # DATATYPE
435                # ABS
436                # Absolute var.
437                if datavar[i][2*j] == 'abs' and self.obs_data[i][datatype[j]] is not None:
438                    self.datavar[i][datatype[j]] = datavar[i][2*j+1] * \
439                        np.ones(len(self.obs_data[i][datatype[j]]))
440
441                # REL
442                # Rel. var.
443                elif datavar[i][2*j] == 'rel' and self.obs_data[i][datatype[j]] is not None:
444                    # Rel. var WITH a min. variance tolerance
445                    if isinstance(datavar[i][2*j+1], list):
446                        self.datavar[i][datatype[j]] = (datavar[i][2*j+1][0] * 0.01 *
447                                                        self.obs_data[i][datatype[j]]) ** 2
448                        ind_tol = self.datavar[i][datatype[j]] < datavar[i][2*j+1][1] ** 2
449                        self.datavar[i][datatype[j]][ind_tol] = datavar[i][2*j+1][1] ** 2
450
451                    else:  # Single. rel. var input
452                        var = (datavar[i][2*j+1] * 0.01 * self.obs_data[i][datatype[j]]) ** 2
453                        var = np.clip(var, 1.0e-9, None)  # avoid zero variance
454                        self.datavar[i][datatype[j]] = var
455                # EMP
456                elif datavar[i][2*j] == 'emp' and datavar[i][2*j+1].endswith('.npz') and \
457                        self.obs_data[i][datatype[j]] is not None:  # Empirical var.
458                    load_data = np.load(datavar[i][2*j+1])  # load the numpy savez file
459                    # store in datavar
460                    self.datavar[i][datatype[j]] = load_data[load_data.files[0]]
461
462                # LOAD
463                elif datavar[i][2*j] == 'load' and datavar[i][2*j+1].endswith('.npz') and \
464                        self.obs_data[i][datatype[j]] is not None:  # Load variance. (1d array)
465                    load_data = np.load(datavar[i][2*j+1])  # load the numpy savez file
466                    load_data = load_data[load_data.files[0]]
467                    self.datavar[i][datatype[j]] = load_data  # store in datavar
468
469                # CD the full covariance matrix is given in its correct format. Hence, load once and set as CD
470                elif datavar[i][2 * j] == 'cd' and datavar[i][2 * j + 1].endswith('.npz') and \
471                        self.obs_data[i][datatype[j]] is not None:
472                    if not hasattr(self, 'cov_data'):  # check to populate once
473                        # load the numpy savez file
474                        load_data = np.load(datavar[i][2 * j + 1])
475                        self.cov_data = load_data[load_data.files[0]]
476                    # store the variance
477                    self.datavar[i][datatype[j]] = self.cov_data[i*j, i*j]
478
479                elif self.obs_data[i][datatype[j]] is None:  # No observed data
480                    self.datavar[i][datatype[j]] = None  # Set None type here also
481
482                # Handle case when noise is estimated using wavelets
483                if self.sparse_info is not None and self.datavar[i][datatype[j]] is not None and \
484                        vintage < len(self.sparse_info['mask']) and \
485                        len(self.datavar[i][datatype[j]]) == int(np.sum(self.sparse_info['mask'][vintage])):
486                    # compute var from sparse_data
487                    est_noise = np.power(self.sparse_data[vintage].est_noise, 2)
488                    self.datavar[i][datatype[j]] = est_noise  # override the given value
489                    vintage = vintage + 1
490
491    def _org_sparse_representation(self):
492        """
493        Function for reading input to wavelet sparse representation of data.
494        """
495        self.sparse_info = {}
496        parsed_info = self.keys_da['compress']
497        dim = [int(elem) for elem in parsed_info[0][1]]
498        # flip to align with flow / eclipse
499        self.sparse_info['dim'] = [dim[2], dim[1], dim[0]]
500        self.sparse_info['mask'] = []
501        for vint in range(1, len(parsed_info[1])):
502            if not os.path.exists(parsed_info[1][vint]):
503                mask = np.ones(self.sparse_info['dim'], dtype=bool)
504                np.savez(f'mask_{vint-1}.npz', mask=mask)
505            else:
506                mask = np.load(parsed_info[1][vint])['mask']
507            self.sparse_info['mask'].append(mask.flatten())
508        self.sparse_info['level'] = parsed_info[2][1]
509        self.sparse_info['wname'] = parsed_info[3][1]
510        self.sparse_info['colored_noise'] = True if parsed_info[4][1] == 'yes' else False
511        self.sparse_info['threshold_rule'] = parsed_info[5][1]
512        self.sparse_info['th_mult'] = parsed_info[6][1]
513        self.sparse_info['use_hard_th'] = True if parsed_info[7][1] == 'yes' else False
514        self.sparse_info['keep_ca'] = True if parsed_info[8][1] == 'yes' else False
515        self.sparse_info['inactive_value'] = parsed_info[9][1]
516        self.sparse_info['use_ensemble'] = True if parsed_info[10][1] == 'yes' else None
517        self.sparse_info['order'] = parsed_info[11][1]
518        self.sparse_info['min_noise'] = parsed_info[12][1]
519
520    def _ext_obs(self):
521        self.obs_data_vector, _ = at.aug_obs_pred_data(self.obs_data, self.pred_data, self.assim_index,
522                                                       self.list_datatypes)
523        # Generate the data auto-covariance matrix
524        if 'emp_cov' in self.keys_da and self.keys_da['emp_cov'] == 'yes':
525            if hasattr(self, 'cov_data'):  # cd matrix has been imported
526                tmp_E = np.dot(cholesky(self.cov_data).T,
527                               np.random.randn(self.cov_data.shape[0], self.ne))
528            else:
529                tmp_E = at.extract_tot_empirical_cov(
530                    self.datavar, self.assim_index, self.list_datatypes, self.ne)
531            # self.E = (tmp_E - tmp_E.mean(1)[:,np.newaxis])/np.sqrt(self.ne - 1)/
532            if 'screendata' in self.keys_da and self.keys_da['screendata'] == 'yes':
533                tmp_E = at.screen_data(tmp_E, self.aug_pred_data,
534                                       self.obs_data_vector, self.iteration)
535            self.E = tmp_E
536            self.real_obs_data = self.obs_data_vector[:, np.newaxis] - tmp_E
537
538            self.cov_data = np.var(self.E, ddof=1,
539                                   axis=1)  # calculate the variance, to be used for e.g. data misfit calc
540            # self.cov_data = ((self.E * self.E)/(self.ne-1)).sum(axis=1) # calculate the variance, to be used for e.g. data misfit calc
541            self.scale_data = np.sqrt(self.cov_data)
542        else:
543            if not hasattr(self, 'cov_data'):  # if cd is not loaded
544                self.cov_data = at.gen_covdata(
545                    self.datavar, self.assim_index, self.list_datatypes)
546            # data screening
547            if 'screendata' in self.keys_da and self.keys_da['screendata'] == 'yes':
548                self.cov_data = at.screen_data(
549                    self.cov_data, self.aug_pred_data, self.obs_data_vector, self.iteration)
550
551            init_en = Cholesky()  # Initialize GeoStat class for generating realizations
552            self.real_obs_data, self.scale_data = init_en.gen_real(self.obs_data_vector, self.cov_data, self.ne,
553                                                                   return_chol=True)
554
555    def _ext_state(self):
556        # get vector of scaling
557        self.state_scaling = at.calc_scaling(
558            self.prior_state, self.list_states, self.prior_info)
559
560        delta_scaled_prior = self.state_scaling[:, None] * \
561            np.dot(at.aug_state(self.prior_state, self.list_states), self.proj)
562
563        u_d, s_d, v_d = np.linalg.svd(delta_scaled_prior, full_matrices=False)
564
565        # remove the last singular value/vector. This is because numpy returns all ne values, while the last is actually
566        # zero. This part is a good place to include eventual additional truncation.
567        energy = 0
568        trunc_index = len(s_d) - 1  # inititallize
569        for c, elem in enumerate(s_d):
570            energy += elem
571            if energy / sum(s_d) >= self.trunc_energy:
572                trunc_index = c  # take the index where all energy is preserved
573                break
574        u_d, s_d, v_d = u_d[:, :trunc_index +
575                            1], s_d[:trunc_index + 1], v_d[:trunc_index + 1, :]
576        self.Am = np.dot(u_d, np.eye(trunc_index+1) *
577                         ((s_d**(-1))[:, None]))  # notation from paper
578
579    def save_temp_state_assim(self, ind_save):
580        """
581        Method to save the state variable during the assimilation. It is stored in a list with length = tot. no.
582        assim. steps + 1 (for the init. ensemble). The list of temporary states are also stored as a .npz file.
583
584        Parameters
585        ----------
586        ind_save : int
587            Assim. step to save (0 = prior)
588        """
589        # Init. temp. save
590        if ind_save == 0:
591            # +1 due to init. ensemble
592            self.temp_state = [None]*(len(self.get_list_assim_steps()) + 1)
593
594        # Save the state
595        self.temp_state[ind_save] = deepcopy(self.state)
596        np.savez('temp_state_assim', self.temp_state)
597
598    def save_temp_state_iter(self, ind_save, max_iter):
599        """
600        Save a snapshot of state at current iteration. It is stored in a list with length equal to max. iteration
601        length + 1 (due to prior state being 0). The list of temporary states are also stored as a .npz file.
602
603        .. warning:: Max. iterations must be defined before invoking this method.
604
605        Parameters
606        ----------
607        ind_save : int
608            Iteration step to save (0 = prior)
609        """
610        # Initial save
611        if ind_save == 0:
612            self.temp_state = [None] * (int(max_iter) + 1)  # +1 due to init. ensemble
613
614        # Save state
615        self.temp_state[ind_save] = deepcopy(self.state)
616        np.savez('temp_state_iter', self.temp_state)
617
618    def save_temp_state_mda(self, ind_save):
619        """
620        Save a snapshot of the state during a MDA loop. The temporary state will be stored as a list with length
621        equal to the tot. no. of assimilations + 1 (init. ensemble saved in 0 entry). The list of temporary states
622        are also stored as a .npz file.
623
624        .. warning:: Tot. no. of assimilations must be defined before invoking this method.
625
626        Parameter
627        ---------
628        ind_save : int
629            Assim. step to save (0 = prior)
630        """
631        # Initial save
632        if ind_save == 0:
633            # +1 due to init. ensemble
634            self.temp_state = [None] * (int(self.tot_assim) + 1)
635
636        # Save state
637        self.temp_state[ind_save] = deepcopy(self.state)
638        np.savez('temp_state_mda', self.temp_state)
639
640    def save_temp_state_ml(self, ind_save):
641        """
642        Save a snapshot of the state during a ML loop. The temporary state will be stored as a list with length
643        equal to the tot. no. of assimilations + 1 (init. ensemble saved in 0 entry). The list of temporary states
644        are also stored as a .npz file.
645
646        .. warning:: Tot. no. of assimilations must be defined before invoking this method.
647
648        Parameters
649        ----------
650        ind_save : int
651            Assim. step to save (0 = prior)
652        """
653        # Initial save
654        if ind_save == 0:
655            # +1 due to init. ensemble
656            self.temp_state = [None] * (int(self.tot_assim) + 1)
657
658        # Save state
659        self.temp_state[ind_save] = deepcopy(self.state)
660        np.savez('temp_state_ml', self.temp_state)
661
662    def compress(self, data=None, vintage=0, aug_coeff=None):
663        """
664        Compress the input data using wavelets.
665
666        Parameters
667        ----------
668        data:
669            data to be compressed
670            If data is `None`, all data (true and simulated) is re-compressed (used if leading indices are updated)
671        vintage: int
672            the time index for the data
673        aug_coeff: bool
674            - False: in this case the leading indices for wavelet coefficients are computed
675            - True: in this case the leading indices are augmented using information from the ensemble
676            - None: in this case simulated data is compressed
677        """
678
679        # If input data is None, we re-compress all data
680        data_array = None
681        if data is None:
682            vintage = 0
683            for i in range(len(self.obs_data)):  # TRUEDATAINDEX
684                for j in self.obs_data[i].keys():  # DATATYPE
685
686                    data_array = self.obs_data[i][j]
687
688                    # Perform compression if required
689                    if data_array is not None and \
690                            vintage < len(self.sparse_info['mask']) and \
691                            len(data_array) == int(np.sum(self.sparse_info['mask'][vintage])):
692                        data_array, wdec_rec = self.sparse_data[vintage].compress(
693                            data_array)  # compress
694                        self.obs_data[i][j] = data_array  # save array in obs_data
695                        rec = self.sparse_data[vintage].reconstruct(
696                            wdec_rec)  # reconstruct the data
697                        s = 'truedata_rec_' + str(vintage) + '.npz'
698                        np.savez(s, rec)  # save reconstructed data
699                        est_noise = np.power(self.sparse_data[vintage].est_noise, 2)
700                        self.datavar[i][j] = est_noise
701
702                        # Update the ensemble
703                        data_sim = self.pred_data[i][j]
704                        self.pred_data[i][j] = np.zeros((len(data_array), self.ne))
705                        self.data_rec.append([])
706                        for m in range(self.pred_data[i][j].shape[1]):
707                            data_array = data_sim[:, m]
708                            data_array, wdec_rec = self.sparse_data[vintage].compress(
709                                data_array)  # compress
710                            self.pred_data[i][j][:, m] = data_array
711                            rec = self.sparse_data[vintage].reconstruct(
712                                wdec_rec)  # reconstruct the data
713                            self.data_rec[vintage].append(rec)
714
715                        # Go to next vintage
716                        vintage = vintage + 1
717
718            # Option to store the dictionaries containing observed data and data variance
719            if 'obsvarsave' in self.keys_da and self.keys_da['obsvarsave'] == 'yes':
720                np.savez('obs_var', obs=self.obs_data, var=self.datavar)
721
722            if 'saveforecast' in self.keys_en:
723                s = 'prior_forecast_rec.npz'
724                np.savez(s, self.data_rec)
725
726            data_array = None
727
728        elif aug_coeff is None:
729
730            data_array, wdec_rec = self.sparse_data[vintage].compress(data)
731            rec = self.sparse_data[vintage].reconstruct(
732                wdec_rec)  # reconstruct the simulated data
733            if len(self.data_rec) == vintage:
734                self.data_rec.append([])
735            self.data_rec[vintage].append(rec)
736
737        elif not aug_coeff:
738
739            options = copy(self.sparse_info)
740            # find the correct mask for the vintage
741            options['mask'] = options['mask'][vintage]
742            if type(options['min_noise']) == list:
743                if 0 <= vintage < len(options['min_noise']):
744                    options['min_noise'] = options['min_noise'][vintage]
745                else:
746                    print(
747                        'Error: min_noise must either be scalar or list with one number for each vintage')
748                    sys.exit(1)
749            x = wt.SparseRepresentation(options)
750            data_array, wdec_rec = x.compress(data, self.sparse_info['th_mult'])
751            self.sparse_data.append(x)  # store the information
752            data_rec = x.reconstruct(wdec_rec)  # reconstruct the data
753            s = 'truedata_rec_' + str(vintage) + '.npz'
754            np.savez(s, data_rec)  # save reconstructed data
755            if self.sparse_info['use_ensemble']:
756                data_array = data  # just return the same as input
757
758        elif aug_coeff:
759
760            _, _ = self.sparse_data[vintage].compress(data, self.sparse_info['th_mult'])
761            data_array = data  # just return the same as input
762
763        return data_array
764
765    def local_analysis_update(self):
766        '''
767        Function for updates that can be used by all algorithms. Do this once to avoid duplicate code for local
768        analysis.
769        '''
770        orig_list_data = deepcopy(self.list_datatypes)
771        orig_list_state = deepcopy(self.list_states)
772        orig_cd = deepcopy(self.cov_data)
773        orig_real_obs_data = deepcopy(self.real_obs_data)
774        orig_data_vector = deepcopy(self.obs_data_vector)
775        # loop over the states that we want to update. Assume that the state and data combinations have been
776        # determined by the initialization.
777        # TODO: augment parameters with identical mask.
778        for state in self.local_analysis['region_parameter']:
779            self.list_datatypes = [elem for elem in self.list_datatypes if
780                                   elem in self.local_analysis['update_mask'][state]]
781            self.list_states = [deepcopy(state)]
782            self._ext_state()  # scaling for this state
783            if 'localization' in self.keys_da:
784                self.localization.loc_info['field'] = self.state_scaling.shape
785            del self.cov_data
786            # reset the random state for consistency
787            np.random.set_state(self.data_random_state)
788            self._ext_obs()  # get the data that's in the list of data.
789            _, self.aug_pred_data = at.aug_obs_pred_data(self.obs_data, self.pred_data, self.assim_index,
790                                                         self.list_datatypes)
791            # Mean pred_data and perturbation matrix with scaling
792            if len(self.scale_data.shape) == 1:
793                self.pert_preddata = np.dot(np.expand_dims(self.scale_data ** (-1), axis=1),
794                                            np.ones((1, self.ne))) * np.dot(self.aug_pred_data, self.proj)
795            else:
796                self.pert_preddata = solve(
797                    self.scale_data, np.dot(self.aug_pred_data, self.proj))
798
799            aug_state = at.aug_state(self.current_state, self.list_states)
800            self.update()
801            if hasattr(self, 'step'):
802                aug_state_upd = aug_state + self.step
803            self.state = at.update_state(aug_state_upd, self.state, self.list_states)
804
805        for state in self.local_analysis['vector_region_parameter']:
806            current_list_datatypes = deepcopy(self.list_datatypes)
807            for state_indx in range(self.state[state].shape[0]): # loop over the elements in the region
808                self.list_datatypes = [elem for elem in self.list_datatypes if
809                                       elem in self.local_analysis['update_mask'][state][state_indx]]
810                if len(self.list_datatypes):
811                    self.list_states = [deepcopy(state)]
812                    self._ext_state()  # scaling for this state
813                    if 'localization' in self.keys_da:
814                        self.localization.loc_info['field'] = self.state_scaling.shape
815                    del self.cov_data
816                    # reset the random state for consistency
817                    np.random.set_state(self.data_random_state)
818                    self._ext_obs()  # get the data that's in the list of data.
819                    _, self.aug_pred_data = at.aug_obs_pred_data(self.obs_data, self.pred_data, self.assim_index,
820                                                                 self.list_datatypes)
821                    # Mean pred_data and perturbation matrix with scaling
822                    if len(self.scale_data.shape) == 1:
823                        self.pert_preddata = np.dot(np.expand_dims(self.scale_data ** (-1), axis=1),
824                                                    np.ones((1, self.ne))) * np.dot(self.aug_pred_data, self.proj)
825                    else:
826                        self.pert_preddata = solve(
827                            self.scale_data, np.dot(self.aug_pred_data, self.proj))
828
829                    aug_state = at.aug_state(self.current_state, self.list_states)[state_indx,:]
830                    self.update()
831                    if hasattr(self, 'step'):
832                        aug_state_upd = aug_state + self.step[state_indx,:]
833                    self.state[state][state_indx,:] = aug_state_upd
834
835                self.list_datatypes = deepcopy(current_list_datatypes)
836
837        for state in self.local_analysis['cell_parameter']:
838            self.list_states = [deepcopy(state)]
839            self._ext_state()  # scaling for this state
840            orig_state_scaling = deepcopy(self.state_scaling)
841            param_position = self.local_analysis['parameter_position'][state]
842            field_size = param_position.shape
843            for k in range(field_size[0]):
844                for j in range(field_size[1]):
845                    for i in range(field_size[2]):
846                        current_data_list = list(
847                            self.local_analysis['update_mask'][state][k][j][i])
848                        current_data_list.sort()  # ensure consistent ordering of data
849                        if len(current_data_list):
850                            # if non-unique data for assimilation index, get the relevant data.
851                            if self.local_analysis['unique'] == False:
852                                orig_assim_index = deepcopy(self.assim_index)
853                                assim_index_data_list = set(
854                                    [el.split('_')[0] for el in current_data_list])
855                                current_assim_index = [
856                                    int(el.split('_')[1]) for el in current_data_list]
857                                current_data_list = list(assim_index_data_list)
858                                self.assim_index[1] = current_assim_index
859                            self.list_datatypes = deepcopy(current_data_list)
860                            del self.cov_data
861                            # reset the random state for consistency
862                            np.random.set_state(self.data_random_state)
863                            self._ext_obs()
864                            _, self.aug_pred_data = at.aug_obs_pred_data(self.obs_data, self.pred_data,
865                                                                         self.assim_index,
866                                                                         self.list_datatypes)
867                            # get parameter indexes
868                            full_cell_index = np.ravel_multi_index(
869                                np.array([[k], [j], [i]]), tuple(field_size))
870                            # count active values
871                            self.cell_index = [sum(param_position.flatten()[:el])
872                                               for el in full_cell_index]
873                            if 'localization' in self.keys_da:
874                                self.localization.loc_info['field'] = (
875                                    len(self.cell_index),)
876                                self.localization.loc_info['distance'] = cov_regularization._calc_distance(
877                                    self.local_analysis['data_position'],
878                                    self.local_analysis['unique'],
879                                    current_data_list, self.assim_index,
880                                    self.obs_data, self.pred_data, [(k, j, i)])
881                            # Set relevant state scaling
882                            self.state_scaling = orig_state_scaling[self.cell_index]
883
884                            # Mean pred_data and perturbation matrix with scaling
885                            if len(self.scale_data.shape) == 1:
886                                self.pert_preddata = np.dot(np.expand_dims(self.scale_data ** (-1), axis=1),
887                                                            np.ones((1, self.ne))) * np.dot(self.aug_pred_data,
888                                                                                            self.proj)
889                            else:
890                                self.pert_preddata = solve(
891                                    self.scale_data, np.dot(self.aug_pred_data, self.proj))
892
893                            aug_state = at.aug_state(
894                                self.current_state, self.list_states, self.cell_index)
895                            self.update()
896                            if hasattr(self, 'step'):
897                                aug_state_upd = aug_state + self.step
898                            self.state = at.update_state(
899                                aug_state_upd, self.state, self.list_states, self.cell_index)
900
901                            if self.local_analysis['unique'] == False:
902                                # reset assim index
903                                self.assim_index = deepcopy(orig_assim_index)
904                            if hasattr(self, 'localization') and 'distance' in self.localization.loc_info:  # reset
905                                del self.localization.loc_info['distance']
906
907        self.list_datatypes = deepcopy(orig_list_data)  # reset to original list
908        self.list_states = deepcopy(orig_list_state)
909        self.cov_data = deepcopy(orig_cd)
910        self.real_obs_data = deepcopy(orig_real_obs_data)
911        self.obs_data_vector = deepcopy(orig_data_vector)
912        self.cell_index = None
class Ensemble(ensemble.ensemble.Ensemble):
 25class Ensemble(PETEnsemble):
 26    """
 27    Class for organizing/initializing misc. variables and simulator for an
 28    ensemble-based inversion run. Inherits the PET ensemble structure
 29    """
 30
 31    def __init__(self, keys_da, keys_en, sim):
 32        """
 33        Parameters
 34        ----------
 35        keys_da : dict
 36            Options for the data assimilation class
 37
 38            - daalg: spesification of the method, first the main type (e.g., "enrml"), then the solver (e.g., "gnenrml")
 39            - analysis: update flavour ("approx", "full" or "subspace")
 40            - energy: percent of singular values kept after SVD
 41            - obsvarsave: save the observations as a file (default false)
 42            - restart: restart optimization from a restart file (default false)
 43            - restartsave: save a restart file after each successful iteration (defalut false)
 44            - analysisdebug: specify which class variables to save to the result files
 45            - truedataindex: order of the simulated data (for timeseries this is points in time)
 46            - obsname: unit for truedataindex (for timeseries this is days or hours or seconds, etc.)
 47            - truedata: the data, e.g., provided as a .csv file
 48            - assimindex: index for the data that will be used for assimilation
 49            - datatype: list with the name of the datatypes
 50            - staticvar: name of the static variables
 51            - datavar: data variance, e.g., provided as a .csv file
 52
 53        keys_en : dict
 54            Options for the ensemble class
 55
 56            - ne: number of perturbations used to compute the gradient
 57            - state: name of state variables passed to the .mako file
 58            - prior_<name>: the prior information the state variables, including mean, variance and variable limits
 59
 60        sim : callable
 61            The forward simulator (e.g. flow)
 62        """
 63
 64
 65        # do the initiallization of the PETensemble
 66        super(Ensemble, self).__init__(keys_en, sim)
 67
 68        # set logger
 69        self.logger = logging.getLogger('PET.PIPT')
 70
 71        # write initial information
 72        self.logger.info(f'Starting a {keys_da["daalg"][0]} run with the {keys_da["daalg"][1]} algorithm applying the '
 73                         f'{keys_da["analysis"]} update scheme with {keys_da["energy"]} Energy.')
 74
 75        # Internalize PIPT dictionary
 76        if not hasattr(self, 'keys_da'):
 77            self.keys_da = keys_da
 78        if not hasattr(self, 'keys_en'):
 79            self.keys_en = keys_en
 80
 81        if self.restart is False:
 82            # Init in _init_prediction_output (used in run_prediction)
 83            self.prediction = None
 84            self.temp_state = None  # temporary state saving
 85            self.cov_prior = None  # Prior cov. matrix
 86            self.sparse_info = None  # Init in _org_sparse_representation
 87            self.sparse_data = []  # List of the compression info
 88            self.data_rec = []  # List of reconstructed data
 89            self.scale_val = None  # Use to scale data
 90
 91            # Prepare sparse representation
 92            if 'compress' in self.keys_da:
 93                self._org_sparse_representation()
 94
 95            self._org_obs_data()
 96            self._org_data_var()
 97
 98            # define projection for centring and scaling
 99            self.proj = (np.eye(self.ne) - (1 / self.ne) *
100                         np.ones((self.ne, self.ne))) / np.sqrt(self.ne - 1)
101
102            # If we have dynamic state variables, we allocate keys for them in 'state'. Since we do not know the size
103            #  of the arrays of the dynamic variables, we only allocate an NE list to be filled in later (in
104            # calc_forecast)
105            if 'dynamicvar' in self.keys_da:
106                dyn_var = self.keys_da['dynamicvar'] if isinstance(self.keys_da['dynamicvar'], list) else \
107                    [self.keys_da['dynamicvar']]
108                for name in dyn_var:
109                    self.state[name] = [None] * self.ne
110
111            # Option to store the dictionaries containing observed data and data variance
112            if 'obsvarsave' in self.keys_da and self.keys_da['obsvarsave'] == 'yes':
113                np.savez('obs_var', obs=self.obs_data, var=self.datavar)
114
115            # Initialize localization
116            if 'localization' in self.keys_da:
117                self.localization = cov_regularization.localization(self.keys_da['localization'],
118                                                                    self.keys_da['truedataindex'],
119                                                                    self.keys_da['datatype'],
120                                                                    self.keys_da['staticvar'],
121                                                                    self.ne)
122            # Initialize local analysis
123            if 'localanalysis' in self.keys_da:
124                self.local_analysis = at.init_local_analysis(
125                    init=self.keys_da['localanalysis'], state=self.state.keys())
126
127            self.pred_data = [{k: np.zeros((1, self.ne), dtype='float32') for k in self.keys_da['datatype']}
128                              for _ in self.obs_data]
129
130            self.cell_index = None  # default value for extracting states
131
132    def check_assimindex_sequential(self):
133        """
134        Check if assim. indices is given as a 2D list as is needed in sequential updating. If not, make it a 2D list
135        """
136        # Check if ASSIMINDEX is a list. If not, make it a 2D list
137        if not isinstance(self.keys_da['assimindex'], list):
138            self.keys_da['assimindex'] = [[self.keys_da['assimindex']]]
139
140        # If ASSIMINDEX is a 1D list (either given in as a single row or single column), we reshape to a 2D list
141        elif not isinstance(self.keys_da['assimindex'][0], list):
142            assimindex_temp = [None] * len(self.keys_da['assimindex'])
143
144            for i in range(len(self.keys_da['assimindex'])):
145                assimindex_temp[i] = [self.keys_da['assimindex'][i]]
146
147            self.keys_da['assimindex'] = assimindex_temp
148
149    def check_assimindex_simultaneous(self):
150        """
151        Check if assim. indices is given as a 1D list as is needed in simultaneous updating. If not, make it a 2D list
152        with one row.
153        """
154        # Check if ASSIMINDEX is a list. If not, make it a 2D list with one row
155        if not isinstance(self.keys_da['assimindex'], list):
156            self.keys_da['assimindex'] = [[self.keys_da['assimindex']]]
157
158        # Check if ASSIMINDEX is a 1D list. If true, make it a 2D list with one row
159        elif not isinstance(self.keys_da['assimindex'][0], list):
160            self.keys_da['assimindex'] = [self.keys_da['assimindex']]
161
162        # If ASSIMINDEX is a 2D list, we reshape it to a 2D list with one row
163        elif isinstance(self.keys_da['assimindex'][0], list):
164            self.keys_da['assimindex'] = [
165                [item for sublist in self.keys_da['assimindex'] for item in sublist]]
166
167    def _org_obs_data(self):
168        """
169        Organize the input true observed data. The obs_data will be a list of length equal length of "TRUEDATAINDEX",
170        and each entery in the list will be a dictionary with keys equal to the "DATATYPE".
171        Also, the pred_data variable (predicted data or forward simulation) will be initialized here with the same
172        structure as the obs_data variable.
173
174        .. warning:: An "N/A" entry in "TRUEDATA" is treated as a None-entry; that is, there is NOT an observed data at this
175        assimilation step.
176
177        .. warning:: The array associated with the first string inputted in "TRUEDATAINDEX" is assumed to be the "main"
178        index, that is, the length of this array will determine the length of the obs_data list! There arrays
179        associated with the subsequent strings in "TRUEDATAINDEX" are then assumed to be a subset of the first
180        string.  An example: the first string is SOURCE (e.g., sources in CSEM), where the array will be a list of numbering
181        for the sources; and the second string is FREQ, where the array associated will be a list of frequencies.
182
183        .. note:: It is assumed that the number of data associated with a subset is the same for each index in the subset.
184        For example: If two frequencies are inputted in FREQ, then the number of data for one SOURCE index and one
185        frequency is 1/2 of the total no. of data for that SOURCE index. If three frequencies are inputted, the number
186        of data for one SOURCE index and one frequencies is 1/3 of the total no of data for that SOURCE index,
187        and so on.
188        """
189
190        # # Check if keys_da['datatype'] is a string or list, and make it a list if single string is given
191        # if isinstance(self.keys_da['datatype'], str):
192        #     datatype = [self.keys_da['datatype']]
193        # else:
194        #     datatype = self.keys_da['datatype']
195        #
196        # # Extract primary indices from "TRUEDATAINDEX"
197        # if isinstance(self.keys_da['truedataindex'], list):  # List of prim. ind
198        #     true_prim = self.keys_da['truedataindex']
199        # else:  # Float
200        #     true_prim = [self.keys_da['truedataindex']]
201        #
202        # # Check if a csv file has been included as "TRUEDATAINDEX". If so, we read it and make a list,
203        # if isinstance(self.keys_da['truedataindex'], str) and self.keys_da['truedataindex'].endswith('.csv'):
204        #     with open(self.keys_da['truedataindex']) as csvfile:
205        #         reader = csv.reader(csvfile)  # get a reader object
206        #         true_prim = []  # Initialize the list of csv data
207        #         for rows in reader:  # Rows is a list of values in the csv file
208        #             csv_data = [None] * len(rows)
209        #             for ind, col in enumerate(rows):
210        #                 csv_data[ind] = int(col)
211        #             true_prim.extend(csv_data)
212        #     self.keys_da['truedataindex'] = true_prim
213        #
214        # # Check if a csv file has been included as "PREDICTION". If so, we read it and make a list,
215        # if 'prediction' in self.keys_da:
216        #     if isinstance(self.keys_da['prediction'], str) and self.keys_da['prediction'].endswith('.csv'):
217        #         with open(self.keys_da['prediction']) as csvfile:
218        #             reader = csv.reader(csvfile)  # get a reader object
219        #             pred_prim = []  # Initialize the list of csv data
220        #             for rows in reader:  # Rows is a list of values in the csv file
221        #                 csv_data = [None] * len(rows)
222        #                 for ind, col in enumerate(rows):
223        #                     csv_data[ind] = int(col)
224        #                 pred_prim.extend(csv_data)
225        #         self.keys_da['prediction'] = pred_prim
226
227        # Extract the observed data from "TRUEDATA"
228        if len(self.keys_da['truedataindex']) == 1:  # Only one assimilation step
229            if isinstance(self.keys_da['truedata'], list):
230                truedata = [self.keys_da['truedata']]
231            else:
232                truedata = [[self.keys_da['truedata']]]
233        else:  # More than one assim. step
234            if isinstance(self.keys_da['truedata'][0], list):  # 2D list
235                truedata = self.keys_da['truedata']
236            else:
237                truedata = [[x] for x in self.keys_da['truedata']]  # Make it a 2D list
238
239        # Initialize obs_data list. List length = len("TRUEDATAINDEX"); dictionary in each list entry = d
240        self.obs_data = [None] * len(self.keys_da['truedataindex'])
241
242        # Check if a csv file has been included in TRUEDATA. If so, we read it and make a 2D list, which we can use
243        # in the below when assigning data to obs_data dictionary
244        if isinstance(self.keys_da['truedata'], str) and self.keys_da['truedata'].endswith('.csv'):
245            truedata = rcsv.read_data_csv(
246                self.keys_da['truedata'], self.keys_da['datatype'], self.keys_da['truedataindex'])
247
248        # # Check if assimindex is given as a csv file. If so, we read and make a potential 2D list (if sequential).
249        # if isinstance(self.keys_da['assimindex'], str) and self.keys_da['assimindex'].endswith('.csv'):
250        #     with open(self.keys_da['assimindex']) as csvfile:
251        #         reader = csv.reader(csvfile)  # get a reader object
252        #         assimindx = []  # Initialize the 2D list of csv data
253        #         for rows in reader:  # Rows is a list of values in the csv file
254        #             csv_data = [None] * len(rows)
255        #             for col in range(len(rows)):
256        #                 csv_data[col] = int(rows[col])
257        #             assimindx.append(csv_data)
258        #     self.keys_da['assimindex'] = assimindx
259
260        # Now we loop over all list entries in obs_data and fill in the observed data from "TRUEDATA".
261        # NOTE: Not all data types may have observed data at each "TRUEDATAINDEX"; in this case it will have a None
262        # entry.
263        # NOTE2: If "TRUEDATA" contains a .npz file, this will be loaded. BUT the array loaded MUST be a 1D numpy
264        # array! So resize BEFORE saving the .npz file!
265        # NOTE3: If CSV file has been included in TRUEDATA, we read the data from this file
266        vintage = 0
267        for i in range(len(self.obs_data)):  # TRUEDATAINDEX
268            # Init. dict. with datatypes (do inside loop to avoid copy of same entry)
269            self.obs_data[i] = {}
270            # Make unified inputs
271            if 'unif_in' in self.keys_da and self.keys_da['unif_in'] == 'yes':
272                if isinstance(truedata[i][0], str) and truedata[i][0].endswith('.npz'):
273                    load_data = np.load(truedata[i][0])  # Load the .npz file
274                    data_array = load_data[load_data.files[0]]
275
276                    # Perform compression if required (we only and always compress signals with same size as number of active cells)
277                    if self.sparse_info is not None and \
278                            vintage < len(self.sparse_info['mask']) and \
279                            len(data_array) == int(np.sum(self.sparse_info['mask'][vintage])):
280                        data_array = self.compress(data_array, vintage, False)
281                        vintage = vintage + 1
282
283                    # Save array in obs_data. If it is an array with single value (not list), then we convert it to a
284                    # list with one entry.
285                    self.obs_data[i][self.keys_da['datatype'][0]] = np.array(
286                        [data_array[()]]) if data_array.shape == () else data_array
287
288                    # Entry is N/A, i.e., no data given
289                elif isinstance(truedata[i][0], str) and not truedata[i][0].endswith('.npz') \
290                        and truedata[i][0].lower() == 'n/a':
291                    self.obs_data[i][self.keys_da['datatype'][0]] = None
292
293                # Unknown string entry
294                elif isinstance(truedata[i][0], str) and not truedata[i][0].endswith('.npz') \
295                        and not truedata[i][0].lower() == 'n/a':
296                    print(
297                        '\n\033[1;31mERROR: Cannot load observed data file! Maybe it is not a .npz file?\033[1;m')
298                    sys.exit(1)
299                # Entry is a numerical value
300                elif not isinstance(truedata[i][0], str):  # Some numerical value or None
301                    self.obs_data[i][self.keys_da['datatype'][0]] = np.array(
302                        truedata[i][:])  # no need to make this into a list
303            else:
304                for j in range(len(self.keys_da['datatype'])):  # DATATYPE
305                    # Load a Numpy npz file
306                    if isinstance(truedata[i][j], str) and truedata[i][j].endswith('.npz'):
307                        load_data = np.load(truedata[i][j])  # Load the .npz file
308                        data_array = load_data[load_data.files[0]]
309
310                        # Perform compression if required (we only and always compress signals with same size as number of active cells)
311                        if self.sparse_info is not None and \
312                                vintage < len(self.sparse_info['mask']) and \
313                                len(data_array) == int(np.sum(self.sparse_info['mask'][vintage])):
314                            data_array = self.compress(data_array, vintage, False)
315                            vintage = vintage + 1
316
317                        # Save array in obs_data. If it is an array with single value (not list), then we convert it to a
318                        # list with one entry
319                        self.obs_data[i][self.keys_da['datatype'][j]] = np.array(
320                            [data_array[()]]) if data_array.shape == () else data_array
321
322                    # Entry is N/A, i.e., no data given
323                    elif isinstance(truedata[i][j], str) and not truedata[i][j].endswith('.npz') \
324                            and truedata[i][j].lower() == 'n/a':
325                        self.obs_data[i][self.keys_da['datatype'][j]] = None
326
327                    # Unknown string entry
328                    elif isinstance(truedata[i][j], str) and not truedata[i][j].endswith('.npz') \
329                            and not truedata[i][j].lower() == 'n/a':
330                        print(
331                            '\n\033[1;31mERROR: Cannot load observed data file! Maybe it is not a .npz file?\033[1;m')
332                        sys.exit(1)
333
334                    # Entry is a numerical value
335                    # Some numerical value or None
336                    elif not isinstance(truedata[i][j], str):
337                        if type(truedata[i][j]) is numpy.ndarray:
338                            self.obs_data[i][self.keys_da['datatype'][j]] = truedata[i][j]
339                        else:
340                            self.obs_data[i][self.keys_da['datatype'][j]] = np.array([truedata[i][j]])
341
342                    # Scale data if required (currently only one group of data can be scaled)
343                    if 'scale' in self.keys_da and self.keys_da['scale'][0] in self.keys_da['datatype'][j] and \
344                            self.obs_data[i][self.keys_da['datatype'][j]] is not None:
345                        self.obs_data[i][self.keys_da['datatype']
346                                         [j]] *= self.keys_da['scale'][1]
347
348    def _org_data_var(self):
349        """
350        Organize the input data variance given by the keyword "DATAVAR" in the "DATAASSIM" part the init_file.
351
352        If a diagonal auto-covariance is to be used to generate data, there are two options for data variance: absolute
353        and relative variance. Absolute is a fixed value for the variance, and relative is a percentage of
354        the observed data as standard deviation which in turn is set as variance. If we want to use an empirical data
355        covariance matrix to generate data, the user must supply a Numpy save file with samples, which is loaded here.
356        If we want to specify the whole covariance matrix, this can also be done. The user must supply a Numpy save file
357        which is loaded here.
358
359        .. warning:: When relative variance is given as input, we set the variance as (true_obs_data*rel_perc*0.01)**2
360        BECAUSE we often want this alternative in cases where we "add some percentage of Gaussian noise to the
361        observed data". Hence, we actually want some percentage of the true observed data as STANDARD DEVIATION since
362        it ultimately is the standard deviation (through square-root decompostion of Cd) that is used when adding
363        noise to observed data.Note that this is ONLY a matter of definition, but we feel that this way of defining
364        relative variance is most common.
365        """
366        # TODO: Change when sub-assim. indices have been re-implemented.
367
368        # Check if keys_da['datatype'] is a string or list, and make it a list if single string is given
369        if isinstance(self.keys_da['datatype'], str):
370            datatype = [self.keys_da['datatype']]
371        else:
372            datatype = self.keys_da['datatype']
373
374        # Extract primary indices from "TRUEDATAINDEX"
375        if isinstance(self.keys_da['truedataindex'], list):  # List of prim. ind
376            true_prim = self.keys_da['truedataindex']
377        else:  # Float
378            true_prim = [self.keys_da['truedataindex']]
379
380        #
381        # Extract the data variance from "DATAVAR"
382        #
383        # Only one assimilation step
384        if len(true_prim) == 1:
385            # More than one DATATYPE, but only one entry in DATAVAR
386            if len(self.keys_da['datavar']) == 2 and len(datatype) > 1:
387                # Copy list entry no. data type times
388                datavar = [self.keys_da['datavar'] * len(datatype)]
389
390            # One DATATYPE
391            else:
392                datavar = [self.keys_da['datavar']]
393
394        # More than one assim. step
395        else:
396            # More than one DATATYPE, but only one entry in DATAVAR
397            if not isinstance(self.keys_da['datavar'][0], list) and len(self.keys_da['datavar']) == 2 and \
398                    len(datatype) > 1:
399                # Need to make a list with entries equal to 2*no. data types (since there are 2 entries in DATAVAR
400                # for one data type). Then we copy this list as many times as we have TRUEDATAINDEX (i.e.,
401                # we get a 2D list)
402                # Copy list entry no. data types times
403                datavar_temp = self.keys_da['datavar'] * len(datatype)
404                datavar = [None] * len(true_prim)  # Init.
405                for i in range(len(true_prim)):
406                    datavar[i] = deepcopy(datavar_temp)
407
408            # Entry for each DATATYPE, but not for each TRUEDATAINDEX
409            elif (len(self.keys_da['datavar'])) / 2 == len(datatype) and \
410                    not isinstance(self.keys_da['datavar'][0], list):
411                # If we have entry for each DATATYPE but NOT for each TRUEDATAINDEX, then we just copy the list of
412                # entries to each TRUEDATAINDEX
413                datavar = [None] * len(true_prim)  # Init.
414                for i in range(len(true_prim)):
415                    datavar[i] = deepcopy(self.keys_da['datavar'])
416
417            else:
418                datavar = self.keys_da['datavar']
419
420        # Check if a csv file has been included in DATAVAR. If so datavar will be redefined and variance info will be
421        #  extracted from the csv file
422        if isinstance(self.keys_da['datavar'], str) and self.keys_da['datavar'].endswith('.csv'):
423            datavar = rcsv.read_var_csv(self.keys_da['datavar'], datatype, true_prim)
424
425        # Initialize datavar output
426        self.datavar = [None] * len(true_prim)
427
428        # Loop over all entries in datavar and fill in values from "DATAVAR" (use obs_data values in the REL variance
429        #  cases)
430        # TODO: Implement loading of data variance from .npz file
431        vintage = 0
432        for i in range(len(self.obs_data)):  # TRUEDATAINDEX
433            # Init. dict. with datatypes (do inside loop to avoid copy of same entry)
434            self.datavar[i] = {}
435            for j in range(len(datatype)):  # DATATYPE
436                # ABS
437                # Absolute var.
438                if datavar[i][2*j] == 'abs' and self.obs_data[i][datatype[j]] is not None:
439                    self.datavar[i][datatype[j]] = datavar[i][2*j+1] * \
440                        np.ones(len(self.obs_data[i][datatype[j]]))
441
442                # REL
443                # Rel. var.
444                elif datavar[i][2*j] == 'rel' and self.obs_data[i][datatype[j]] is not None:
445                    # Rel. var WITH a min. variance tolerance
446                    if isinstance(datavar[i][2*j+1], list):
447                        self.datavar[i][datatype[j]] = (datavar[i][2*j+1][0] * 0.01 *
448                                                        self.obs_data[i][datatype[j]]) ** 2
449                        ind_tol = self.datavar[i][datatype[j]] < datavar[i][2*j+1][1] ** 2
450                        self.datavar[i][datatype[j]][ind_tol] = datavar[i][2*j+1][1] ** 2
451
452                    else:  # Single. rel. var input
453                        var = (datavar[i][2*j+1] * 0.01 * self.obs_data[i][datatype[j]]) ** 2
454                        var = np.clip(var, 1.0e-9, None)  # avoid zero variance
455                        self.datavar[i][datatype[j]] = var
456                # EMP
457                elif datavar[i][2*j] == 'emp' and datavar[i][2*j+1].endswith('.npz') and \
458                        self.obs_data[i][datatype[j]] is not None:  # Empirical var.
459                    load_data = np.load(datavar[i][2*j+1])  # load the numpy savez file
460                    # store in datavar
461                    self.datavar[i][datatype[j]] = load_data[load_data.files[0]]
462
463                # LOAD
464                elif datavar[i][2*j] == 'load' and datavar[i][2*j+1].endswith('.npz') and \
465                        self.obs_data[i][datatype[j]] is not None:  # Load variance. (1d array)
466                    load_data = np.load(datavar[i][2*j+1])  # load the numpy savez file
467                    load_data = load_data[load_data.files[0]]
468                    self.datavar[i][datatype[j]] = load_data  # store in datavar
469
470                # CD the full covariance matrix is given in its correct format. Hence, load once and set as CD
471                elif datavar[i][2 * j] == 'cd' and datavar[i][2 * j + 1].endswith('.npz') and \
472                        self.obs_data[i][datatype[j]] is not None:
473                    if not hasattr(self, 'cov_data'):  # check to populate once
474                        # load the numpy savez file
475                        load_data = np.load(datavar[i][2 * j + 1])
476                        self.cov_data = load_data[load_data.files[0]]
477                    # store the variance
478                    self.datavar[i][datatype[j]] = self.cov_data[i*j, i*j]
479
480                elif self.obs_data[i][datatype[j]] is None:  # No observed data
481                    self.datavar[i][datatype[j]] = None  # Set None type here also
482
483                # Handle case when noise is estimated using wavelets
484                if self.sparse_info is not None and self.datavar[i][datatype[j]] is not None and \
485                        vintage < len(self.sparse_info['mask']) and \
486                        len(self.datavar[i][datatype[j]]) == int(np.sum(self.sparse_info['mask'][vintage])):
487                    # compute var from sparse_data
488                    est_noise = np.power(self.sparse_data[vintage].est_noise, 2)
489                    self.datavar[i][datatype[j]] = est_noise  # override the given value
490                    vintage = vintage + 1
491
492    def _org_sparse_representation(self):
493        """
494        Function for reading input to wavelet sparse representation of data.
495        """
496        self.sparse_info = {}
497        parsed_info = self.keys_da['compress']
498        dim = [int(elem) for elem in parsed_info[0][1]]
499        # flip to align with flow / eclipse
500        self.sparse_info['dim'] = [dim[2], dim[1], dim[0]]
501        self.sparse_info['mask'] = []
502        for vint in range(1, len(parsed_info[1])):
503            if not os.path.exists(parsed_info[1][vint]):
504                mask = np.ones(self.sparse_info['dim'], dtype=bool)
505                np.savez(f'mask_{vint-1}.npz', mask=mask)
506            else:
507                mask = np.load(parsed_info[1][vint])['mask']
508            self.sparse_info['mask'].append(mask.flatten())
509        self.sparse_info['level'] = parsed_info[2][1]
510        self.sparse_info['wname'] = parsed_info[3][1]
511        self.sparse_info['colored_noise'] = True if parsed_info[4][1] == 'yes' else False
512        self.sparse_info['threshold_rule'] = parsed_info[5][1]
513        self.sparse_info['th_mult'] = parsed_info[6][1]
514        self.sparse_info['use_hard_th'] = True if parsed_info[7][1] == 'yes' else False
515        self.sparse_info['keep_ca'] = True if parsed_info[8][1] == 'yes' else False
516        self.sparse_info['inactive_value'] = parsed_info[9][1]
517        self.sparse_info['use_ensemble'] = True if parsed_info[10][1] == 'yes' else None
518        self.sparse_info['order'] = parsed_info[11][1]
519        self.sparse_info['min_noise'] = parsed_info[12][1]
520
521    def _ext_obs(self):
522        self.obs_data_vector, _ = at.aug_obs_pred_data(self.obs_data, self.pred_data, self.assim_index,
523                                                       self.list_datatypes)
524        # Generate the data auto-covariance matrix
525        if 'emp_cov' in self.keys_da and self.keys_da['emp_cov'] == 'yes':
526            if hasattr(self, 'cov_data'):  # cd matrix has been imported
527                tmp_E = np.dot(cholesky(self.cov_data).T,
528                               np.random.randn(self.cov_data.shape[0], self.ne))
529            else:
530                tmp_E = at.extract_tot_empirical_cov(
531                    self.datavar, self.assim_index, self.list_datatypes, self.ne)
532            # self.E = (tmp_E - tmp_E.mean(1)[:,np.newaxis])/np.sqrt(self.ne - 1)/
533            if 'screendata' in self.keys_da and self.keys_da['screendata'] == 'yes':
534                tmp_E = at.screen_data(tmp_E, self.aug_pred_data,
535                                       self.obs_data_vector, self.iteration)
536            self.E = tmp_E
537            self.real_obs_data = self.obs_data_vector[:, np.newaxis] - tmp_E
538
539            self.cov_data = np.var(self.E, ddof=1,
540                                   axis=1)  # calculate the variance, to be used for e.g. data misfit calc
541            # self.cov_data = ((self.E * self.E)/(self.ne-1)).sum(axis=1) # calculate the variance, to be used for e.g. data misfit calc
542            self.scale_data = np.sqrt(self.cov_data)
543        else:
544            if not hasattr(self, 'cov_data'):  # if cd is not loaded
545                self.cov_data = at.gen_covdata(
546                    self.datavar, self.assim_index, self.list_datatypes)
547            # data screening
548            if 'screendata' in self.keys_da and self.keys_da['screendata'] == 'yes':
549                self.cov_data = at.screen_data(
550                    self.cov_data, self.aug_pred_data, self.obs_data_vector, self.iteration)
551
552            init_en = Cholesky()  # Initialize GeoStat class for generating realizations
553            self.real_obs_data, self.scale_data = init_en.gen_real(self.obs_data_vector, self.cov_data, self.ne,
554                                                                   return_chol=True)
555
556    def _ext_state(self):
557        # get vector of scaling
558        self.state_scaling = at.calc_scaling(
559            self.prior_state, self.list_states, self.prior_info)
560
561        delta_scaled_prior = self.state_scaling[:, None] * \
562            np.dot(at.aug_state(self.prior_state, self.list_states), self.proj)
563
564        u_d, s_d, v_d = np.linalg.svd(delta_scaled_prior, full_matrices=False)
565
566        # remove the last singular value/vector. This is because numpy returns all ne values, while the last is actually
567        # zero. This part is a good place to include eventual additional truncation.
568        energy = 0
569        trunc_index = len(s_d) - 1  # inititallize
570        for c, elem in enumerate(s_d):
571            energy += elem
572            if energy / sum(s_d) >= self.trunc_energy:
573                trunc_index = c  # take the index where all energy is preserved
574                break
575        u_d, s_d, v_d = u_d[:, :trunc_index +
576                            1], s_d[:trunc_index + 1], v_d[:trunc_index + 1, :]
577        self.Am = np.dot(u_d, np.eye(trunc_index+1) *
578                         ((s_d**(-1))[:, None]))  # notation from paper
579
580    def save_temp_state_assim(self, ind_save):
581        """
582        Method to save the state variable during the assimilation. It is stored in a list with length = tot. no.
583        assim. steps + 1 (for the init. ensemble). The list of temporary states are also stored as a .npz file.
584
585        Parameters
586        ----------
587        ind_save : int
588            Assim. step to save (0 = prior)
589        """
590        # Init. temp. save
591        if ind_save == 0:
592            # +1 due to init. ensemble
593            self.temp_state = [None]*(len(self.get_list_assim_steps()) + 1)
594
595        # Save the state
596        self.temp_state[ind_save] = deepcopy(self.state)
597        np.savez('temp_state_assim', self.temp_state)
598
599    def save_temp_state_iter(self, ind_save, max_iter):
600        """
601        Save a snapshot of state at current iteration. It is stored in a list with length equal to max. iteration
602        length + 1 (due to prior state being 0). The list of temporary states are also stored as a .npz file.
603
604        .. warning:: Max. iterations must be defined before invoking this method.
605
606        Parameters
607        ----------
608        ind_save : int
609            Iteration step to save (0 = prior)
610        """
611        # Initial save
612        if ind_save == 0:
613            self.temp_state = [None] * (int(max_iter) + 1)  # +1 due to init. ensemble
614
615        # Save state
616        self.temp_state[ind_save] = deepcopy(self.state)
617        np.savez('temp_state_iter', self.temp_state)
618
619    def save_temp_state_mda(self, ind_save):
620        """
621        Save a snapshot of the state during a MDA loop. The temporary state will be stored as a list with length
622        equal to the tot. no. of assimilations + 1 (init. ensemble saved in 0 entry). The list of temporary states
623        are also stored as a .npz file.
624
625        .. warning:: Tot. no. of assimilations must be defined before invoking this method.
626
627        Parameter
628        ---------
629        ind_save : int
630            Assim. step to save (0 = prior)
631        """
632        # Initial save
633        if ind_save == 0:
634            # +1 due to init. ensemble
635            self.temp_state = [None] * (int(self.tot_assim) + 1)
636
637        # Save state
638        self.temp_state[ind_save] = deepcopy(self.state)
639        np.savez('temp_state_mda', self.temp_state)
640
641    def save_temp_state_ml(self, ind_save):
642        """
643        Save a snapshot of the state during a ML loop. The temporary state will be stored as a list with length
644        equal to the tot. no. of assimilations + 1 (init. ensemble saved in 0 entry). The list of temporary states
645        are also stored as a .npz file.
646
647        .. warning:: Tot. no. of assimilations must be defined before invoking this method.
648
649        Parameters
650        ----------
651        ind_save : int
652            Assim. step to save (0 = prior)
653        """
654        # Initial save
655        if ind_save == 0:
656            # +1 due to init. ensemble
657            self.temp_state = [None] * (int(self.tot_assim) + 1)
658
659        # Save state
660        self.temp_state[ind_save] = deepcopy(self.state)
661        np.savez('temp_state_ml', self.temp_state)
662
663    def compress(self, data=None, vintage=0, aug_coeff=None):
664        """
665        Compress the input data using wavelets.
666
667        Parameters
668        ----------
669        data:
670            data to be compressed
671            If data is `None`, all data (true and simulated) is re-compressed (used if leading indices are updated)
672        vintage: int
673            the time index for the data
674        aug_coeff: bool
675            - False: in this case the leading indices for wavelet coefficients are computed
676            - True: in this case the leading indices are augmented using information from the ensemble
677            - None: in this case simulated data is compressed
678        """
679
680        # If input data is None, we re-compress all data
681        data_array = None
682        if data is None:
683            vintage = 0
684            for i in range(len(self.obs_data)):  # TRUEDATAINDEX
685                for j in self.obs_data[i].keys():  # DATATYPE
686
687                    data_array = self.obs_data[i][j]
688
689                    # Perform compression if required
690                    if data_array is not None and \
691                            vintage < len(self.sparse_info['mask']) and \
692                            len(data_array) == int(np.sum(self.sparse_info['mask'][vintage])):
693                        data_array, wdec_rec = self.sparse_data[vintage].compress(
694                            data_array)  # compress
695                        self.obs_data[i][j] = data_array  # save array in obs_data
696                        rec = self.sparse_data[vintage].reconstruct(
697                            wdec_rec)  # reconstruct the data
698                        s = 'truedata_rec_' + str(vintage) + '.npz'
699                        np.savez(s, rec)  # save reconstructed data
700                        est_noise = np.power(self.sparse_data[vintage].est_noise, 2)
701                        self.datavar[i][j] = est_noise
702
703                        # Update the ensemble
704                        data_sim = self.pred_data[i][j]
705                        self.pred_data[i][j] = np.zeros((len(data_array), self.ne))
706                        self.data_rec.append([])
707                        for m in range(self.pred_data[i][j].shape[1]):
708                            data_array = data_sim[:, m]
709                            data_array, wdec_rec = self.sparse_data[vintage].compress(
710                                data_array)  # compress
711                            self.pred_data[i][j][:, m] = data_array
712                            rec = self.sparse_data[vintage].reconstruct(
713                                wdec_rec)  # reconstruct the data
714                            self.data_rec[vintage].append(rec)
715
716                        # Go to next vintage
717                        vintage = vintage + 1
718
719            # Option to store the dictionaries containing observed data and data variance
720            if 'obsvarsave' in self.keys_da and self.keys_da['obsvarsave'] == 'yes':
721                np.savez('obs_var', obs=self.obs_data, var=self.datavar)
722
723            if 'saveforecast' in self.keys_en:
724                s = 'prior_forecast_rec.npz'
725                np.savez(s, self.data_rec)
726
727            data_array = None
728
729        elif aug_coeff is None:
730
731            data_array, wdec_rec = self.sparse_data[vintage].compress(data)
732            rec = self.sparse_data[vintage].reconstruct(
733                wdec_rec)  # reconstruct the simulated data
734            if len(self.data_rec) == vintage:
735                self.data_rec.append([])
736            self.data_rec[vintage].append(rec)
737
738        elif not aug_coeff:
739
740            options = copy(self.sparse_info)
741            # find the correct mask for the vintage
742            options['mask'] = options['mask'][vintage]
743            if type(options['min_noise']) == list:
744                if 0 <= vintage < len(options['min_noise']):
745                    options['min_noise'] = options['min_noise'][vintage]
746                else:
747                    print(
748                        'Error: min_noise must either be scalar or list with one number for each vintage')
749                    sys.exit(1)
750            x = wt.SparseRepresentation(options)
751            data_array, wdec_rec = x.compress(data, self.sparse_info['th_mult'])
752            self.sparse_data.append(x)  # store the information
753            data_rec = x.reconstruct(wdec_rec)  # reconstruct the data
754            s = 'truedata_rec_' + str(vintage) + '.npz'
755            np.savez(s, data_rec)  # save reconstructed data
756            if self.sparse_info['use_ensemble']:
757                data_array = data  # just return the same as input
758
759        elif aug_coeff:
760
761            _, _ = self.sparse_data[vintage].compress(data, self.sparse_info['th_mult'])
762            data_array = data  # just return the same as input
763
764        return data_array
765
766    def local_analysis_update(self):
767        '''
768        Function for updates that can be used by all algorithms. Do this once to avoid duplicate code for local
769        analysis.
770        '''
771        orig_list_data = deepcopy(self.list_datatypes)
772        orig_list_state = deepcopy(self.list_states)
773        orig_cd = deepcopy(self.cov_data)
774        orig_real_obs_data = deepcopy(self.real_obs_data)
775        orig_data_vector = deepcopy(self.obs_data_vector)
776        # loop over the states that we want to update. Assume that the state and data combinations have been
777        # determined by the initialization.
778        # TODO: augment parameters with identical mask.
779        for state in self.local_analysis['region_parameter']:
780            self.list_datatypes = [elem for elem in self.list_datatypes if
781                                   elem in self.local_analysis['update_mask'][state]]
782            self.list_states = [deepcopy(state)]
783            self._ext_state()  # scaling for this state
784            if 'localization' in self.keys_da:
785                self.localization.loc_info['field'] = self.state_scaling.shape
786            del self.cov_data
787            # reset the random state for consistency
788            np.random.set_state(self.data_random_state)
789            self._ext_obs()  # get the data that's in the list of data.
790            _, self.aug_pred_data = at.aug_obs_pred_data(self.obs_data, self.pred_data, self.assim_index,
791                                                         self.list_datatypes)
792            # Mean pred_data and perturbation matrix with scaling
793            if len(self.scale_data.shape) == 1:
794                self.pert_preddata = np.dot(np.expand_dims(self.scale_data ** (-1), axis=1),
795                                            np.ones((1, self.ne))) * np.dot(self.aug_pred_data, self.proj)
796            else:
797                self.pert_preddata = solve(
798                    self.scale_data, np.dot(self.aug_pred_data, self.proj))
799
800            aug_state = at.aug_state(self.current_state, self.list_states)
801            self.update()
802            if hasattr(self, 'step'):
803                aug_state_upd = aug_state + self.step
804            self.state = at.update_state(aug_state_upd, self.state, self.list_states)
805
806        for state in self.local_analysis['vector_region_parameter']:
807            current_list_datatypes = deepcopy(self.list_datatypes)
808            for state_indx in range(self.state[state].shape[0]): # loop over the elements in the region
809                self.list_datatypes = [elem for elem in self.list_datatypes if
810                                       elem in self.local_analysis['update_mask'][state][state_indx]]
811                if len(self.list_datatypes):
812                    self.list_states = [deepcopy(state)]
813                    self._ext_state()  # scaling for this state
814                    if 'localization' in self.keys_da:
815                        self.localization.loc_info['field'] = self.state_scaling.shape
816                    del self.cov_data
817                    # reset the random state for consistency
818                    np.random.set_state(self.data_random_state)
819                    self._ext_obs()  # get the data that's in the list of data.
820                    _, self.aug_pred_data = at.aug_obs_pred_data(self.obs_data, self.pred_data, self.assim_index,
821                                                                 self.list_datatypes)
822                    # Mean pred_data and perturbation matrix with scaling
823                    if len(self.scale_data.shape) == 1:
824                        self.pert_preddata = np.dot(np.expand_dims(self.scale_data ** (-1), axis=1),
825                                                    np.ones((1, self.ne))) * np.dot(self.aug_pred_data, self.proj)
826                    else:
827                        self.pert_preddata = solve(
828                            self.scale_data, np.dot(self.aug_pred_data, self.proj))
829
830                    aug_state = at.aug_state(self.current_state, self.list_states)[state_indx,:]
831                    self.update()
832                    if hasattr(self, 'step'):
833                        aug_state_upd = aug_state + self.step[state_indx,:]
834                    self.state[state][state_indx,:] = aug_state_upd
835
836                self.list_datatypes = deepcopy(current_list_datatypes)
837
838        for state in self.local_analysis['cell_parameter']:
839            self.list_states = [deepcopy(state)]
840            self._ext_state()  # scaling for this state
841            orig_state_scaling = deepcopy(self.state_scaling)
842            param_position = self.local_analysis['parameter_position'][state]
843            field_size = param_position.shape
844            for k in range(field_size[0]):
845                for j in range(field_size[1]):
846                    for i in range(field_size[2]):
847                        current_data_list = list(
848                            self.local_analysis['update_mask'][state][k][j][i])
849                        current_data_list.sort()  # ensure consistent ordering of data
850                        if len(current_data_list):
851                            # if non-unique data for assimilation index, get the relevant data.
852                            if self.local_analysis['unique'] == False:
853                                orig_assim_index = deepcopy(self.assim_index)
854                                assim_index_data_list = set(
855                                    [el.split('_')[0] for el in current_data_list])
856                                current_assim_index = [
857                                    int(el.split('_')[1]) for el in current_data_list]
858                                current_data_list = list(assim_index_data_list)
859                                self.assim_index[1] = current_assim_index
860                            self.list_datatypes = deepcopy(current_data_list)
861                            del self.cov_data
862                            # reset the random state for consistency
863                            np.random.set_state(self.data_random_state)
864                            self._ext_obs()
865                            _, self.aug_pred_data = at.aug_obs_pred_data(self.obs_data, self.pred_data,
866                                                                         self.assim_index,
867                                                                         self.list_datatypes)
868                            # get parameter indexes
869                            full_cell_index = np.ravel_multi_index(
870                                np.array([[k], [j], [i]]), tuple(field_size))
871                            # count active values
872                            self.cell_index = [sum(param_position.flatten()[:el])
873                                               for el in full_cell_index]
874                            if 'localization' in self.keys_da:
875                                self.localization.loc_info['field'] = (
876                                    len(self.cell_index),)
877                                self.localization.loc_info['distance'] = cov_regularization._calc_distance(
878                                    self.local_analysis['data_position'],
879                                    self.local_analysis['unique'],
880                                    current_data_list, self.assim_index,
881                                    self.obs_data, self.pred_data, [(k, j, i)])
882                            # Set relevant state scaling
883                            self.state_scaling = orig_state_scaling[self.cell_index]
884
885                            # Mean pred_data and perturbation matrix with scaling
886                            if len(self.scale_data.shape) == 1:
887                                self.pert_preddata = np.dot(np.expand_dims(self.scale_data ** (-1), axis=1),
888                                                            np.ones((1, self.ne))) * np.dot(self.aug_pred_data,
889                                                                                            self.proj)
890                            else:
891                                self.pert_preddata = solve(
892                                    self.scale_data, np.dot(self.aug_pred_data, self.proj))
893
894                            aug_state = at.aug_state(
895                                self.current_state, self.list_states, self.cell_index)
896                            self.update()
897                            if hasattr(self, 'step'):
898                                aug_state_upd = aug_state + self.step
899                            self.state = at.update_state(
900                                aug_state_upd, self.state, self.list_states, self.cell_index)
901
902                            if self.local_analysis['unique'] == False:
903                                # reset assim index
904                                self.assim_index = deepcopy(orig_assim_index)
905                            if hasattr(self, 'localization') and 'distance' in self.localization.loc_info:  # reset
906                                del self.localization.loc_info['distance']
907
908        self.list_datatypes = deepcopy(orig_list_data)  # reset to original list
909        self.list_states = deepcopy(orig_list_state)
910        self.cov_data = deepcopy(orig_cd)
911        self.real_obs_data = deepcopy(orig_real_obs_data)
912        self.obs_data_vector = deepcopy(orig_data_vector)
913        self.cell_index = None

Class for organizing/initializing misc. variables and simulator for an ensemble-based inversion run. Inherits the PET ensemble structure

Ensemble(keys_da, keys_en, sim)
 31    def __init__(self, keys_da, keys_en, sim):
 32        """
 33        Parameters
 34        ----------
 35        keys_da : dict
 36            Options for the data assimilation class
 37
 38            - daalg: spesification of the method, first the main type (e.g., "enrml"), then the solver (e.g., "gnenrml")
 39            - analysis: update flavour ("approx", "full" or "subspace")
 40            - energy: percent of singular values kept after SVD
 41            - obsvarsave: save the observations as a file (default false)
 42            - restart: restart optimization from a restart file (default false)
 43            - restartsave: save a restart file after each successful iteration (defalut false)
 44            - analysisdebug: specify which class variables to save to the result files
 45            - truedataindex: order of the simulated data (for timeseries this is points in time)
 46            - obsname: unit for truedataindex (for timeseries this is days or hours or seconds, etc.)
 47            - truedata: the data, e.g., provided as a .csv file
 48            - assimindex: index for the data that will be used for assimilation
 49            - datatype: list with the name of the datatypes
 50            - staticvar: name of the static variables
 51            - datavar: data variance, e.g., provided as a .csv file
 52
 53        keys_en : dict
 54            Options for the ensemble class
 55
 56            - ne: number of perturbations used to compute the gradient
 57            - state: name of state variables passed to the .mako file
 58            - prior_<name>: the prior information the state variables, including mean, variance and variable limits
 59
 60        sim : callable
 61            The forward simulator (e.g. flow)
 62        """
 63
 64
 65        # do the initiallization of the PETensemble
 66        super(Ensemble, self).__init__(keys_en, sim)
 67
 68        # set logger
 69        self.logger = logging.getLogger('PET.PIPT')
 70
 71        # write initial information
 72        self.logger.info(f'Starting a {keys_da["daalg"][0]} run with the {keys_da["daalg"][1]} algorithm applying the '
 73                         f'{keys_da["analysis"]} update scheme with {keys_da["energy"]} Energy.')
 74
 75        # Internalize PIPT dictionary
 76        if not hasattr(self, 'keys_da'):
 77            self.keys_da = keys_da
 78        if not hasattr(self, 'keys_en'):
 79            self.keys_en = keys_en
 80
 81        if self.restart is False:
 82            # Init in _init_prediction_output (used in run_prediction)
 83            self.prediction = None
 84            self.temp_state = None  # temporary state saving
 85            self.cov_prior = None  # Prior cov. matrix
 86            self.sparse_info = None  # Init in _org_sparse_representation
 87            self.sparse_data = []  # List of the compression info
 88            self.data_rec = []  # List of reconstructed data
 89            self.scale_val = None  # Use to scale data
 90
 91            # Prepare sparse representation
 92            if 'compress' in self.keys_da:
 93                self._org_sparse_representation()
 94
 95            self._org_obs_data()
 96            self._org_data_var()
 97
 98            # define projection for centring and scaling
 99            self.proj = (np.eye(self.ne) - (1 / self.ne) *
100                         np.ones((self.ne, self.ne))) / np.sqrt(self.ne - 1)
101
102            # If we have dynamic state variables, we allocate keys for them in 'state'. Since we do not know the size
103            #  of the arrays of the dynamic variables, we only allocate an NE list to be filled in later (in
104            # calc_forecast)
105            if 'dynamicvar' in self.keys_da:
106                dyn_var = self.keys_da['dynamicvar'] if isinstance(self.keys_da['dynamicvar'], list) else \
107                    [self.keys_da['dynamicvar']]
108                for name in dyn_var:
109                    self.state[name] = [None] * self.ne
110
111            # Option to store the dictionaries containing observed data and data variance
112            if 'obsvarsave' in self.keys_da and self.keys_da['obsvarsave'] == 'yes':
113                np.savez('obs_var', obs=self.obs_data, var=self.datavar)
114
115            # Initialize localization
116            if 'localization' in self.keys_da:
117                self.localization = cov_regularization.localization(self.keys_da['localization'],
118                                                                    self.keys_da['truedataindex'],
119                                                                    self.keys_da['datatype'],
120                                                                    self.keys_da['staticvar'],
121                                                                    self.ne)
122            # Initialize local analysis
123            if 'localanalysis' in self.keys_da:
124                self.local_analysis = at.init_local_analysis(
125                    init=self.keys_da['localanalysis'], state=self.state.keys())
126
127            self.pred_data = [{k: np.zeros((1, self.ne), dtype='float32') for k in self.keys_da['datatype']}
128                              for _ in self.obs_data]
129
130            self.cell_index = None  # default value for extracting states
Parameters
  • keys_da (dict): Options for the data assimilation class

    • daalg: spesification of the method, first the main type (e.g., "enrml"), then the solver (e.g., "gnenrml")
    • analysis: update flavour ("approx", "full" or "subspace")
    • energy: percent of singular values kept after SVD
    • obsvarsave: save the observations as a file (default false)
    • restart: restart optimization from a restart file (default false)
    • restartsave: save a restart file after each successful iteration (defalut false)
    • analysisdebug: specify which class variables to save to the result files
    • truedataindex: order of the simulated data (for timeseries this is points in time)
    • obsname: unit for truedataindex (for timeseries this is days or hours or seconds, etc.)
    • truedata: the data, e.g., provided as a .csv file
    • assimindex: index for the data that will be used for assimilation
    • datatype: list with the name of the datatypes
    • staticvar: name of the static variables
    • datavar: data variance, e.g., provided as a .csv file
  • keys_en (dict): Options for the ensemble class

    • ne: number of perturbations used to compute the gradient
    • state: name of state variables passed to the .mako file
    • prior_: the prior information the state variables, including mean, variance and variable limits
  • sim (callable): The forward simulator (e.g. flow)
logger
def check_assimindex_sequential(self):
132    def check_assimindex_sequential(self):
133        """
134        Check if assim. indices is given as a 2D list as is needed in sequential updating. If not, make it a 2D list
135        """
136        # Check if ASSIMINDEX is a list. If not, make it a 2D list
137        if not isinstance(self.keys_da['assimindex'], list):
138            self.keys_da['assimindex'] = [[self.keys_da['assimindex']]]
139
140        # If ASSIMINDEX is a 1D list (either given in as a single row or single column), we reshape to a 2D list
141        elif not isinstance(self.keys_da['assimindex'][0], list):
142            assimindex_temp = [None] * len(self.keys_da['assimindex'])
143
144            for i in range(len(self.keys_da['assimindex'])):
145                assimindex_temp[i] = [self.keys_da['assimindex'][i]]
146
147            self.keys_da['assimindex'] = assimindex_temp

Check if assim. indices is given as a 2D list as is needed in sequential updating. If not, make it a 2D list

def check_assimindex_simultaneous(self):
149    def check_assimindex_simultaneous(self):
150        """
151        Check if assim. indices is given as a 1D list as is needed in simultaneous updating. If not, make it a 2D list
152        with one row.
153        """
154        # Check if ASSIMINDEX is a list. If not, make it a 2D list with one row
155        if not isinstance(self.keys_da['assimindex'], list):
156            self.keys_da['assimindex'] = [[self.keys_da['assimindex']]]
157
158        # Check if ASSIMINDEX is a 1D list. If true, make it a 2D list with one row
159        elif not isinstance(self.keys_da['assimindex'][0], list):
160            self.keys_da['assimindex'] = [self.keys_da['assimindex']]
161
162        # If ASSIMINDEX is a 2D list, we reshape it to a 2D list with one row
163        elif isinstance(self.keys_da['assimindex'][0], list):
164            self.keys_da['assimindex'] = [
165                [item for sublist in self.keys_da['assimindex'] for item in sublist]]

Check if assim. indices is given as a 1D list as is needed in simultaneous updating. If not, make it a 2D list with one row.

def save_temp_state_assim(self, ind_save):
580    def save_temp_state_assim(self, ind_save):
581        """
582        Method to save the state variable during the assimilation. It is stored in a list with length = tot. no.
583        assim. steps + 1 (for the init. ensemble). The list of temporary states are also stored as a .npz file.
584
585        Parameters
586        ----------
587        ind_save : int
588            Assim. step to save (0 = prior)
589        """
590        # Init. temp. save
591        if ind_save == 0:
592            # +1 due to init. ensemble
593            self.temp_state = [None]*(len(self.get_list_assim_steps()) + 1)
594
595        # Save the state
596        self.temp_state[ind_save] = deepcopy(self.state)
597        np.savez('temp_state_assim', self.temp_state)

Method to save the state variable during the assimilation. It is stored in a list with length = tot. no. assim. steps + 1 (for the init. ensemble). The list of temporary states are also stored as a .npz file.

Parameters
  • ind_save (int): Assim. step to save (0 = prior)
def save_temp_state_iter(self, ind_save, max_iter):
599    def save_temp_state_iter(self, ind_save, max_iter):
600        """
601        Save a snapshot of state at current iteration. It is stored in a list with length equal to max. iteration
602        length + 1 (due to prior state being 0). The list of temporary states are also stored as a .npz file.
603
604        .. warning:: Max. iterations must be defined before invoking this method.
605
606        Parameters
607        ----------
608        ind_save : int
609            Iteration step to save (0 = prior)
610        """
611        # Initial save
612        if ind_save == 0:
613            self.temp_state = [None] * (int(max_iter) + 1)  # +1 due to init. ensemble
614
615        # Save state
616        self.temp_state[ind_save] = deepcopy(self.state)
617        np.savez('temp_state_iter', self.temp_state)

Save a snapshot of state at current iteration. It is stored in a list with length equal to max. iteration length + 1 (due to prior state being 0). The list of temporary states are also stored as a .npz file.

Max. iterations must be defined before invoking this method.
Parameters
  • ind_save (int): Iteration step to save (0 = prior)
def save_temp_state_mda(self, ind_save):
619    def save_temp_state_mda(self, ind_save):
620        """
621        Save a snapshot of the state during a MDA loop. The temporary state will be stored as a list with length
622        equal to the tot. no. of assimilations + 1 (init. ensemble saved in 0 entry). The list of temporary states
623        are also stored as a .npz file.
624
625        .. warning:: Tot. no. of assimilations must be defined before invoking this method.
626
627        Parameter
628        ---------
629        ind_save : int
630            Assim. step to save (0 = prior)
631        """
632        # Initial save
633        if ind_save == 0:
634            # +1 due to init. ensemble
635            self.temp_state = [None] * (int(self.tot_assim) + 1)
636
637        # Save state
638        self.temp_state[ind_save] = deepcopy(self.state)
639        np.savez('temp_state_mda', self.temp_state)

Save a snapshot of the state during a MDA loop. The temporary state will be stored as a list with length equal to the tot. no. of assimilations + 1 (init. ensemble saved in 0 entry). The list of temporary states are also stored as a .npz file.

Tot. no. of assimilations must be defined before invoking this method.
Parameter

ind_save : int Assim. step to save (0 = prior)

def save_temp_state_ml(self, ind_save):
641    def save_temp_state_ml(self, ind_save):
642        """
643        Save a snapshot of the state during a ML loop. The temporary state will be stored as a list with length
644        equal to the tot. no. of assimilations + 1 (init. ensemble saved in 0 entry). The list of temporary states
645        are also stored as a .npz file.
646
647        .. warning:: Tot. no. of assimilations must be defined before invoking this method.
648
649        Parameters
650        ----------
651        ind_save : int
652            Assim. step to save (0 = prior)
653        """
654        # Initial save
655        if ind_save == 0:
656            # +1 due to init. ensemble
657            self.temp_state = [None] * (int(self.tot_assim) + 1)
658
659        # Save state
660        self.temp_state[ind_save] = deepcopy(self.state)
661        np.savez('temp_state_ml', self.temp_state)

Save a snapshot of the state during a ML loop. The temporary state will be stored as a list with length equal to the tot. no. of assimilations + 1 (init. ensemble saved in 0 entry). The list of temporary states are also stored as a .npz file.

Tot. no. of assimilations must be defined before invoking this method.
Parameters
  • ind_save (int): Assim. step to save (0 = prior)
def compress(self, data=None, vintage=0, aug_coeff=None):
663    def compress(self, data=None, vintage=0, aug_coeff=None):
664        """
665        Compress the input data using wavelets.
666
667        Parameters
668        ----------
669        data:
670            data to be compressed
671            If data is `None`, all data (true and simulated) is re-compressed (used if leading indices are updated)
672        vintage: int
673            the time index for the data
674        aug_coeff: bool
675            - False: in this case the leading indices for wavelet coefficients are computed
676            - True: in this case the leading indices are augmented using information from the ensemble
677            - None: in this case simulated data is compressed
678        """
679
680        # If input data is None, we re-compress all data
681        data_array = None
682        if data is None:
683            vintage = 0
684            for i in range(len(self.obs_data)):  # TRUEDATAINDEX
685                for j in self.obs_data[i].keys():  # DATATYPE
686
687                    data_array = self.obs_data[i][j]
688
689                    # Perform compression if required
690                    if data_array is not None and \
691                            vintage < len(self.sparse_info['mask']) and \
692                            len(data_array) == int(np.sum(self.sparse_info['mask'][vintage])):
693                        data_array, wdec_rec = self.sparse_data[vintage].compress(
694                            data_array)  # compress
695                        self.obs_data[i][j] = data_array  # save array in obs_data
696                        rec = self.sparse_data[vintage].reconstruct(
697                            wdec_rec)  # reconstruct the data
698                        s = 'truedata_rec_' + str(vintage) + '.npz'
699                        np.savez(s, rec)  # save reconstructed data
700                        est_noise = np.power(self.sparse_data[vintage].est_noise, 2)
701                        self.datavar[i][j] = est_noise
702
703                        # Update the ensemble
704                        data_sim = self.pred_data[i][j]
705                        self.pred_data[i][j] = np.zeros((len(data_array), self.ne))
706                        self.data_rec.append([])
707                        for m in range(self.pred_data[i][j].shape[1]):
708                            data_array = data_sim[:, m]
709                            data_array, wdec_rec = self.sparse_data[vintage].compress(
710                                data_array)  # compress
711                            self.pred_data[i][j][:, m] = data_array
712                            rec = self.sparse_data[vintage].reconstruct(
713                                wdec_rec)  # reconstruct the data
714                            self.data_rec[vintage].append(rec)
715
716                        # Go to next vintage
717                        vintage = vintage + 1
718
719            # Option to store the dictionaries containing observed data and data variance
720            if 'obsvarsave' in self.keys_da and self.keys_da['obsvarsave'] == 'yes':
721                np.savez('obs_var', obs=self.obs_data, var=self.datavar)
722
723            if 'saveforecast' in self.keys_en:
724                s = 'prior_forecast_rec.npz'
725                np.savez(s, self.data_rec)
726
727            data_array = None
728
729        elif aug_coeff is None:
730
731            data_array, wdec_rec = self.sparse_data[vintage].compress(data)
732            rec = self.sparse_data[vintage].reconstruct(
733                wdec_rec)  # reconstruct the simulated data
734            if len(self.data_rec) == vintage:
735                self.data_rec.append([])
736            self.data_rec[vintage].append(rec)
737
738        elif not aug_coeff:
739
740            options = copy(self.sparse_info)
741            # find the correct mask for the vintage
742            options['mask'] = options['mask'][vintage]
743            if type(options['min_noise']) == list:
744                if 0 <= vintage < len(options['min_noise']):
745                    options['min_noise'] = options['min_noise'][vintage]
746                else:
747                    print(
748                        'Error: min_noise must either be scalar or list with one number for each vintage')
749                    sys.exit(1)
750            x = wt.SparseRepresentation(options)
751            data_array, wdec_rec = x.compress(data, self.sparse_info['th_mult'])
752            self.sparse_data.append(x)  # store the information
753            data_rec = x.reconstruct(wdec_rec)  # reconstruct the data
754            s = 'truedata_rec_' + str(vintage) + '.npz'
755            np.savez(s, data_rec)  # save reconstructed data
756            if self.sparse_info['use_ensemble']:
757                data_array = data  # just return the same as input
758
759        elif aug_coeff:
760
761            _, _ = self.sparse_data[vintage].compress(data, self.sparse_info['th_mult'])
762            data_array = data  # just return the same as input
763
764        return data_array

Compress the input data using wavelets.

Parameters
  • data:: data to be compressed If data is None, all data (true and simulated) is re-compressed (used if leading indices are updated)
  • vintage (int): the time index for the data
  • aug_coeff (bool):
    • False: in this case the leading indices for wavelet coefficients are computed
    • True: in this case the leading indices are augmented using information from the ensemble
    • None: in this case simulated data is compressed
def local_analysis_update(self):
766    def local_analysis_update(self):
767        '''
768        Function for updates that can be used by all algorithms. Do this once to avoid duplicate code for local
769        analysis.
770        '''
771        orig_list_data = deepcopy(self.list_datatypes)
772        orig_list_state = deepcopy(self.list_states)
773        orig_cd = deepcopy(self.cov_data)
774        orig_real_obs_data = deepcopy(self.real_obs_data)
775        orig_data_vector = deepcopy(self.obs_data_vector)
776        # loop over the states that we want to update. Assume that the state and data combinations have been
777        # determined by the initialization.
778        # TODO: augment parameters with identical mask.
779        for state in self.local_analysis['region_parameter']:
780            self.list_datatypes = [elem for elem in self.list_datatypes if
781                                   elem in self.local_analysis['update_mask'][state]]
782            self.list_states = [deepcopy(state)]
783            self._ext_state()  # scaling for this state
784            if 'localization' in self.keys_da:
785                self.localization.loc_info['field'] = self.state_scaling.shape
786            del self.cov_data
787            # reset the random state for consistency
788            np.random.set_state(self.data_random_state)
789            self._ext_obs()  # get the data that's in the list of data.
790            _, self.aug_pred_data = at.aug_obs_pred_data(self.obs_data, self.pred_data, self.assim_index,
791                                                         self.list_datatypes)
792            # Mean pred_data and perturbation matrix with scaling
793            if len(self.scale_data.shape) == 1:
794                self.pert_preddata = np.dot(np.expand_dims(self.scale_data ** (-1), axis=1),
795                                            np.ones((1, self.ne))) * np.dot(self.aug_pred_data, self.proj)
796            else:
797                self.pert_preddata = solve(
798                    self.scale_data, np.dot(self.aug_pred_data, self.proj))
799
800            aug_state = at.aug_state(self.current_state, self.list_states)
801            self.update()
802            if hasattr(self, 'step'):
803                aug_state_upd = aug_state + self.step
804            self.state = at.update_state(aug_state_upd, self.state, self.list_states)
805
806        for state in self.local_analysis['vector_region_parameter']:
807            current_list_datatypes = deepcopy(self.list_datatypes)
808            for state_indx in range(self.state[state].shape[0]): # loop over the elements in the region
809                self.list_datatypes = [elem for elem in self.list_datatypes if
810                                       elem in self.local_analysis['update_mask'][state][state_indx]]
811                if len(self.list_datatypes):
812                    self.list_states = [deepcopy(state)]
813                    self._ext_state()  # scaling for this state
814                    if 'localization' in self.keys_da:
815                        self.localization.loc_info['field'] = self.state_scaling.shape
816                    del self.cov_data
817                    # reset the random state for consistency
818                    np.random.set_state(self.data_random_state)
819                    self._ext_obs()  # get the data that's in the list of data.
820                    _, self.aug_pred_data = at.aug_obs_pred_data(self.obs_data, self.pred_data, self.assim_index,
821                                                                 self.list_datatypes)
822                    # Mean pred_data and perturbation matrix with scaling
823                    if len(self.scale_data.shape) == 1:
824                        self.pert_preddata = np.dot(np.expand_dims(self.scale_data ** (-1), axis=1),
825                                                    np.ones((1, self.ne))) * np.dot(self.aug_pred_data, self.proj)
826                    else:
827                        self.pert_preddata = solve(
828                            self.scale_data, np.dot(self.aug_pred_data, self.proj))
829
830                    aug_state = at.aug_state(self.current_state, self.list_states)[state_indx,:]
831                    self.update()
832                    if hasattr(self, 'step'):
833                        aug_state_upd = aug_state + self.step[state_indx,:]
834                    self.state[state][state_indx,:] = aug_state_upd
835
836                self.list_datatypes = deepcopy(current_list_datatypes)
837
838        for state in self.local_analysis['cell_parameter']:
839            self.list_states = [deepcopy(state)]
840            self._ext_state()  # scaling for this state
841            orig_state_scaling = deepcopy(self.state_scaling)
842            param_position = self.local_analysis['parameter_position'][state]
843            field_size = param_position.shape
844            for k in range(field_size[0]):
845                for j in range(field_size[1]):
846                    for i in range(field_size[2]):
847                        current_data_list = list(
848                            self.local_analysis['update_mask'][state][k][j][i])
849                        current_data_list.sort()  # ensure consistent ordering of data
850                        if len(current_data_list):
851                            # if non-unique data for assimilation index, get the relevant data.
852                            if self.local_analysis['unique'] == False:
853                                orig_assim_index = deepcopy(self.assim_index)
854                                assim_index_data_list = set(
855                                    [el.split('_')[0] for el in current_data_list])
856                                current_assim_index = [
857                                    int(el.split('_')[1]) for el in current_data_list]
858                                current_data_list = list(assim_index_data_list)
859                                self.assim_index[1] = current_assim_index
860                            self.list_datatypes = deepcopy(current_data_list)
861                            del self.cov_data
862                            # reset the random state for consistency
863                            np.random.set_state(self.data_random_state)
864                            self._ext_obs()
865                            _, self.aug_pred_data = at.aug_obs_pred_data(self.obs_data, self.pred_data,
866                                                                         self.assim_index,
867                                                                         self.list_datatypes)
868                            # get parameter indexes
869                            full_cell_index = np.ravel_multi_index(
870                                np.array([[k], [j], [i]]), tuple(field_size))
871                            # count active values
872                            self.cell_index = [sum(param_position.flatten()[:el])
873                                               for el in full_cell_index]
874                            if 'localization' in self.keys_da:
875                                self.localization.loc_info['field'] = (
876                                    len(self.cell_index),)
877                                self.localization.loc_info['distance'] = cov_regularization._calc_distance(
878                                    self.local_analysis['data_position'],
879                                    self.local_analysis['unique'],
880                                    current_data_list, self.assim_index,
881                                    self.obs_data, self.pred_data, [(k, j, i)])
882                            # Set relevant state scaling
883                            self.state_scaling = orig_state_scaling[self.cell_index]
884
885                            # Mean pred_data and perturbation matrix with scaling
886                            if len(self.scale_data.shape) == 1:
887                                self.pert_preddata = np.dot(np.expand_dims(self.scale_data ** (-1), axis=1),
888                                                            np.ones((1, self.ne))) * np.dot(self.aug_pred_data,
889                                                                                            self.proj)
890                            else:
891                                self.pert_preddata = solve(
892                                    self.scale_data, np.dot(self.aug_pred_data, self.proj))
893
894                            aug_state = at.aug_state(
895                                self.current_state, self.list_states, self.cell_index)
896                            self.update()
897                            if hasattr(self, 'step'):
898                                aug_state_upd = aug_state + self.step
899                            self.state = at.update_state(
900                                aug_state_upd, self.state, self.list_states, self.cell_index)
901
902                            if self.local_analysis['unique'] == False:
903                                # reset assim index
904                                self.assim_index = deepcopy(orig_assim_index)
905                            if hasattr(self, 'localization') and 'distance' in self.localization.loc_info:  # reset
906                                del self.localization.loc_info['distance']
907
908        self.list_datatypes = deepcopy(orig_list_data)  # reset to original list
909        self.list_states = deepcopy(orig_list_state)
910        self.cov_data = deepcopy(orig_cd)
911        self.real_obs_data = deepcopy(orig_real_obs_data)
912        self.obs_data_vector = deepcopy(orig_data_vector)
913        self.cell_index = None

Function for updates that can be used by all algorithms. Do this once to avoid duplicate code for local analysis.