pipt.misc_tools.analysis_tools

Collection of tools that can be used in update/analysis schemes.

Only put tools here that are so general that they can be used by several update/analysis schemes. If some method is only applicable to the update scheme you are implementing, leave it in that class.

   1"""
   2Collection of tools that can be used in update/analysis schemes.
   3
   4Only put tools here that are so general that they can
   5be used by several update/analysis schemes. If some method is only applicable to the update scheme you are
   6implementing, leave it in that class.
   7"""
   8
   9# External imports
  10import numpy as np          # Numerical tools
  11from scipy import linalg    # Linear algebra tools
  12from misc.system_tools.environ_var import OpenBlasSingleThread  # only single thread
  13import multiprocessing as mp  # parallel updates
  14import time
  15import pickle
  16from importlib import import_module  # To import packages
  17
  18from scipy.spatial import cKDTree
  19
  20
  21def parallel_upd(list_state, prior_info, states_dict, X, local_mask_info, obs_data, pred_data, parallel, actnum=None,
  22                 field_dim=None, act_data_list=None, scale_data=None, num_states=1, emp_d_cov=False):
  23    """
  24    Script to initialize and control a parallel update of the ensemble state following [1].
  25
  26    Parameters
  27    ----------
  28    list_state: list
  29        List of state names
  30    prior_info: dict
  31        INSERT DESCRIPTION
  32    states_dict: dict
  33        Dict. of state arrays
  34    X: ndarray
  35        INSERT DESCRIPTION
  36    local_mask_info: dict
  37        INSERT DESCRIPTION
  38    obs_data: ndarray
  39        Observed data
  40    pred_data: ndarray
  41        Predicted data
  42    parallel: int
  43        Number of parallel runs
  44    actnum: ndarray, optional
  45        Active cells
  46    field_dim: list, optional
  47        Number of grid cells in each direction
  48    act_data_list: list, optional
  49        List of active data names
  50    scale_data: ndarray, optional
  51        Scaling array for data
  52    num_states: int, optional
  53        Number of states
  54    emp_d_cov: bool
  55        INSERT DESCRIPTION
  56
  57    Notes
  58    -----
  59    Since the localization matrix is to large for evaluation, we instead calculate it row for row.
  60
  61    References
  62    ----------
  63    [1] Emerick, Alexandre A. 2016. “Analysis of the Performance of Ensemble-Based Assimilation of Production and
  64    Seismic Data.” Journal of Petroleum Science and Engineering 139. Elsevier: 219-39. doi:10.1016/j.petrol.2016.01.029
  65    """
  66    if scale_data is None:
  67        scale_data = np.ones(obs_data.shape[0])
  68
  69    # Generate a list over the grid coordinates
  70    if field_dim is not None:
  71        k_coord, j_coord, i_coord = np.meshgrid(range(field_dim[0]), range(
  72            field_dim[1]), range(field_dim[2]), indexing='ij')
  73        tot_g = np.array([k_coord, j_coord, i_coord])
  74        if actnum is not None:
  75            act_g = tot_g[:, actnum.reshape(field_dim)]
  76        else:
  77            act_g = tot_g[:, np.ones(tuple(field_dim), dtype=bool)]
  78
  79    dat = [el for el in local_mask_info.keys()]
  80    # data coordinates to initialize search
  81    tot_completions = [tuple(el) for dat_mask in dat if type(
  82        dat_mask) == tuple for el in local_mask_info[dat_mask]['position']]
  83    uniq_completions = [el for el in set(tot_completions)]
  84    tot_w_name = [dat_mask for dat_mask in dat if type(
  85        dat_mask) == tuple for _ in local_mask_info[dat_mask]['position']]
  86    uniq_w_name = [tot_w_name[tot_completions.index(el)] for el in uniq_completions]
  87    # todo: limit to active datanan
  88    coord_search = cKDTree(data=uniq_completions)
  89
  90    try:
  91        act_w_name = [el[0].split()[1] for el in uniq_w_name]
  92
  93        tot_well_dict = {}
  94        for well in set(act_w_name):
  95            tot_well_dict[well] = [el for el in local_mask_info.keys() if type(el) == tuple and
  96                                   el[0].split()[1] == well]
  97    except:
  98        tot_well_dict = local_mask_info
  99
 100    if len(scale_data.shape) == 1:
 101        diff = np.dot(np.expand_dims(scale_data**(-1), axis=1),
 102                      np.ones((1, pred_data.shape[1])))*(obs_data - pred_data)
 103    else:
 104        diff = linalg.solve(scale_data, (obs_data - pred_data))
 105
 106    # initiallize the update
 107    upd = {}
 108
 109    # Assume that we have three types of parameters. The full 3D fields, layers (2D fields), or scalar values. These are
 110    # handled individually.
 111
 112    field_states = [state for state in list_state if states_dict[state].shape[0]
 113                    == act_g.shape[1]]  # field states
 114    layer_states = [state for state in list_state if 1 <
 115                    states_dict[state].shape[0] < act_g.shape[1]]  # layer states
 116    # scalar states
 117    scalar_states = [state for state in list_state if states_dict[state].shape[0] == 1]
 118
 119    # We handle the field states first. These are the most time consuming, and requires parallelization.
 120
 121    # since X must be passed to all processes I spit the state into equal portions, and let the row updates loop over
 122    # the different portions
 123    # coordinates for active parameters
 124    split_coord = np.array_split(act_g, parallel, axis=1)
 125    # Assuming that all parameters are spatial fields
 126    split_state = [{} for _ in range(parallel)]
 127    tmp_loc = {}  # intitallize for checking similar localization info
 128    # assume for now that everything is spatial, if not we require an extra loop or (if/else block)
 129    for state in field_states:
 130        # Augment the joint state variables (originally a dictionary) and the prior state variable
 131        aug_state = states_dict[state]
 132        # aug_prior_state = at.aug_state(self.prior_state, self.list_states)
 133
 134        # Mean state and perturbation matrix
 135        mean_state = np.mean(aug_state, 1)
 136        if emp_d_cov:
 137            pert_state = (aug_state - np.dot(np.resize(mean_state, (len(mean_state), 1)),
 138                                             np.ones((1, aug_state.shape[1]))))
 139        else:
 140            pert_state = (aug_state - np.dot(np.resize(mean_state, (len(mean_state), 1)),
 141                                             np.ones((1, aug_state.shape[1])))) / (np.sqrt(aug_state.shape[1] - 1))
 142
 143        tmp_state = np.array_split(pert_state, parallel)
 144        for i, elem in enumerate(tmp_state):
 145            split_state[i][state] = elem
 146        tmp_loc[state] = [el for el in local_mask_info if el[2] == state]
 147    # loc_info = [local_mask_info for _ in range(parallel)]
 148    # tot_X = [X for _ in range(parallel)]
 149    # tot_coord_seach = [coord_search for _ in range(parallel)] # might promt error if coord_search is to large
 150    # tot_uniq_name = [uniq_w_name for _ in range(parallel)]
 151    # tot_data_list = [act_data_list for _ in range(parallel)]
 152    # tot_well_dict_list = [tot_well_dict for _ in range(parallel)]
 153    non_similar = []
 154    for state in field_states[1:]:  # check localication
 155        non_shared = {k: ' ' for i, k in enumerate(
 156            tmp_loc[field_states[0]]) if local_mask_info[k] != local_mask_info[tmp_loc[state][i]]}
 157        non_similar.append(len(non_shared))
 158
 159    if sum(non_similar) == 0:
 160        identical_loc = True
 161    else:
 162        identical_loc = False
 163    # Due to memory issues a pickle file is written containing all "meta" data required for the update
 164    with open('meta_analysis.p', 'wb') as file:
 165        pickle.dump({'local_mask_info': local_mask_info, 'diff': diff, 'X': X, 'coord_search': coord_search,
 166                     'unique_w_name': uniq_w_name, 'act_data_list': act_data_list, 'tot_well_dict': tot_well_dict,
 167                     'actnum': actnum, 'unique_completions': uniq_completions, 'identical_loc': identical_loc}, file)
 168    tot_file_name = ['meta_analysis.p' for _ in range(parallel)]
 169    # to_workers = zip(split_state, loc_info, diff, tot_X, split_coord, tot_coord_seach,tot_uniq_name, tot_data_list,
 170    #                  tot_well_dict_list)
 171    to_workers = zip(split_state, split_coord, tot_file_name)
 172
 173    parallel = 1  # test
 174    #
 175    with OpenBlasSingleThread():
 176        if parallel > 1:
 177            with mp.get_context('spawn').Pool(parallel) as pool:
 178                s = pool.map(_calc_row_upd, to_workers)
 179        else:
 180            tmp_s = map(_calc_row_upd, to_workers)
 181            s = [el for el in tmp_s]
 182
 183    for tmp_key in field_states:
 184        upd[tmp_key] = np.concatenate([el[tmp_key] for el in s], axis=0)
 185
 186    ####################################################################################################################
 187    # Now handle the layer states
 188
 189    for state in layer_states:
 190        # could add parallellizaton later
 191        aug_state = states_dict[state]
 192        mean_state = np.mean(aug_state, 1)
 193        if emp_d_cov:
 194            pert_state = {state: (aug_state - np.dot(np.resize(mean_state, (len(mean_state), 1)),
 195                                                     np.ones((1, aug_state.shape[1]))))}
 196        else:
 197            pert_state = {state: (aug_state - np.dot(np.resize(mean_state, (len(mean_state), 1)),
 198                                                     np.ones((1, aug_state.shape[1])))) / (np.sqrt(aug_state.shape[1] - 1))}
 199        # Layer
 200        # make a rule that requires the parameter name to end with the "_ + layer number". E.g. "multz_5"
 201        layer = int(state.split('_')[-1])
 202        l_act = np.full(field_dim, False)
 203        l_act[layer, :, :] = actnum.reshape(field_dim)[layer, :, :]
 204        act_g = tot_g[:, l_act]
 205
 206        to_workers = zip([pert_state], [act_g], ['meta_analysis.p'])
 207
 208        # with OpenBlasSingleThread():
 209        s = map(_calc_row_upd, to_workers)
 210        upd[state] = np.concatenate([el[state] for el in s], axis=0)
 211
 212    ####################################################################################################################
 213    # Finally the scalar states
 214    for state in scalar_states:
 215        # could add parallellizaton later
 216        aug_state = states_dict[state]
 217        mean_state = np.mean(aug_state, 1)
 218        if emp_d_cov:
 219            pert_state = {state: (aug_state - np.dot(np.resize(mean_state, (len(mean_state), 1)),
 220                                                     np.ones((1, aug_state.shape[1]))))}
 221        else:
 222            pert_state = {state: (aug_state - np.dot(np.resize(mean_state, (len(mean_state), 1)),
 223                                                     np.ones((1, aug_state.shape[1])))) / (np.sqrt(aug_state.shape[1] - 1))}
 224
 225        to_workers = zip([pert_state], [tot_g], ['meta_analysis.p'])
 226
 227        # with OpenBlasSingleThread():
 228        s = map(_calc_row_upd, to_workers)
 229
 230        upd[state] = np.concatenate([el[state] for el in s], axis=0)
 231
 232    return upd
 233
 234
 235def _calc_row_upd(inp):
 236    """
 237    Calculate the updates.
 238
 239    Parameters
 240    ----------
 241    inp: list    
 242        List of [state, param_coordinates, metadata file name]
 243    """
 244
 245    with open(inp[2], 'rb') as file:
 246        meta_data = pickle.load(file)
 247    states = [el for el in inp[0].keys()]
 248    Ne = inp[0][states[0]].shape[1]
 249    upd = {}
 250    for el in states:
 251        upd[el] = [np.zeros((1, Ne))]*(inp[0][el].shape[0])
 252
 253    # Check and define regions for wells
 254    regions = _calc_region(meta_data['local_mask_info'], states,
 255                           meta_data['local_mask_info']['field'], meta_data['actnum'])
 256    max_r = {}
 257    for state in states:
 258        tmp_r = [meta_data['local_mask_info'][el]['range'][0] for el in meta_data['local_mask_info'].keys() if
 259                 type(el) == tuple and state in el and type(meta_data['local_mask_info'][el]['range'][0]) == int]
 260        if len(tmp_r):
 261            max_r[state] = max(tmp_r)
 262        else:
 263            max_r[state] = 0
 264    for i in range(inp[0][states[0]].shape[0]):
 265        for el in states:
 266            uniq_well = []
 267            if len(regions[el]):
 268                for reg in regions[el]:
 269                    if max_r[el] == 0:  # only use wells in the region, no taper
 270                        tmp_unique = []
 271                        for ind, w in enumerate(meta_data['unique_w_name']):
 272                            for comp in reg.T:
 273                                if meta_data['unique_completions'][ind][2] == comp[0] and \
 274                                        meta_data['unique_completions'][ind][1] == comp[1] and \
 275                                        meta_data['unique_completions'][ind][0] == comp[2]:
 276                                    tmp_unique.append(w)
 277                                    break
 278                        uniq_well.extend(tmp_unique)
 279                    else:  # only wells in the region, with taper
 280                        uniq_well.extend([w for w in set([meta_data['unique_w_name'][el] for el in
 281                                                          meta_data['coord_search'].query_ball_point(x=(inp[1][2, i], inp[1][1, i], inp[1][0, i]), r=max_r[el])])])
 282            else:
 283                uniq_well.extend([w for w in set([meta_data['unique_w_name'][el] for el in meta_data['coord_search'].query_ball_point(
 284                    x=(inp[1][2, i], inp[1][1, i], inp[1][0, i]), r=max_r[el])])])
 285
 286            uniq_well = [(w[0], w[1], el) for w in set(uniq_well)]
 287            row_loc = np.zeros(meta_data['diff'].shape[0])
 288            for well in uniq_well:
 289                try:
 290                    tot_act_well = [elem for elem in meta_data['tot_well_dict']
 291                                    [well[0].split()[1]] if elem[2] == el]
 292                except:
 293                    tot_act_well = [elem for elem in meta_data['tot_well_dict'][well]]
 294                # curr_completions = frozenset((inp[1][tot_act_well[0]]['position']))
 295                tot_act_data_types = set([el[0].split()[0] for el in tot_act_well])
 296                for data_typ in tot_act_data_types:
 297                    for el_well in tot_act_well:
 298                        if el_well[0].split()[0] == data_typ:
 299                            tmp_loc_info = el_well
 300                            break
 301                    curr_rho = _calc_loc(grid_pos=(inp[1][2, i], inp[1][1, i], inp[1][0, i]),
 302                                         loc_info=meta_data['local_mask_info'][tmp_loc_info], ne=Ne)
 303                    index = meta_data['act_data_list'][tmp_loc_info[0]]
 304                    row_loc[index] = curr_rho
 305                # for act_well in tot_act_well:
 306                #     # if len(curr_completions.difference(inp[1][act_well]['position'])) > 0:
 307                #     #     curr_completions = frozenset((inp[1][act_well]['position']))
 308                #     #     curr_rho = _calc_loc(grid_pos=(inp[4][2,i], inp[4][1,i], inp[4][0,i]), loc_info=inp[1][act_well],
 309                #     #                          ne=Ne)
 310                #     loc_index = inp[7][(act_well[0], act_well[1])]
 311                #     row_loc[loc_index] = curr_rho
 312            if 'identical_loc' in meta_data and meta_data['identical_loc']:
 313                for el_upd in states:
 314                    upd[el_upd][i] = np.dot(np.expand_dims(row_loc * np.dot(inp[0][el_upd][i, :], meta_data['X']), axis=0),
 315                                            meta_data['diff'])
 316                break
 317            else:
 318                upd[el][i] = np.dot(np.expand_dims(
 319                    row_loc*np.dot(inp[0][el][i, :], meta_data['X']), axis=0), meta_data['diff'])
 320
 321    tot_upd = {}
 322    for el in states:
 323        tot_upd[el] = np.concatenate(upd[el], axis=0)
 324
 325    return tot_upd
 326
 327
 328def _calc_region(loc_info, states, field_dim, actnum):
 329    """
 330    Calculate the region-boxes where data can be available for the state.
 331
 332    Parameters
 333    ----------
 334    loc_info: dict
 335        Information for localization
 336    states: dict 
 337        State variables
 338    field_dim: list
 339        Dimension of grid
 340    actnum: ndarray
 341        Active cells
 342
 343    Returns
 344    -------
 345    regions: dict
 346        Region-box
 347    """
 348    regions = {}
 349    for state in states:
 350        tmp_reg = [loc_info[el]['range'] for el in loc_info.keys() if type(el) == tuple and 'region' in loc_info[el]['taper_func']
 351                   and state in el]
 352        unique_reg = [el for el in set(map(tuple, tmp_reg))]
 353        regions[state] = []
 354        for reg in unique_reg:
 355            upd_reg = []
 356            for el in reg:
 357                # convert region boundaries (x0:x1) into list of integers [x0,x1]
 358                if ':' in el:
 359                    upd_reg.extend([int(l) for l in el.split(':')])
 360                else:
 361                    upd_reg.append(el)
 362            regions[state].append(_get_region(upd_reg, field_dim, actnum))
 363
 364    return regions
 365
 366
 367def _get_region(reg, field_dim=None, actnum=None):
 368    """
 369    Calculate the coordinates of the region. Consider two formats.
 370    <ol>
 371        <li>k_min, k_max, j_min, j_max, i_min, i_max</li>
 372        <li>File (containing regions) regions</li>
 373    </ol>
 374
 375    Parameters
 376    ----------
 377    reg:
 378    field_dim: list
 379        Dimension of grid
 380    actnum: ndarray
 381        Active cells
 382
 383    Returns
 384    -------
 385    act_g: ndarray
 386    """
 387
 388    # Get the files
 389    if type(reg[0]) == str:
 390        flag_region = [int(el) for el in reg[1:]]
 391        with open(reg[0], 'r') as file:
 392            lines = file.readlines()
 393            # Extract all lines that start with a digit, and make a list of all digits
 394            tot_char = [el for l in lines if len(l.strip())
 395                        and l.strip()[0][0].isdigit() for el in l.split() if el[0].isdigit()]
 396        if field_dim is not None:
 397            # CHECK THIS AT SOME POINT!
 398            k_coord, j_coord, i_coord = np.meshgrid(range(field_dim[0]), range(
 399                field_dim[1]), range(field_dim[2]), indexing='ij')
 400            tot_g = np.array([k_coord, j_coord, i_coord])
 401            if actnum is not None:
 402                tot_f = np.zeros(field_dim).flatten()
 403                count = 0
 404                for l in tot_char:
 405                    if l.isdigit():
 406                        if int(l) in flag_region:
 407                            tot_f[count] = 1
 408                        count += 1
 409                    else:  # assume that we have input on the format num_cells*region_number
 410                        num_cell, tmp_region = l.split('*')
 411                        if int(tmp_region) in flag_region:
 412                            for i in range(int(num_cell)):
 413                                tot_f[count + i] = 1
 414                        count += int(num_cell)
 415                tot_f[~actnum] = 0
 416                act_g = tot_g[:, tot_f]
 417    else:
 418        # Get the domain
 419        if field_dim is not None:
 420            k_coord, j_coord, i_coord = np.meshgrid(range(field_dim[0]), range(
 421                field_dim[1]), range(field_dim[2]), indexing='ij')
 422            tot_g = np.array([k_coord, j_coord, i_coord])
 423            if actnum is not None:
 424                tot_f = np.zeros(field_dim, dtype=bool)
 425                tot_f[reg[4]:reg[5], reg[2]:reg[3], reg[0]:reg[1]] = actnum.reshape(
 426                    field_dim)[reg[4]:reg[5], reg[2]:reg[3], reg[0]:reg[1]]
 427                act_g = tot_g[:, tot_f]
 428            else:
 429                tot_f = np.zeros(field_dim, dtype=bool)
 430                tot_f[reg[4]:reg[5], reg[2]:reg[3], reg[0]:reg[1]] = np.ones(
 431                    field_dim, dtype=bool)[reg[4]:reg[5], reg[2]:reg[3], reg[0]:reg[1]]
 432                act_g = tot_g[:, tot_f]
 433
 434    return act_g
 435
 436
 437def _calc_loc(grid_pos=[0, 0, 0], loc_info=None, ne=1):
 438    """
 439    _summary_
 440
 441    Parameters
 442    ----------
 443    grid_pos: list, optional
 444     Grid coordinates. Defaults to [0,0,0].
 445    loc_info: dict, optional
 446        Localization inf. Defaults to None.
 447    ne: int, optional
 448        Number of ensemble members. Defaults to 1.
 449
 450    Returns
 451    -------
 452    mask: ndarray
 453        Localization mask
 454    """
 455    # given the parameter type (to get the prior info) and the range to the data points we can calculate the
 456    # localization mask
 457
 458    if loc_info['taper_func'] == 'region':
 459        mask = 1
 460    else:
 461        # TODO: Add 3D anisotropi
 462        loc_range = []
 463        for el in loc_info['position']:
 464            loc_range.append(_calc_dist(grid_pos, el))
 465
 466        dist = min(loc_range)
 467        if loc_info['taper_func'] == 'fb':
 468            # assume that FB localization is utilized. Here vi can add all different localization functions
 469            if dist < loc_info['range'][0]:
 470                tmp = 1 - 1 * \
 471                    (1.5 * np.abs(dist) / loc_info['range']
 472                     [0] - .5 * (dist / loc_info['range'][0]) ** 3)
 473            else:
 474                tmp = 0
 475
 476            mask = (ne * tmp ** 2) / ((tmp ** 2) * (ne + 1) + 1 ** 2)
 477
 478    return mask
 479
 480
 481def _calc_dist(x1, x2):
 482    """
 483    Calculate distance between two points
 484
 485    Parameters
 486    ----------
 487    x1, x2: ndarray
 488        Coordinates
 489
 490    Returns
 491    -------
 492    dist: ndarray
 493        (Euclidean) distance between `x1` and `x2`
 494
 495    """
 496    if len(x1) == 1:
 497        return np.sqrt((x1-x2)**2)
 498    elif len(x1) == 2:
 499        return np.sqrt((x1[0]-x2[0])**2 + (x1[1]-x2[1])**2)
 500    elif len(x1) == 3:
 501        return np.sqrt((x1[0]-x2[0])**2 + (x1[1]-x2[1])**2 + (x1[2]-x2[2])**2)
 502
 503
 504def calc_autocov(pert):
 505    """
 506    Calculate sample auto-covariance matrix.
 507
 508    Parameters
 509    ----------
 510    pert: ndarray
 511        Perturbation matrix (matrix of variables perturbed with their mean)
 512
 513    Returns
 514    -------
 515    cov_auto: ndarray
 516        Sample auto-covariance matrix
 517    """
 518    # TODO: Implement sqrt-covariance matrices
 519
 520    # No of samples
 521    ne = pert.shape[1]
 522
 523    # Standard sample auto-covariance calculation
 524    cov_auto = (1 / (ne - 1)) * np.dot(pert, pert.T)
 525
 526    # Return the auto-covariance matrix
 527    return cov_auto
 528
 529
 530def calc_objectivefun(pert_obs, pred_data, Cd):
 531    """
 532    Calculate the objective function.
 533
 534    Parameters
 535    ----------
 536    pert_obs : array-like
 537        NdxNe array containing perturbed observations.
 538
 539    pred_data : array-like
 540        NdxNe array containing ensemble of predictions.
 541
 542    Cd : array-like
 543        NdxNd array containing data covariance, or Ndx1 array containing data variance.
 544
 545    Returns
 546    -------
 547    data_misfit : array-like
 548        Nex1 array containing objective function values.
 549    """
 550    ne = pred_data.shape[1]
 551    r = (pred_data - pert_obs)
 552    if len(Cd.shape) == 1:
 553        precission = Cd**(-1)
 554        data_misfit = np.diag(r.T.dot(r*precission[:, None]))
 555    else:
 556        data_misfit = np.diag(r.T.dot(linalg.solve(Cd, r)))
 557
 558    return data_misfit
 559
 560
 561def calc_crosscov(pert1, pert2):
 562    """
 563    Calculate sample cross-covariance matrix.
 564
 565    Parameters
 566    ----------
 567    pert1, pert2: ndarray
 568        Perturbation matrices (matrix of variables perturbed with their mean).
 569
 570    Returns
 571    -------
 572    cov_cross: ndarray
 573        Sample cross-covariance matrix
 574    """
 575    # TODO: Implement sqrt-covariance matrices
 576
 577    # No of samples
 578    ne = pert1.shape[1]
 579
 580    # Standard calc. of sample cross-covariance
 581    cov_cross = (1 / (ne - 1)) * np.dot(pert1, pert2.T)
 582
 583    # Return the cross-covariance matrix
 584    return cov_cross
 585
 586
 587def update_datavar(cov_data, datavar, assim_index, list_data):
 588    """
 589    Extract the separate variance from an augmented vector. It is assumed that the augmented variance
 590    is made gen_covdata, hence this is the reverse method of gen_covdata.
 591
 592    Parameters
 593    ----------
 594    cov_data : array-like
 595        Augmented vector of variance.
 596
 597    datavar : dict
 598        Dictionary of separate variances.
 599
 600    assim_index : list
 601        Assimilation order as a list.
 602
 603    list_data : list
 604        List of data keys.
 605
 606    Returns
 607    -------
 608    datavar : dict
 609        Updated dictionary of separate variances."""
 610
 611    # Loop over all entries in list_state and extract a vector with same number of elements as the key in datavar
 612    # determines from aug and replace the values in datavar[key].
 613
 614    # Make sure assim_index is list
 615    if isinstance(assim_index[1], list):  # Check if prim. ind. is a list
 616        l_prim = [int(x) for x in assim_index[1]]
 617    else:
 618        l_prim = [int(assim_index[1])]
 619
 620    # Extract the diagonal if cov_data is a matrix
 621    if len(cov_data.shape) == 2:
 622        cov_data = np.diag(cov_data)
 623
 624    # Initialize a variable to keep track of which row in 'cov_data' we start from in each loop
 625    aug_row = 0
 626    # Loop over all primary indices
 627    for ix in range(len(l_prim)):
 628        # Loop over data types and augment the data variance
 629        for i in range(len(list_data)):
 630            if datavar[l_prim[ix]][list_data[i]] is not None:
 631
 632                # If there is an observed data here, update it
 633                no_rows = datavar[l_prim[ix]][list_data[i]].shape[0]
 634
 635                # Extract the rows from aug and update 'state[key]'
 636                datavar[l_prim[ix]][list_data[i]] = cov_data[aug_row:aug_row + no_rows]
 637
 638                # Update tracking variable for row in 'aug'
 639                aug_row += no_rows
 640
 641    # Return
 642    return datavar
 643
 644
 645def save_analysisdebug(ind_save, **kwargs):
 646    """
 647    Save variables in analysis step for debugging purpose
 648
 649    Parameters
 650    ----------
 651    ind_save: int
 652        Index of analysis step
 653    **kwargs: dict
 654        Variables that will be saved to npz file
 655
 656    Notes
 657    -----
 658    Use kwargs here because the input will be a dictionary with names equal the variable names to store, and when this
 659    is passed to np.savez (kwargs) the variable will be stored with their original name.
 660    """
 661    # Save input variables
 662    try:
 663        np.savez('debug_analysis_step_{0}'.format(str(ind_save)), **kwargs)
 664    except: # if npz save fails dump to a pickle file
 665        with open(f'debug_analysis_step_{ind_save}.p', 'wb') as file:
 666            pickle.dump(kwargs, file)
 667
 668
 669def get_list_data_types(obs_data, assim_index):
 670    """
 671    Extract the list of all and active data types 
 672
 673    Parameters
 674    ----------
 675    obs_data: dict
 676        Observed data
 677    assim_index: int
 678        Current assimilation index
 679
 680    Returns
 681    -------
 682    l_all: list
 683        List of all data types
 684    l_act: list
 685        List of the data types that are active (that are not `None`)
 686    """
 687    # List the primary indices
 688    if isinstance(assim_index[0], list):  # If True, then we have subset list
 689        if isinstance(assim_index[1][0], list):  # Check if prim. ind. is a list
 690            l_prim = [int(x) for x in assim_index[1][0]]
 691        else:
 692            l_prim = [int(assim_index[1][0])]
 693    else:  # Only prim. assim. ind.
 694        if isinstance(assim_index[1], list):  # Check if prim. ind. is a list
 695            l_prim = [int(x) for x in assim_index[1]]
 696        else:
 697            l_prim = [int(assim_index[1])]
 698
 699    # List the data types.
 700    l_all = list(obs_data[l_prim[0]].keys())
 701
 702    # Extract the data types that are active at current assimilation step
 703    l_act = []
 704    for ix in l_prim:
 705        for data_typ in l_all:
 706            if obs_data[ix][data_typ] is not None:
 707                l_act.extend([data_typ])
 708
 709    # Return the list
 710    return l_all, l_act
 711
 712
 713def gen_covdata(datavar, assim_index, list_data):
 714    """
 715    Generate the data covariance matrix at current assimilation step. Note here that the data covariance may be a
 716    diagonal matrix with only variance entries, or an empirical covariance matrix, or both if in combination. For
 717    diagonal data covariance we only store vector of variance values.
 718
 719    Parameters
 720    ----------
 721    datavar: list
 722        List of dictionaries containing variance for the observed data. The structure of this list is the same as for
 723        `obs_data`
 724    assim_index: int
 725        Current assimilation index
 726    list_data: list
 727        List of the data types
 728
 729    Returns
 730    -------
 731    cd: ndarray
 732        Data auto-covariance matrix
 733
 734    Notes
 735    -----
 736    For empirical covariance generation, the datavar entry must be a 2D array, arranged as a standard ensemble matrix (N
 737    x Ns, where Ns is the number of samples).
 738    """
 739    # TODO: Change if sub-assim. indices are implemented
 740    # TODO: Use something other that numpy hstack for this augmentation!
 741
 742    # Make sure assim_index is list
 743    if isinstance(assim_index[1], list):  # Check if prim. ind. is a list
 744        l_prim = [int(x) for x in assim_index[1]]
 745    else:
 746        l_prim = [int(assim_index[1])]
 747
 748    # Init. a logical variable to check if it is the first time in the loop below that we extract variance data.
 749    # Need this because we stack the remaining variance horizontally, and it is possible that we have "None"
 750    # input in the first instances of the loop (hence we cannot always say that
 751    # self.datavar[l_prim[0]][list_data[0]] will be the first variance we want to extract)
 752    first_time = True
 753
 754    # Initialize augmented array
 755    # Loop over all primary indices
 756    for ix in range(len(l_prim)):
 757        # Loop over data types and augment the data variance
 758        for i in range(len(list_data)):
 759            if datavar[l_prim[ix]][list_data[i]] is not None:
 760                # If there is an observed data here, augment it
 761                if first_time:  # Init. var output
 762                    # Switch off the first time logical variable
 763                    first_time = False
 764
 765                    # Calc. var.
 766                    var = datavar[l_prim[ix]][list_data[i]]
 767
 768                    # If var is 2D then it is either full covariance or realizations to generate a sample cov.
 769                    # If matrix is square assume it is full covariance, note this can go wrong!
 770                    if var.ndim == 2:
 771                        if var.shape[0] == var.shape[1]:  # full cov
 772                            c_var = var
 773                        else:
 774                            c_var = calc_autocov(var)
 775                    # else we make a diagonal matrix
 776                    else:  # diagonal, only store vector
 777                        c_var = var
 778
 779                else:  # Stack var output
 780                    # Calc. var.
 781                    var = datavar[l_prim[ix]][list_data[i]]
 782
 783                    # If var is 2D then we generate a sample cov., else we make a diagonal matrix
 784                    if var.ndim == 2:  # empirical
 785                        if var.shape[0] == var.shape[1]:  # full cov
 786                            c_var_temp = var
 787                        else:
 788                            c_var_temp = calc_autocov(var)
 789                        c_var = linalg.block_diag(c_var, c_var_temp)
 790                    else:  # diagonal, only store vector
 791                        c_var_temp = var
 792                        c_var = np.append(c_var, c_var_temp)
 793
 794    # Generate the covariance matrix
 795    cd = c_var
 796
 797    # Return data covariance matrix
 798    return cd
 799
 800
 801def screen_data(cov_data, pred_data, obs_data_vector, keys_da, iteration):
 802    """
 803    INSERT DESCRIPTION
 804
 805    Parameters
 806    ----------
 807    cov_data: ndarray
 808        Data covariance matrix
 809    pred_data: ndarray
 810        Predicted data
 811    obs_data_vector: 
 812        Observed data (1D array)
 813    keys_da: dict
 814        Dictionary with every input in `DATAASSIM`
 815    iteration: int
 816        Current iteration
 817
 818    Returns
 819    -------
 820    cov_data: ndarray
 821        Updated data covariance matrix
 822    """
 823
 824    if ('restart' in keys_da and keys_da['restart'] == 'yes') or (iteration != 0):
 825        with open('cov_data.p', 'rb') as f:
 826            cov_data = pickle.load(f)
 827    else:
 828        emp_cov = False
 829        if cov_data.ndim == 2:  # assume emp_cov
 830            emp_cov = True
 831            var = np.var(cov_data, ddof=1, axis=1)
 832            cov_data = cov_data - cov_data.mean(1)[:, np.newaxis]
 833        num_data = pred_data.shape[0]
 834        for i in range(num_data):
 835            v = 0
 836            if obs_data_vector[i] < np.min(pred_data[i, :]):
 837                v = np.abs(obs_data_vector[i] - np.min(pred_data[i, :]))
 838            elif obs_data_vector[i] > np.max(pred_data[i, :]):
 839                v = np.abs(obs_data_vector[i] - np.max(pred_data[i, :]))
 840            if not emp_cov:
 841                cov_data[i] = np.max((cov_data[i], v ** 2))
 842            else:
 843                v = np.max((v**2 / var[i], 1))
 844                cov_data[i, :] *= np.sqrt(v)
 845        with open('cov_data.p', 'wb') as f:
 846            pickle.dump(cov_data, f)
 847
 848    return cov_data
 849
 850
 851def store_ensemble_sim_information(saveinfo, member):
 852    """
 853    Here, we can either run a unique python script or do some other post-processing routines. The function should
 854    not return anything, but provide a method for storing revevant information.
 855    Input the current member for easy storage
 856    """
 857
 858    for el in saveinfo:
 859        if '.py' in el:  # This is a unique python file
 860            sim_info_func = import_module(el[:-3])  # remove .py ending
 861            # Note: the function must be named main, and we pass the full current instance of the object pluss the
 862            # current member.
 863            sim_info_func.main(member)
 864
 865
 866def extract_tot_empirical_cov(data_var, assim_index, list_data, ne):
 867    """
 868    Extract realizations of noise from data_var (if imported), or generate realizations if only variance is specified
 869    (assume uncorrelated)
 870
 871    Parameters
 872    ----------
 873    data_var: list
 874        List of dictionaries containing the varianse as read from the input
 875    assim_index: int
 876        Index of the assimilation
 877    list_data: list
 878        List of data types
 879    ne: int
 880        Ensemble size
 881
 882    Returns
 883    -------
 884    E: ndarray
 885        Sorted (according to assim_index and list_data) matrix of data realization noise.
 886    """
 887
 888    if isinstance(assim_index[1], list):  # Check if prim. ind. is a list
 889        l_prim = [int(x) for x in assim_index[1]]
 890    else:
 891        l_prim = [int(assim_index[1])]
 892
 893    tmp_E = []
 894    for el in l_prim:
 895        tmp_tmp_E = {}
 896        for dat in list_data:
 897            if data_var[el][dat] is not None:
 898                if len(data_var[el][dat].shape) == 1:
 899                    tmp_tmp_E[dat] = np.sqrt(
 900                        data_var[el][dat][:, np.newaxis])*np.random.randn(data_var[el][dat].shape[0], ne)
 901                else:
 902                    if data_var[el][dat].shape[0] == data_var[el][dat].shape[1]:
 903                        tmp_tmp_E[dat] = np.dot(linalg.cholesky(
 904                            data_var[el][dat]), np.random.randn(data_var[el][dat].shape[1], ne))
 905                    else:
 906                        tmp_tmp_E[dat] = data_var[el][dat]
 907        tmp_E.append(tmp_tmp_E)
 908    E = np.concatenate(tuple(tmp_E[i][dat] for i, el in enumerate(
 909        l_prim) for dat in list_data if data_var[el][dat] is not None))
 910
 911    return E
 912
 913
 914def aug_obs_pred_data(obs_data, pred_data, assim_index, list_data):
 915    """
 916    Augment the observed and predicted data to an array at an assimilation step. The observed data will be an augemented
 917    vector and the predicted data will be an ensemble matrix.
 918
 919    Parameters
 920    ----------
 921    obs_data: list
 922        List of dictionaries containing observed data
 923    pred_data: list
 924        List of dictionaries where each entry of the list is the forward simulation results at an assimilation step. The
 925        dictionary has keys equal to the data type (given in `OBSNAME`).
 926
 927    Returns
 928    -------
 929    obs: ndarray 
 930        Augmented vector of observed data
 931    pred: ndarray
 932        Ensemble matrix of predicted data
 933    """
 934    # TODO: Change if sub-assim. ind. are implemented.
 935    # TODO: Use something other that numpy hstack and vstack for these augmentations!
 936
 937    # Make sure assim_index is a list
 938    if isinstance(assim_index[1], list):  # Check if prim. ind. is a list
 939        l_prim = [int(x) for x in assim_index[1]]
 940    else:
 941        l_prim = [int(assim_index[1])]
 942
 943    # make this more efficient
 944
 945    tot_pred = tuple(pred_data[el][dat] for el in l_prim if pred_data[el]
 946                     is not None for dat in list_data if obs_data[el][dat] is not None)
 947    if len(tot_pred):  # if this is done during the initiallization tot_pred contains nothing
 948        pred = np.concatenate(tot_pred)
 949    else:
 950        pred = None
 951    obs = np.concatenate(tuple(
 952        obs_data[el][dat] for el in l_prim for dat in list_data if obs_data[el][dat] is not None))
 953
 954    # Init. a logical variable to check if it is the first time in the loop below that we extract obs/pred data.
 955    # Need this because we stack the remaining data horizontally/vertically, and it is possible that we have "None"
 956    # input in the first instances of the loop (hence we cannot always say that
 957    # self.obs_data[l_prim[0]][list_data[0]] and self.pred_data[l_prim[0]][list_data[0]] will be the
 958    # first data we want to extract)
 959    # first_time = True
 960    #
 961    # #initialize obs and pred
 962    # obs = None
 963    # pred = None
 964    #
 965    # # Init the augmented arrays.
 966    # # Loop over all primary indices
 967    # for ix in range(len(l_prim)):
 968    #     # Loop over obs_data/pred_data keys
 969    #     for i in range(len(list_data)):
 970    #         # If there is an observed data here, augment obs and pred
 971    #         if obs_data[l_prim[ix]][list_data[i]] is not None:  # No obs/pred data
 972    #             if first_time:  # Init. the outputs obs and pred
 973    #                 # Switch off the first time logical variable
 974    #                 first_time = False
 975    #
 976    #                 # Observed data:
 977    #                 obs = obs_data[l_prim[ix]][list_data[i]]
 978    #
 979    #                 # Predicted data
 980    #                 pred = pred_data[l_prim[ix]][list_data[i]]
 981    #
 982    #             else:  # Stack the obs and pred outputs
 983    #                 # Observed data:
 984    #                 obs = np.hstack((obs, obs_data[l_prim[ix]][list_data[i]]))
 985    #
 986    #                 # Predicted data
 987    #                 pred = np.vstack((pred, pred_data[l_prim[ix]][list_data[i]]))
 988    #
 989    # # Return augmented arrays
 990    return obs, pred
 991
 992
 993def calc_kalmangain(cov_cross, cov_auto, cov_data, opt=None):
 994    r"""
 995    Calculate the Kalman gain
 996
 997    Parameters
 998    ----------
 999    cov_cross: ndarray
1000        Cross-covariance matrix between state and predicted data
1001    cov_auto: ndarray
1002        Auto-covariance matrix of predicted data
1003    cov_data: ndarray
1004        Variance on observed data (diagonal matrix)
1005    opt: str
1006        Which method should we use to calculate Kalman gain
1007        <ul>
1008            <li>'lu': LU decomposition (default)</li>
1009            <li>'chol': Cholesky decomposition</li>
1010        </ul>
1011
1012    Returns
1013    -------
1014    kalman_gain: ndarray
1015        Kalman gain
1016
1017    Notes
1018    -----
1019    In the following Kalman gain is $K$, cross-covariance is $C_{mg}$, predicted data auto-covariance is $C_{g}$,
1020    and data covariance is $C_{d}$.
1021
1022    With `'lu'` option, we solve the transposed linear system:
1023    $$
1024        K^T = (C_{g} + C_{d})^{-T}C_{mg}^T
1025    $$
1026
1027    With `'chol'` option we use Cholesky on auto-covariance matrix,
1028    $$
1029       L L^T = (C_{g} + C_{d})^T
1030    $$
1031    and solve linear system with the square-root matrix from Cholesky:
1032    $$
1033        L^T Y = C_{mg}^T\\
1034        LK = Y
1035    $$
1036    """
1037    if opt is None:
1038        calc_opt = 'lu'
1039
1040    # Add data and predicted data auto-covariance matrices
1041    if len(cov_data.shape) == 1:
1042        cov_data = np.diag(cov_data)
1043    c_auto = cov_auto + cov_data
1044
1045    if calc_opt == 'lu':
1046        kg = linalg.solve(c_auto.T, cov_cross.T)
1047        kalman_gain = kg.T
1048
1049    elif calc_opt == 'chol':
1050        # Cholesky decomp (upper triangular matrix)
1051        u = linalg.cho_factor(c_auto.T, check_finite=False)
1052
1053        # Solve linear system with cholesky square-root
1054        kalman_gain = linalg.cho_solve(u, cov_cross.T, check_finite=False)
1055
1056    # Return Kalman gain
1057    return kalman_gain
1058
1059
1060def calc_subspace_kalmangain(cov_cross, data_pert, cov_data, energy):
1061    """
1062    Compute the Kalman gain in a efficient subspace determined by how much energy (i.e. percentage of singluar values)
1063    to retain. For more info regarding the implementation, see Chapter 14 in [1].
1064
1065    Parameters
1066    cov_cross: ndarray
1067        Cross-covariance matrix between state and predicted data
1068    data_pert: ndarray
1069            Predicted data - mean of predicted data
1070    cov_data: ndarray
1071        Variance on observed data (diagonal matrix)
1072
1073    Returns
1074    -------
1075    k_g: ndarray
1076        Subspace Kalman gain
1077
1078    References
1079    ----------
1080    [1] G. Evensen (2009). Data Assimilation: The Ensemble Kalman Filter, Springer.
1081    """
1082    # No. ensemble members
1083    ne = data_pert.shape[1]
1084
1085    # Perform SVD on pred. data perturbations
1086    u_d, s_d, v_d = np.linalg.svd(np.sqrt(1 / (ne - 1)) * data_pert, full_matrices=False)
1087
1088    # If no. measurements is more than ne - 1, we only keep ne - 1 sing. val.
1089    if data_pert.shape[0] >= ne:
1090        u_d, s_d, v_d = u_d[:, :-1].copy(), s_d[:-1].copy(), v_d[:-1, :].copy()
1091
1092    # If energy is less than 100 we truncate the SVD matrices
1093    if energy < 100:
1094        ti = (np.cumsum(s_d) / sum(s_d)) * 100 <= energy
1095        u_d, s_d, v_d = u_d[:, ti].copy(), s_d[ti].copy(), v_d[ti, :].copy()
1096
1097    # Calculate x_0 and its eigenvalue decomp.
1098    if len(cov_data.shape) == 1:
1099        x_0 = np.dot(np.diag(s_d[:]**(-1)), np.dot(u_d[:, :].T, np.expand_dims(cov_data, axis=1)*np.dot(u_d[:, :],
1100                                                                                                        np.diag(s_d[:]**(-1)).T)))
1101    else:
1102        x_0 = np.dot(np.diag(s_d[:] ** (-1)), np.dot(u_d[:, :].T, np.dot(cov_data, np.dot(u_d[:, :],
1103                                                                                          np.diag(s_d[:] ** (-1)).T))))
1104    s, u = np.linalg.eig(x_0)
1105
1106    # Calculate x_1
1107    x_1 = np.dot(u_d[:, :], np.dot(np.diag(s_d[:]**(-1)).T, u))
1108
1109    # Calculate Kalman gain based on the subspace matrices we made above
1110    k_g = np.dot(cov_cross, np.dot(x_1, linalg.solve(
1111        (np.eye(s.shape[0]) + np.diag(s)), x_1.T)))
1112
1113    # Return subspace Kalman gain
1114    return k_g
1115
1116
1117def compute_x(pert_preddata, cov_data, keys_da, alfa=None):
1118    """
1119    INSERT DESCRIPTION
1120
1121    Parameters
1122    ----------
1123    pert_preddata: ndarray
1124        Perturbed predicted data
1125    cov_data: ndarray
1126        Data covariance matrix
1127    keys_da: dict
1128        Dictionary with every input in `DATAASSIM`
1129    alfa: None, optional
1130        INSERT DESCRIPTION
1131
1132    Returns:
1133    X: ndarray
1134        INSERT DESCRIPTION
1135    """
1136    X = []
1137    if 'kalmangain' in keys_da and keys_da['kalmangain'][0] == 'subspace':
1138
1139        # TSVD energy
1140        energy = keys_da['kalmangain'][1]
1141
1142        # No. ensemble members
1143        ne = pert_preddata.shape[1]
1144
1145        # Calculate x_0 and its eigenvalue decomp.
1146        if len(cov_data.shape) == 1:
1147            scale = np.expand_dims(np.sqrt(cov_data), axis=1)
1148        else:
1149            scale = np.expand_dims(np.sqrt(np.diag(cov_data)), axis=1)
1150
1151        # Perform SVD on pred. data perturbations
1152        u_d, s_d, v_d = np.linalg.svd(pert_preddata/scale, full_matrices=False)
1153
1154        # If no. measurements is more than ne - 1, we only keep ne - 1 sing. val.
1155        if pert_preddata.shape[0] >= ne:
1156            u_d, s_d, v_d = u_d[:, :-1].copy(), s_d[:-1].copy(), v_d[:-1, :].copy()
1157
1158        # If energy is less than 100 we truncate the SVD matrices
1159        if energy < 100:
1160            ti = (np.cumsum(s_d) / sum(s_d)) * 100 <= energy
1161            u_d, s_d, v_d = u_d[:, ti].copy(), s_d[ti].copy(), v_d[ti, :].copy()
1162
1163        # Calculate x_0 and its eigenvalue decomp.
1164        if len(cov_data.shape) == 1:
1165            x_0 = np.dot(np.diag(s_d[:] ** (-1)),
1166                         np.dot(u_d[:, :].T, np.expand_dims(cov_data, axis=1) * np.dot(u_d[:, :],
1167                                                                                       np.diag(s_d[:] ** (-1)).T)))
1168        else:
1169            x_0 = np.dot(np.diag(s_d[:] ** (-1)), np.dot(u_d[:, :].T, np.dot(cov_data, np.dot(u_d[:, :],
1170                                                                                              np.diag(s_d[:] ** (-1)).T))))
1171        s, u = np.linalg.eig(x_0)
1172
1173        # Calculate x_1
1174        x_1 = np.dot(u_d[:, :], np.dot(np.diag(s_d[:] ** (-1)).T, u))/scale
1175
1176        # Calculate X based on the subspace matrices we made above
1177        X = np.dot(np.dot(pert_preddata.T, x_1), linalg.solve(
1178            (np.eye(s.shape[0]) + np.diag(s)), x_1.T))
1179
1180    else:
1181        if len(cov_data.shape) == 1:
1182            X = linalg.solve(np.dot(pert_preddata, pert_preddata.T) +
1183                             np.diag(cov_data), pert_preddata)
1184        else:
1185            X = linalg.solve(np.dot(pert_preddata, pert_preddata.T) +
1186                             cov_data, pert_preddata)
1187        X = X.T
1188
1189    return X
1190
1191
1192def aug_state(state, list_state, cell_index=None):
1193    """
1194    Augment the state variables to an array.
1195
1196    Parameters
1197    ----------
1198    state: dict
1199        Dictionary of initial ensemble of (joint) state variables (static parameters and dynamic variables) to be
1200        assimilated.
1201    list_state: list
1202        Fixed list of keys in state dict.
1203    cell_index: list of vector indexes to be extracted
1204
1205    Returns
1206    -------
1207    aug: ndarray
1208        Ensemble matrix of augmented state variables
1209    """
1210    # TODO: In some rare cases, it may not be desirable to update every state variable at each assimilation step.
1211    # Change code to only augment states to be updated at the specific assimilation step
1212    # TODO: Use something other that numpy vstack for this augmentation!
1213
1214    if cell_index is not None:
1215        # Start with ensemble of first state variable
1216        aug = state[list_state[0]][cell_index]
1217
1218        # Loop over the next states (if exists)
1219        for i in range(1, len(list_state)):
1220            aug = np.vstack((aug, state[list_state[i]][cell_index]))
1221
1222        # Return the augmented array
1223
1224    else:
1225        # Start with ensemble of first state variable
1226        aug = state[list_state[0]]
1227
1228        # Loop over the next states (if exists)
1229        for i in range(1, len(list_state)):
1230            aug = np.vstack((aug, state[list_state[i]]))
1231
1232        # Return the augmented array
1233    return aug
1234
1235
1236def calc_scaling(state, list_state, prior_info):
1237    """
1238    Form the scaling to be used in svd related algoritms. Scaling consist of standard deviation for each `STATICVAR`
1239    It is important that this is formed in the same manner as the augmentet state vector is formed. Hence, with the same
1240    list of states.
1241
1242    Parameters
1243    ----------
1244    state: dict
1245        Dictionary containing the state
1246    list_state: list
1247        List of states for augmenting
1248    prior_info: dict
1249        Nested dictionary containing prior information
1250
1251    Returns
1252    -------
1253    scaling: numpy array
1254        scaling
1255    """
1256
1257    scaling = []
1258    for elem in list_state:
1259        # more than single value. This is for multiple layers. Assume all values are active
1260        if len(prior_info[elem]['variance']) > 1:
1261            scaling.append(np.concatenate(tuple(np.sqrt(prior_info[elem]['variance'][z]) *
1262                                                np.ones(
1263                                                    prior_info[elem]['ny']*prior_info[elem]['nx'])
1264                                                for z in range(prior_info[elem]['nz']))))
1265        else:
1266            scaling.append(tuple(np.sqrt(prior_info[elem]['variance']) *
1267                                 np.ones(state[elem].shape[0])))
1268
1269    return np.concatenate(scaling)
1270
1271
1272def update_state(aug_state, state, list_state, cell_index=None):
1273    """
1274    Extract the separate state variables from an augmented state array. It is assumed that the augmented state
1275    array is made in `aug_state`, hence this is the reverse method of `aug_state`.
1276
1277    Parameters
1278    ----------
1279    aug_state: ndarray
1280        Augmented array of UPDATED state variables
1281    state: dict
1282        Dict. of state variables NOT updated.
1283    list_state: list
1284        List of state keys that have been updated
1285    cell_index: list
1286        List of indexes that gives the where the aug state should be placed
1287
1288    Returns
1289    -------
1290    state: dict
1291        Dict. of UPDATED state variables
1292    """
1293    if cell_index is None:
1294        # Loop over all entries in list_state and extract a matrix with same number of rows as the key in state
1295        # determines from aug and replace the values in state[key].
1296        # Init. a variable to keep track of which row in 'aug' we start from in each loop
1297        aug_row = 0
1298        for _, key in enumerate(list_state):
1299            # Find no. rows in state[lkey] to determine how many rows from aug to extract
1300            no_rows = state[key].shape[0]
1301
1302            # Extract the rows from aug and update 'state[key]'
1303            state[key] = aug_state[aug_row:aug_row + no_rows, :]
1304
1305            # Update tracking variable for row in 'aug'
1306            aug_row += no_rows
1307
1308    else:
1309        aug_row = 0
1310        for _, key in enumerate(list_state):
1311            # Find no. rows in state[lkey] to determine how many rows from aug to extract
1312            no_rows = len(cell_index)
1313
1314            # Extract the rows from aug and update 'state[key]'
1315            state[key][cell_index, :] = aug_state[aug_row:aug_row + no_rows, :]
1316
1317            # Update tracking variable for row in 'aug'
1318            aug_row += no_rows
1319    return state
1320
1321
1322def resample_state(aug_state, state, list_state, new_en_size):
1323    """
1324    Extract the seperate state variables from an augmented state matrix. Calculate the mean and covariance, and resample
1325    this.
1326
1327    Parameters
1328    ----------
1329    aug_upd_state: ndarray
1330        Augmented matrix of state variables
1331    state: dict
1332        Dict. af state variables
1333    list_state: list
1334        List of state variable
1335    new_en_size: int
1336        Size of the new ensemble
1337
1338    Returns
1339    -------
1340    state: dict
1341        Dict. of resampled members
1342    """
1343
1344    aug_row = 0
1345    curr_ne = state[list_state[0]].shape[1]
1346    new_state = {}
1347    for elem in list_state:
1348        # determine how many rows to extract
1349        no_rows = state[elem].shape[0]
1350        new_state[elem] = np.empty((no_rows, new_en_size))
1351
1352        mean_state = np.mean(aug_state[aug_row:aug_row + no_rows, :], 1)
1353        pert_state = np.sqrt(1/(curr_ne - 1)) * (aug_state[aug_row:aug_row + no_rows, :] - np.dot(np.resize(mean_state,
1354                                                                                                            (len(mean_state), 1)), np.ones((1, curr_ne))))
1355        for i in range(new_en_size):
1356            new_state[elem][:, i] = mean_state + \
1357                np.dot(pert_state, np.random.normal(0, 1, pert_state.shape[1]))
1358
1359        aug_row += no_rows
1360
1361    return new_state
1362
1363
1364def block_diag_cov(cov, list_state):
1365    """
1366    Block diagonalize a covariance matrix dictionary.
1367
1368    Parameters
1369    ----------
1370    cov: dict
1371        Dict. with cov. matrices
1372    list_state: list
1373        Fixed list of keys in state dict.
1374
1375    Returns
1376    -------
1377    cov_out: ndarray
1378        Block diag. matrix with prior covariance matrices for each state.
1379    """
1380    # TODO: Change if there are cross-correlation between different states
1381
1382    # Init. block in matrix
1383    cov_out = cov[list_state[0]]
1384
1385    # Test if scalar has been given in init. block
1386    if not hasattr(cov_out, '__len__'):
1387        cov_out = np.array([[cov_out]])
1388
1389    # Loop of rest of the state-names and add in block diag. matrix
1390    for i in range(1, len(list_state)):
1391        cov_out = linalg.block_diag(cov_out, cov[list_state[i]])
1392
1393    # Return
1394    return cov_out
1395
1396
1397def calc_kalman_filter_eq(aug_state, kalman_gain, obs_data, pred_data):
1398    """
1399    Calculate the updated augment state using the Kalman filter equations
1400
1401    Parameters
1402    ----------
1403    aug_state: ndarray
1404        Augmented state variable (all the parameters defined in `STATICVAR` augmented in one array)
1405    kalman_gain: ndarray
1406        Kalman gain
1407    obs_data: ndarray
1408        Augmented observed data vector (all `OBSNAME` augmented in one array)
1409    pred_data: ndarray
1410        Augmented predicted data vector (all `OBSNAME` augmented in one array)
1411
1412    Returns
1413    -------
1414    aug_state_upd: ndarray
1415        Updated augmented state variable using the Kalman filter equations
1416    """
1417    # TODO: Implement svd updating algorithm
1418
1419    # Matrix version
1420    # aug_state_upd = aug_state + np.dot(kalman_gain, (obs_data - pred_data))
1421
1422    # For-loop version
1423    aug_state_upd = np.zeros(aug_state.shape)  # Init. updated state
1424
1425    for i in range(aug_state.shape[1]):  # Loop over ensemble members
1426        aug_state_upd[:, i] = aug_state[:, i] + \
1427            np.dot(kalman_gain, (obs_data[:, i] - pred_data[:, i]))
1428
1429    # Return the updated state
1430    return aug_state_upd
1431
1432
1433def limits(state, prior_info):
1434    """
1435    Check if any state variables overshoots the limits given by the prior info. If so, modify these values
1436
1437    Parameters
1438    ----------
1439    state: dict
1440        Dictionary containing the states
1441    prior_info: dict
1442        Dictionary containing prior information for all the states.
1443
1444    Returns
1445    -------
1446    state: dict
1447        Valid state
1448    """
1449    for var in state.keys():
1450        if 'limits' in prior_info[var]:
1451            state[var][state[var] < prior_info[var]['limits']
1452                       [0][0]] = prior_info[var]['limits'][0][0]
1453            state[var][state[var] > prior_info[var]['limits']
1454                       [0][1]] = prior_info[var]['limits'][0][1]
1455    return state
1456
1457
1458def subsample_state(index, aug_state, pert_state):
1459    """
1460    Draw a subsample from the original state, given by the index
1461
1462    Parameters
1463    ----------
1464    index: ndarray
1465        Index of parameters to draw.
1466    aug_state: ndarray
1467        Original augmented state.
1468    pert_state: ndarray
1469        Perturbed augmented state, for error covariance.
1470
1471    Returns
1472    -------
1473    new_state: dict
1474        Subsample of state.
1475    """
1476
1477    new_state = np.empty((aug_state.shape[0], len(index)))
1478    for i in range(len(index)):
1479        new_state[:, i] = aug_state[:, index[i]] + \
1480            np.dot(pert_state, np.random.normal(0, 1, pert_state.shape[1]))
1481        # select some elements
1482
1483    return new_state
1484
1485
1486def init_local_analysis(init, state):
1487    """Initialize local analysis.
1488
1489    Initialize the local analysis by reading the input variables, defining the parameter classes and search ranges. Build
1490    the map of data/parameter positions.
1491
1492    Args:
1493        init: dictionary containing the parsed information form the input file.
1494        state: list of states that will be updated
1495    Returns:
1496        local: dictionary of initialized values.
1497    """
1498
1499    local = {}
1500    local['cell_parameter'] = []
1501    local['region_parameter'] = []
1502    local['vector_region_parameter'] = []
1503    local['unique'] = True
1504
1505    for i, opt in enumerate(list(zip(*init))[0]):
1506        if opt.lower() == 'region_parameter':  # define scalar parameters valid in a region
1507            local['region_parameter'] = [
1508                elem for elem in init[i][1].split(' ') if elem in state]
1509        if opt.lower() == 'vector_region_parameter': # Sometimes it useful to define the same parameter for multiple
1510                                                    # regions as a vector.
1511            local['vector_region_parameter'] = [
1512                elem for elem in init[i][1].split(' ') if elem in state]
1513        if opt.lower() == 'cell_parameter':  # define cell specific vector parameters
1514            local['cell_parameter'] = [
1515                elem for elem in init[i][1].split(' ') if elem in state]
1516        if opt.lower() == 'search_range':
1517            local['search_range'] = int(init[i][1])
1518        if opt.lower() == 'column_update':
1519            local['column_update'] = [elem for elem in init[i][1].split(',')]
1520        if opt.lower() == 'parameter_position_file':  # assume pickled format
1521            with open(init[i][1], 'rb') as file:
1522                local['parameter_position'] = pickle.load(file)
1523        if opt.lower() == 'data_position_file':  # assume pickled format
1524            with open(init[i][1], 'rb') as file:
1525                local['data_position'] = pickle.load(file)
1526        if opt.lower() == 'update_mask_file':
1527            with open(init[i][1], 'rb') as file:
1528                local['update_mask'] = pickle.load(file)
1529
1530    if 'update_mask' in local:
1531        return local
1532    else:
1533        assert 'parameter_position' in local, 'A pickle file containing the binary map of the parameters is MANDATORY'
1534        assert 'data_position' in local, 'A pickle file containing the position of the data is MANDATORY'
1535
1536        data_name = [elem for elem in local['data_position'].keys()]
1537        if type(local['data_position'][data_name[0]][0]) == list:  # assim index has spesific position
1538            local['unique'] = False
1539            data_pos = [elem for data in data_name for assim_elem in local['data_position'][data]
1540                        for elem in assim_elem]
1541            data_ind = [f'{data}_{assim_indx}' for data in data_name for assim_indx, assim_elem in enumerate(local['data_position'][data])
1542                        for _ in assim_elem]
1543        else:
1544            data_pos = [elem for data in data_name for elem in local['data_position'][data]]
1545            # store the name for easy index
1546            data_ind = [data for data in data_name for _ in local['data_position'][data]]
1547        kde_search = cKDTree(data=data_pos)
1548
1549        local['update_mask'] = {}
1550        for param in local['cell_parameter']:  # find data in a distance from the parameter
1551            field_size = local['parameter_position'][param].shape
1552            local['update_mask'][param] = [[[[] for _ in range(field_size[2])] for _ in range(field_size[1])] for _
1553                                           in range(field_size[0])]
1554            for k in range(field_size[0]):
1555                for j in range(field_size[1]):
1556                    new_iter = [elem for elem, val in enumerate(
1557                        local['parameter_position'][param][k, j, :]) if val]
1558                    if len(new_iter):
1559                        for i in new_iter:
1560                            local['update_mask'][param][k][j][i] = set(
1561                                [data_ind[elem] for elem in kde_search.query_ball_point(x=(k, j, i),
1562                                                                                        r=local['search_range'], workers=-1)])
1563
1564        # see if data is inside the region. Note parameter_position is boolean map
1565        for param in local['region_parameter']:
1566            in_region = [local['parameter_position'][param][elem] for elem in data_pos]
1567            local['update_mask'][param] = set(
1568                [data_ind[count] for count, val in enumerate(in_region) if val])
1569
1570        return local
def parallel_upd( list_state, prior_info, states_dict, X, local_mask_info, obs_data, pred_data, parallel, actnum=None, field_dim=None, act_data_list=None, scale_data=None, num_states=1, emp_d_cov=False):
 22def parallel_upd(list_state, prior_info, states_dict, X, local_mask_info, obs_data, pred_data, parallel, actnum=None,
 23                 field_dim=None, act_data_list=None, scale_data=None, num_states=1, emp_d_cov=False):
 24    """
 25    Script to initialize and control a parallel update of the ensemble state following [1].
 26
 27    Parameters
 28    ----------
 29    list_state: list
 30        List of state names
 31    prior_info: dict
 32        INSERT DESCRIPTION
 33    states_dict: dict
 34        Dict. of state arrays
 35    X: ndarray
 36        INSERT DESCRIPTION
 37    local_mask_info: dict
 38        INSERT DESCRIPTION
 39    obs_data: ndarray
 40        Observed data
 41    pred_data: ndarray
 42        Predicted data
 43    parallel: int
 44        Number of parallel runs
 45    actnum: ndarray, optional
 46        Active cells
 47    field_dim: list, optional
 48        Number of grid cells in each direction
 49    act_data_list: list, optional
 50        List of active data names
 51    scale_data: ndarray, optional
 52        Scaling array for data
 53    num_states: int, optional
 54        Number of states
 55    emp_d_cov: bool
 56        INSERT DESCRIPTION
 57
 58    Notes
 59    -----
 60    Since the localization matrix is to large for evaluation, we instead calculate it row for row.
 61
 62    References
 63    ----------
 64    [1] Emerick, Alexandre A. 2016. “Analysis of the Performance of Ensemble-Based Assimilation of Production and
 65    Seismic Data.” Journal of Petroleum Science and Engineering 139. Elsevier: 219-39. doi:10.1016/j.petrol.2016.01.029
 66    """
 67    if scale_data is None:
 68        scale_data = np.ones(obs_data.shape[0])
 69
 70    # Generate a list over the grid coordinates
 71    if field_dim is not None:
 72        k_coord, j_coord, i_coord = np.meshgrid(range(field_dim[0]), range(
 73            field_dim[1]), range(field_dim[2]), indexing='ij')
 74        tot_g = np.array([k_coord, j_coord, i_coord])
 75        if actnum is not None:
 76            act_g = tot_g[:, actnum.reshape(field_dim)]
 77        else:
 78            act_g = tot_g[:, np.ones(tuple(field_dim), dtype=bool)]
 79
 80    dat = [el for el in local_mask_info.keys()]
 81    # data coordinates to initialize search
 82    tot_completions = [tuple(el) for dat_mask in dat if type(
 83        dat_mask) == tuple for el in local_mask_info[dat_mask]['position']]
 84    uniq_completions = [el for el in set(tot_completions)]
 85    tot_w_name = [dat_mask for dat_mask in dat if type(
 86        dat_mask) == tuple for _ in local_mask_info[dat_mask]['position']]
 87    uniq_w_name = [tot_w_name[tot_completions.index(el)] for el in uniq_completions]
 88    # todo: limit to active datanan
 89    coord_search = cKDTree(data=uniq_completions)
 90
 91    try:
 92        act_w_name = [el[0].split()[1] for el in uniq_w_name]
 93
 94        tot_well_dict = {}
 95        for well in set(act_w_name):
 96            tot_well_dict[well] = [el for el in local_mask_info.keys() if type(el) == tuple and
 97                                   el[0].split()[1] == well]
 98    except:
 99        tot_well_dict = local_mask_info
100
101    if len(scale_data.shape) == 1:
102        diff = np.dot(np.expand_dims(scale_data**(-1), axis=1),
103                      np.ones((1, pred_data.shape[1])))*(obs_data - pred_data)
104    else:
105        diff = linalg.solve(scale_data, (obs_data - pred_data))
106
107    # initiallize the update
108    upd = {}
109
110    # Assume that we have three types of parameters. The full 3D fields, layers (2D fields), or scalar values. These are
111    # handled individually.
112
113    field_states = [state for state in list_state if states_dict[state].shape[0]
114                    == act_g.shape[1]]  # field states
115    layer_states = [state for state in list_state if 1 <
116                    states_dict[state].shape[0] < act_g.shape[1]]  # layer states
117    # scalar states
118    scalar_states = [state for state in list_state if states_dict[state].shape[0] == 1]
119
120    # We handle the field states first. These are the most time consuming, and requires parallelization.
121
122    # since X must be passed to all processes I spit the state into equal portions, and let the row updates loop over
123    # the different portions
124    # coordinates for active parameters
125    split_coord = np.array_split(act_g, parallel, axis=1)
126    # Assuming that all parameters are spatial fields
127    split_state = [{} for _ in range(parallel)]
128    tmp_loc = {}  # intitallize for checking similar localization info
129    # assume for now that everything is spatial, if not we require an extra loop or (if/else block)
130    for state in field_states:
131        # Augment the joint state variables (originally a dictionary) and the prior state variable
132        aug_state = states_dict[state]
133        # aug_prior_state = at.aug_state(self.prior_state, self.list_states)
134
135        # Mean state and perturbation matrix
136        mean_state = np.mean(aug_state, 1)
137        if emp_d_cov:
138            pert_state = (aug_state - np.dot(np.resize(mean_state, (len(mean_state), 1)),
139                                             np.ones((1, aug_state.shape[1]))))
140        else:
141            pert_state = (aug_state - np.dot(np.resize(mean_state, (len(mean_state), 1)),
142                                             np.ones((1, aug_state.shape[1])))) / (np.sqrt(aug_state.shape[1] - 1))
143
144        tmp_state = np.array_split(pert_state, parallel)
145        for i, elem in enumerate(tmp_state):
146            split_state[i][state] = elem
147        tmp_loc[state] = [el for el in local_mask_info if el[2] == state]
148    # loc_info = [local_mask_info for _ in range(parallel)]
149    # tot_X = [X for _ in range(parallel)]
150    # tot_coord_seach = [coord_search for _ in range(parallel)] # might promt error if coord_search is to large
151    # tot_uniq_name = [uniq_w_name for _ in range(parallel)]
152    # tot_data_list = [act_data_list for _ in range(parallel)]
153    # tot_well_dict_list = [tot_well_dict for _ in range(parallel)]
154    non_similar = []
155    for state in field_states[1:]:  # check localication
156        non_shared = {k: ' ' for i, k in enumerate(
157            tmp_loc[field_states[0]]) if local_mask_info[k] != local_mask_info[tmp_loc[state][i]]}
158        non_similar.append(len(non_shared))
159
160    if sum(non_similar) == 0:
161        identical_loc = True
162    else:
163        identical_loc = False
164    # Due to memory issues a pickle file is written containing all "meta" data required for the update
165    with open('meta_analysis.p', 'wb') as file:
166        pickle.dump({'local_mask_info': local_mask_info, 'diff': diff, 'X': X, 'coord_search': coord_search,
167                     'unique_w_name': uniq_w_name, 'act_data_list': act_data_list, 'tot_well_dict': tot_well_dict,
168                     'actnum': actnum, 'unique_completions': uniq_completions, 'identical_loc': identical_loc}, file)
169    tot_file_name = ['meta_analysis.p' for _ in range(parallel)]
170    # to_workers = zip(split_state, loc_info, diff, tot_X, split_coord, tot_coord_seach,tot_uniq_name, tot_data_list,
171    #                  tot_well_dict_list)
172    to_workers = zip(split_state, split_coord, tot_file_name)
173
174    parallel = 1  # test
175    #
176    with OpenBlasSingleThread():
177        if parallel > 1:
178            with mp.get_context('spawn').Pool(parallel) as pool:
179                s = pool.map(_calc_row_upd, to_workers)
180        else:
181            tmp_s = map(_calc_row_upd, to_workers)
182            s = [el for el in tmp_s]
183
184    for tmp_key in field_states:
185        upd[tmp_key] = np.concatenate([el[tmp_key] for el in s], axis=0)
186
187    ####################################################################################################################
188    # Now handle the layer states
189
190    for state in layer_states:
191        # could add parallellizaton later
192        aug_state = states_dict[state]
193        mean_state = np.mean(aug_state, 1)
194        if emp_d_cov:
195            pert_state = {state: (aug_state - np.dot(np.resize(mean_state, (len(mean_state), 1)),
196                                                     np.ones((1, aug_state.shape[1]))))}
197        else:
198            pert_state = {state: (aug_state - np.dot(np.resize(mean_state, (len(mean_state), 1)),
199                                                     np.ones((1, aug_state.shape[1])))) / (np.sqrt(aug_state.shape[1] - 1))}
200        # Layer
201        # make a rule that requires the parameter name to end with the "_ + layer number". E.g. "multz_5"
202        layer = int(state.split('_')[-1])
203        l_act = np.full(field_dim, False)
204        l_act[layer, :, :] = actnum.reshape(field_dim)[layer, :, :]
205        act_g = tot_g[:, l_act]
206
207        to_workers = zip([pert_state], [act_g], ['meta_analysis.p'])
208
209        # with OpenBlasSingleThread():
210        s = map(_calc_row_upd, to_workers)
211        upd[state] = np.concatenate([el[state] for el in s], axis=0)
212
213    ####################################################################################################################
214    # Finally the scalar states
215    for state in scalar_states:
216        # could add parallellizaton later
217        aug_state = states_dict[state]
218        mean_state = np.mean(aug_state, 1)
219        if emp_d_cov:
220            pert_state = {state: (aug_state - np.dot(np.resize(mean_state, (len(mean_state), 1)),
221                                                     np.ones((1, aug_state.shape[1]))))}
222        else:
223            pert_state = {state: (aug_state - np.dot(np.resize(mean_state, (len(mean_state), 1)),
224                                                     np.ones((1, aug_state.shape[1])))) / (np.sqrt(aug_state.shape[1] - 1))}
225
226        to_workers = zip([pert_state], [tot_g], ['meta_analysis.p'])
227
228        # with OpenBlasSingleThread():
229        s = map(_calc_row_upd, to_workers)
230
231        upd[state] = np.concatenate([el[state] for el in s], axis=0)
232
233    return upd

Script to initialize and control a parallel update of the ensemble state following [1].

Parameters
  • list_state (list): List of state names
  • prior_info (dict): INSERT DESCRIPTION
  • states_dict (dict): Dict. of state arrays
  • X (ndarray): INSERT DESCRIPTION
  • local_mask_info (dict): INSERT DESCRIPTION
  • obs_data (ndarray): Observed data
  • pred_data (ndarray): Predicted data
  • parallel (int): Number of parallel runs
  • actnum (ndarray, optional): Active cells
  • field_dim (list, optional): Number of grid cells in each direction
  • act_data_list (list, optional): List of active data names
  • scale_data (ndarray, optional): Scaling array for data
  • num_states (int, optional): Number of states
  • emp_d_cov (bool): INSERT DESCRIPTION
Notes

Since the localization matrix is to large for evaluation, we instead calculate it row for row.

References

[1] Emerick, Alexandre A. 2016. “Analysis of the Performance of Ensemble-Based Assimilation of Production and Seismic Data.” Journal of Petroleum Science and Engineering 139. Elsevier: 219-39. doi:10.1016/j.petrol.2016.01.029

def calc_autocov(pert):
505def calc_autocov(pert):
506    """
507    Calculate sample auto-covariance matrix.
508
509    Parameters
510    ----------
511    pert: ndarray
512        Perturbation matrix (matrix of variables perturbed with their mean)
513
514    Returns
515    -------
516    cov_auto: ndarray
517        Sample auto-covariance matrix
518    """
519    # TODO: Implement sqrt-covariance matrices
520
521    # No of samples
522    ne = pert.shape[1]
523
524    # Standard sample auto-covariance calculation
525    cov_auto = (1 / (ne - 1)) * np.dot(pert, pert.T)
526
527    # Return the auto-covariance matrix
528    return cov_auto

Calculate sample auto-covariance matrix.

Parameters
  • pert (ndarray): Perturbation matrix (matrix of variables perturbed with their mean)
Returns
  • cov_auto (ndarray): Sample auto-covariance matrix
def calc_objectivefun(pert_obs, pred_data, Cd):
531def calc_objectivefun(pert_obs, pred_data, Cd):
532    """
533    Calculate the objective function.
534
535    Parameters
536    ----------
537    pert_obs : array-like
538        NdxNe array containing perturbed observations.
539
540    pred_data : array-like
541        NdxNe array containing ensemble of predictions.
542
543    Cd : array-like
544        NdxNd array containing data covariance, or Ndx1 array containing data variance.
545
546    Returns
547    -------
548    data_misfit : array-like
549        Nex1 array containing objective function values.
550    """
551    ne = pred_data.shape[1]
552    r = (pred_data - pert_obs)
553    if len(Cd.shape) == 1:
554        precission = Cd**(-1)
555        data_misfit = np.diag(r.T.dot(r*precission[:, None]))
556    else:
557        data_misfit = np.diag(r.T.dot(linalg.solve(Cd, r)))
558
559    return data_misfit

Calculate the objective function.

Parameters
  • pert_obs (array-like): NdxNe array containing perturbed observations.
  • pred_data (array-like): NdxNe array containing ensemble of predictions.
  • Cd (array-like): NdxNd array containing data covariance, or Ndx1 array containing data variance.
Returns
  • data_misfit (array-like): Nex1 array containing objective function values.
def calc_crosscov(pert1, pert2):
562def calc_crosscov(pert1, pert2):
563    """
564    Calculate sample cross-covariance matrix.
565
566    Parameters
567    ----------
568    pert1, pert2: ndarray
569        Perturbation matrices (matrix of variables perturbed with their mean).
570
571    Returns
572    -------
573    cov_cross: ndarray
574        Sample cross-covariance matrix
575    """
576    # TODO: Implement sqrt-covariance matrices
577
578    # No of samples
579    ne = pert1.shape[1]
580
581    # Standard calc. of sample cross-covariance
582    cov_cross = (1 / (ne - 1)) * np.dot(pert1, pert2.T)
583
584    # Return the cross-covariance matrix
585    return cov_cross

Calculate sample cross-covariance matrix.

Parameters
  • pert1, pert2 (ndarray): Perturbation matrices (matrix of variables perturbed with their mean).
Returns
  • cov_cross (ndarray): Sample cross-covariance matrix
def update_datavar(cov_data, datavar, assim_index, list_data):
588def update_datavar(cov_data, datavar, assim_index, list_data):
589    """
590    Extract the separate variance from an augmented vector. It is assumed that the augmented variance
591    is made gen_covdata, hence this is the reverse method of gen_covdata.
592
593    Parameters
594    ----------
595    cov_data : array-like
596        Augmented vector of variance.
597
598    datavar : dict
599        Dictionary of separate variances.
600
601    assim_index : list
602        Assimilation order as a list.
603
604    list_data : list
605        List of data keys.
606
607    Returns
608    -------
609    datavar : dict
610        Updated dictionary of separate variances."""
611
612    # Loop over all entries in list_state and extract a vector with same number of elements as the key in datavar
613    # determines from aug and replace the values in datavar[key].
614
615    # Make sure assim_index is list
616    if isinstance(assim_index[1], list):  # Check if prim. ind. is a list
617        l_prim = [int(x) for x in assim_index[1]]
618    else:
619        l_prim = [int(assim_index[1])]
620
621    # Extract the diagonal if cov_data is a matrix
622    if len(cov_data.shape) == 2:
623        cov_data = np.diag(cov_data)
624
625    # Initialize a variable to keep track of which row in 'cov_data' we start from in each loop
626    aug_row = 0
627    # Loop over all primary indices
628    for ix in range(len(l_prim)):
629        # Loop over data types and augment the data variance
630        for i in range(len(list_data)):
631            if datavar[l_prim[ix]][list_data[i]] is not None:
632
633                # If there is an observed data here, update it
634                no_rows = datavar[l_prim[ix]][list_data[i]].shape[0]
635
636                # Extract the rows from aug and update 'state[key]'
637                datavar[l_prim[ix]][list_data[i]] = cov_data[aug_row:aug_row + no_rows]
638
639                # Update tracking variable for row in 'aug'
640                aug_row += no_rows
641
642    # Return
643    return datavar

Extract the separate variance from an augmented vector. It is assumed that the augmented variance is made gen_covdata, hence this is the reverse method of gen_covdata.

Parameters
  • cov_data (array-like): Augmented vector of variance.
  • datavar (dict): Dictionary of separate variances.
  • assim_index (list): Assimilation order as a list.
  • list_data (list): List of data keys.
Returns
  • datavar (dict): Updated dictionary of separate variances.
def save_analysisdebug(ind_save, **kwargs):
646def save_analysisdebug(ind_save, **kwargs):
647    """
648    Save variables in analysis step for debugging purpose
649
650    Parameters
651    ----------
652    ind_save: int
653        Index of analysis step
654    **kwargs: dict
655        Variables that will be saved to npz file
656
657    Notes
658    -----
659    Use kwargs here because the input will be a dictionary with names equal the variable names to store, and when this
660    is passed to np.savez (kwargs) the variable will be stored with their original name.
661    """
662    # Save input variables
663    try:
664        np.savez('debug_analysis_step_{0}'.format(str(ind_save)), **kwargs)
665    except: # if npz save fails dump to a pickle file
666        with open(f'debug_analysis_step_{ind_save}.p', 'wb') as file:
667            pickle.dump(kwargs, file)

Save variables in analysis step for debugging purpose

Parameters
  • ind_save (int): Index of analysis step
  • **kwargs (dict): Variables that will be saved to npz file
Notes

Use kwargs here because the input will be a dictionary with names equal the variable names to store, and when this is passed to np.savez (kwargs) the variable will be stored with their original name.

def get_list_data_types(obs_data, assim_index):
670def get_list_data_types(obs_data, assim_index):
671    """
672    Extract the list of all and active data types 
673
674    Parameters
675    ----------
676    obs_data: dict
677        Observed data
678    assim_index: int
679        Current assimilation index
680
681    Returns
682    -------
683    l_all: list
684        List of all data types
685    l_act: list
686        List of the data types that are active (that are not `None`)
687    """
688    # List the primary indices
689    if isinstance(assim_index[0], list):  # If True, then we have subset list
690        if isinstance(assim_index[1][0], list):  # Check if prim. ind. is a list
691            l_prim = [int(x) for x in assim_index[1][0]]
692        else:
693            l_prim = [int(assim_index[1][0])]
694    else:  # Only prim. assim. ind.
695        if isinstance(assim_index[1], list):  # Check if prim. ind. is a list
696            l_prim = [int(x) for x in assim_index[1]]
697        else:
698            l_prim = [int(assim_index[1])]
699
700    # List the data types.
701    l_all = list(obs_data[l_prim[0]].keys())
702
703    # Extract the data types that are active at current assimilation step
704    l_act = []
705    for ix in l_prim:
706        for data_typ in l_all:
707            if obs_data[ix][data_typ] is not None:
708                l_act.extend([data_typ])
709
710    # Return the list
711    return l_all, l_act

Extract the list of all and active data types

Parameters
  • obs_data (dict): Observed data
  • assim_index (int): Current assimilation index
Returns
  • l_all (list): List of all data types
  • l_act (list): List of the data types that are active (that are not None)
def gen_covdata(datavar, assim_index, list_data):
714def gen_covdata(datavar, assim_index, list_data):
715    """
716    Generate the data covariance matrix at current assimilation step. Note here that the data covariance may be a
717    diagonal matrix with only variance entries, or an empirical covariance matrix, or both if in combination. For
718    diagonal data covariance we only store vector of variance values.
719
720    Parameters
721    ----------
722    datavar: list
723        List of dictionaries containing variance for the observed data. The structure of this list is the same as for
724        `obs_data`
725    assim_index: int
726        Current assimilation index
727    list_data: list
728        List of the data types
729
730    Returns
731    -------
732    cd: ndarray
733        Data auto-covariance matrix
734
735    Notes
736    -----
737    For empirical covariance generation, the datavar entry must be a 2D array, arranged as a standard ensemble matrix (N
738    x Ns, where Ns is the number of samples).
739    """
740    # TODO: Change if sub-assim. indices are implemented
741    # TODO: Use something other that numpy hstack for this augmentation!
742
743    # Make sure assim_index is list
744    if isinstance(assim_index[1], list):  # Check if prim. ind. is a list
745        l_prim = [int(x) for x in assim_index[1]]
746    else:
747        l_prim = [int(assim_index[1])]
748
749    # Init. a logical variable to check if it is the first time in the loop below that we extract variance data.
750    # Need this because we stack the remaining variance horizontally, and it is possible that we have "None"
751    # input in the first instances of the loop (hence we cannot always say that
752    # self.datavar[l_prim[0]][list_data[0]] will be the first variance we want to extract)
753    first_time = True
754
755    # Initialize augmented array
756    # Loop over all primary indices
757    for ix in range(len(l_prim)):
758        # Loop over data types and augment the data variance
759        for i in range(len(list_data)):
760            if datavar[l_prim[ix]][list_data[i]] is not None:
761                # If there is an observed data here, augment it
762                if first_time:  # Init. var output
763                    # Switch off the first time logical variable
764                    first_time = False
765
766                    # Calc. var.
767                    var = datavar[l_prim[ix]][list_data[i]]
768
769                    # If var is 2D then it is either full covariance or realizations to generate a sample cov.
770                    # If matrix is square assume it is full covariance, note this can go wrong!
771                    if var.ndim == 2:
772                        if var.shape[0] == var.shape[1]:  # full cov
773                            c_var = var
774                        else:
775                            c_var = calc_autocov(var)
776                    # else we make a diagonal matrix
777                    else:  # diagonal, only store vector
778                        c_var = var
779
780                else:  # Stack var output
781                    # Calc. var.
782                    var = datavar[l_prim[ix]][list_data[i]]
783
784                    # If var is 2D then we generate a sample cov., else we make a diagonal matrix
785                    if var.ndim == 2:  # empirical
786                        if var.shape[0] == var.shape[1]:  # full cov
787                            c_var_temp = var
788                        else:
789                            c_var_temp = calc_autocov(var)
790                        c_var = linalg.block_diag(c_var, c_var_temp)
791                    else:  # diagonal, only store vector
792                        c_var_temp = var
793                        c_var = np.append(c_var, c_var_temp)
794
795    # Generate the covariance matrix
796    cd = c_var
797
798    # Return data covariance matrix
799    return cd

Generate the data covariance matrix at current assimilation step. Note here that the data covariance may be a diagonal matrix with only variance entries, or an empirical covariance matrix, or both if in combination. For diagonal data covariance we only store vector of variance values.

Parameters
  • datavar (list): List of dictionaries containing variance for the observed data. The structure of this list is the same as for obs_data
  • assim_index (int): Current assimilation index
  • list_data (list): List of the data types
Returns
  • cd (ndarray): Data auto-covariance matrix
Notes

For empirical covariance generation, the datavar entry must be a 2D array, arranged as a standard ensemble matrix (N x Ns, where Ns is the number of samples).

def screen_data(cov_data, pred_data, obs_data_vector, keys_da, iteration):
802def screen_data(cov_data, pred_data, obs_data_vector, keys_da, iteration):
803    """
804    INSERT DESCRIPTION
805
806    Parameters
807    ----------
808    cov_data: ndarray
809        Data covariance matrix
810    pred_data: ndarray
811        Predicted data
812    obs_data_vector: 
813        Observed data (1D array)
814    keys_da: dict
815        Dictionary with every input in `DATAASSIM`
816    iteration: int
817        Current iteration
818
819    Returns
820    -------
821    cov_data: ndarray
822        Updated data covariance matrix
823    """
824
825    if ('restart' in keys_da and keys_da['restart'] == 'yes') or (iteration != 0):
826        with open('cov_data.p', 'rb') as f:
827            cov_data = pickle.load(f)
828    else:
829        emp_cov = False
830        if cov_data.ndim == 2:  # assume emp_cov
831            emp_cov = True
832            var = np.var(cov_data, ddof=1, axis=1)
833            cov_data = cov_data - cov_data.mean(1)[:, np.newaxis]
834        num_data = pred_data.shape[0]
835        for i in range(num_data):
836            v = 0
837            if obs_data_vector[i] < np.min(pred_data[i, :]):
838                v = np.abs(obs_data_vector[i] - np.min(pred_data[i, :]))
839            elif obs_data_vector[i] > np.max(pred_data[i, :]):
840                v = np.abs(obs_data_vector[i] - np.max(pred_data[i, :]))
841            if not emp_cov:
842                cov_data[i] = np.max((cov_data[i], v ** 2))
843            else:
844                v = np.max((v**2 / var[i], 1))
845                cov_data[i, :] *= np.sqrt(v)
846        with open('cov_data.p', 'wb') as f:
847            pickle.dump(cov_data, f)
848
849    return cov_data

INSERT DESCRIPTION

Parameters
  • cov_data (ndarray): Data covariance matrix
  • pred_data (ndarray): Predicted data
  • obs_data_vector (): Observed data (1D array)
  • keys_da (dict): Dictionary with every input in DATAASSIM
  • iteration (int): Current iteration
Returns
  • cov_data (ndarray): Updated data covariance matrix
def store_ensemble_sim_information(saveinfo, member):
852def store_ensemble_sim_information(saveinfo, member):
853    """
854    Here, we can either run a unique python script or do some other post-processing routines. The function should
855    not return anything, but provide a method for storing revevant information.
856    Input the current member for easy storage
857    """
858
859    for el in saveinfo:
860        if '.py' in el:  # This is a unique python file
861            sim_info_func = import_module(el[:-3])  # remove .py ending
862            # Note: the function must be named main, and we pass the full current instance of the object pluss the
863            # current member.
864            sim_info_func.main(member)

Here, we can either run a unique python script or do some other post-processing routines. The function should not return anything, but provide a method for storing revevant information. Input the current member for easy storage

def extract_tot_empirical_cov(data_var, assim_index, list_data, ne):
867def extract_tot_empirical_cov(data_var, assim_index, list_data, ne):
868    """
869    Extract realizations of noise from data_var (if imported), or generate realizations if only variance is specified
870    (assume uncorrelated)
871
872    Parameters
873    ----------
874    data_var: list
875        List of dictionaries containing the varianse as read from the input
876    assim_index: int
877        Index of the assimilation
878    list_data: list
879        List of data types
880    ne: int
881        Ensemble size
882
883    Returns
884    -------
885    E: ndarray
886        Sorted (according to assim_index and list_data) matrix of data realization noise.
887    """
888
889    if isinstance(assim_index[1], list):  # Check if prim. ind. is a list
890        l_prim = [int(x) for x in assim_index[1]]
891    else:
892        l_prim = [int(assim_index[1])]
893
894    tmp_E = []
895    for el in l_prim:
896        tmp_tmp_E = {}
897        for dat in list_data:
898            if data_var[el][dat] is not None:
899                if len(data_var[el][dat].shape) == 1:
900                    tmp_tmp_E[dat] = np.sqrt(
901                        data_var[el][dat][:, np.newaxis])*np.random.randn(data_var[el][dat].shape[0], ne)
902                else:
903                    if data_var[el][dat].shape[0] == data_var[el][dat].shape[1]:
904                        tmp_tmp_E[dat] = np.dot(linalg.cholesky(
905                            data_var[el][dat]), np.random.randn(data_var[el][dat].shape[1], ne))
906                    else:
907                        tmp_tmp_E[dat] = data_var[el][dat]
908        tmp_E.append(tmp_tmp_E)
909    E = np.concatenate(tuple(tmp_E[i][dat] for i, el in enumerate(
910        l_prim) for dat in list_data if data_var[el][dat] is not None))
911
912    return E

Extract realizations of noise from data_var (if imported), or generate realizations if only variance is specified (assume uncorrelated)

Parameters
  • data_var (list): List of dictionaries containing the varianse as read from the input
  • assim_index (int): Index of the assimilation
  • list_data (list): List of data types
  • ne (int): Ensemble size
Returns
  • E (ndarray): Sorted (according to assim_index and list_data) matrix of data realization noise.
def aug_obs_pred_data(obs_data, pred_data, assim_index, list_data):
915def aug_obs_pred_data(obs_data, pred_data, assim_index, list_data):
916    """
917    Augment the observed and predicted data to an array at an assimilation step. The observed data will be an augemented
918    vector and the predicted data will be an ensemble matrix.
919
920    Parameters
921    ----------
922    obs_data: list
923        List of dictionaries containing observed data
924    pred_data: list
925        List of dictionaries where each entry of the list is the forward simulation results at an assimilation step. The
926        dictionary has keys equal to the data type (given in `OBSNAME`).
927
928    Returns
929    -------
930    obs: ndarray 
931        Augmented vector of observed data
932    pred: ndarray
933        Ensemble matrix of predicted data
934    """
935    # TODO: Change if sub-assim. ind. are implemented.
936    # TODO: Use something other that numpy hstack and vstack for these augmentations!
937
938    # Make sure assim_index is a list
939    if isinstance(assim_index[1], list):  # Check if prim. ind. is a list
940        l_prim = [int(x) for x in assim_index[1]]
941    else:
942        l_prim = [int(assim_index[1])]
943
944    # make this more efficient
945
946    tot_pred = tuple(pred_data[el][dat] for el in l_prim if pred_data[el]
947                     is not None for dat in list_data if obs_data[el][dat] is not None)
948    if len(tot_pred):  # if this is done during the initiallization tot_pred contains nothing
949        pred = np.concatenate(tot_pred)
950    else:
951        pred = None
952    obs = np.concatenate(tuple(
953        obs_data[el][dat] for el in l_prim for dat in list_data if obs_data[el][dat] is not None))
954
955    # Init. a logical variable to check if it is the first time in the loop below that we extract obs/pred data.
956    # Need this because we stack the remaining data horizontally/vertically, and it is possible that we have "None"
957    # input in the first instances of the loop (hence we cannot always say that
958    # self.obs_data[l_prim[0]][list_data[0]] and self.pred_data[l_prim[0]][list_data[0]] will be the
959    # first data we want to extract)
960    # first_time = True
961    #
962    # #initialize obs and pred
963    # obs = None
964    # pred = None
965    #
966    # # Init the augmented arrays.
967    # # Loop over all primary indices
968    # for ix in range(len(l_prim)):
969    #     # Loop over obs_data/pred_data keys
970    #     for i in range(len(list_data)):
971    #         # If there is an observed data here, augment obs and pred
972    #         if obs_data[l_prim[ix]][list_data[i]] is not None:  # No obs/pred data
973    #             if first_time:  # Init. the outputs obs and pred
974    #                 # Switch off the first time logical variable
975    #                 first_time = False
976    #
977    #                 # Observed data:
978    #                 obs = obs_data[l_prim[ix]][list_data[i]]
979    #
980    #                 # Predicted data
981    #                 pred = pred_data[l_prim[ix]][list_data[i]]
982    #
983    #             else:  # Stack the obs and pred outputs
984    #                 # Observed data:
985    #                 obs = np.hstack((obs, obs_data[l_prim[ix]][list_data[i]]))
986    #
987    #                 # Predicted data
988    #                 pred = np.vstack((pred, pred_data[l_prim[ix]][list_data[i]]))
989    #
990    # # Return augmented arrays
991    return obs, pred

Augment the observed and predicted data to an array at an assimilation step. The observed data will be an augemented vector and the predicted data will be an ensemble matrix.

Parameters
  • obs_data (list): List of dictionaries containing observed data
  • pred_data (list): List of dictionaries where each entry of the list is the forward simulation results at an assimilation step. The dictionary has keys equal to the data type (given in OBSNAME).
Returns
  • obs (ndarray): Augmented vector of observed data
  • pred (ndarray): Ensemble matrix of predicted data
def calc_kalmangain(cov_cross, cov_auto, cov_data, opt=None):
 994def calc_kalmangain(cov_cross, cov_auto, cov_data, opt=None):
 995    r"""
 996    Calculate the Kalman gain
 997
 998    Parameters
 999    ----------
1000    cov_cross: ndarray
1001        Cross-covariance matrix between state and predicted data
1002    cov_auto: ndarray
1003        Auto-covariance matrix of predicted data
1004    cov_data: ndarray
1005        Variance on observed data (diagonal matrix)
1006    opt: str
1007        Which method should we use to calculate Kalman gain
1008        <ul>
1009            <li>'lu': LU decomposition (default)</li>
1010            <li>'chol': Cholesky decomposition</li>
1011        </ul>
1012
1013    Returns
1014    -------
1015    kalman_gain: ndarray
1016        Kalman gain
1017
1018    Notes
1019    -----
1020    In the following Kalman gain is $K$, cross-covariance is $C_{mg}$, predicted data auto-covariance is $C_{g}$,
1021    and data covariance is $C_{d}$.
1022
1023    With `'lu'` option, we solve the transposed linear system:
1024    $$
1025        K^T = (C_{g} + C_{d})^{-T}C_{mg}^T
1026    $$
1027
1028    With `'chol'` option we use Cholesky on auto-covariance matrix,
1029    $$
1030       L L^T = (C_{g} + C_{d})^T
1031    $$
1032    and solve linear system with the square-root matrix from Cholesky:
1033    $$
1034        L^T Y = C_{mg}^T\\
1035        LK = Y
1036    $$
1037    """
1038    if opt is None:
1039        calc_opt = 'lu'
1040
1041    # Add data and predicted data auto-covariance matrices
1042    if len(cov_data.shape) == 1:
1043        cov_data = np.diag(cov_data)
1044    c_auto = cov_auto + cov_data
1045
1046    if calc_opt == 'lu':
1047        kg = linalg.solve(c_auto.T, cov_cross.T)
1048        kalman_gain = kg.T
1049
1050    elif calc_opt == 'chol':
1051        # Cholesky decomp (upper triangular matrix)
1052        u = linalg.cho_factor(c_auto.T, check_finite=False)
1053
1054        # Solve linear system with cholesky square-root
1055        kalman_gain = linalg.cho_solve(u, cov_cross.T, check_finite=False)
1056
1057    # Return Kalman gain
1058    return kalman_gain

Calculate the Kalman gain

Parameters
  • cov_cross (ndarray): Cross-covariance matrix between state and predicted data
  • cov_auto (ndarray): Auto-covariance matrix of predicted data
  • cov_data (ndarray): Variance on observed data (diagonal matrix)
  • opt (str): Which method should we use to calculate Kalman gain

    • 'lu': LU decomposition (default)
    • 'chol': Cholesky decomposition
Returns
  • kalman_gain (ndarray): Kalman gain
Notes

In the following Kalman gain is $K$, cross-covariance is $C_{mg}$, predicted data auto-covariance is $C_{g}$, and data covariance is $C_{d}$.

With 'lu' option, we solve the transposed linear system: $$ K^T = (C_{g} + C_{d})^{-T}C_{mg}^T $$

With 'chol' option we use Cholesky on auto-covariance matrix, $$ L L^T = (C_{g} + C_{d})^T $$ and solve linear system with the square-root matrix from Cholesky: $$ L^T Y = C_{mg}^T\ LK = Y $$

def calc_subspace_kalmangain(cov_cross, data_pert, cov_data, energy):
1061def calc_subspace_kalmangain(cov_cross, data_pert, cov_data, energy):
1062    """
1063    Compute the Kalman gain in a efficient subspace determined by how much energy (i.e. percentage of singluar values)
1064    to retain. For more info regarding the implementation, see Chapter 14 in [1].
1065
1066    Parameters
1067    cov_cross: ndarray
1068        Cross-covariance matrix between state and predicted data
1069    data_pert: ndarray
1070            Predicted data - mean of predicted data
1071    cov_data: ndarray
1072        Variance on observed data (diagonal matrix)
1073
1074    Returns
1075    -------
1076    k_g: ndarray
1077        Subspace Kalman gain
1078
1079    References
1080    ----------
1081    [1] G. Evensen (2009). Data Assimilation: The Ensemble Kalman Filter, Springer.
1082    """
1083    # No. ensemble members
1084    ne = data_pert.shape[1]
1085
1086    # Perform SVD on pred. data perturbations
1087    u_d, s_d, v_d = np.linalg.svd(np.sqrt(1 / (ne - 1)) * data_pert, full_matrices=False)
1088
1089    # If no. measurements is more than ne - 1, we only keep ne - 1 sing. val.
1090    if data_pert.shape[0] >= ne:
1091        u_d, s_d, v_d = u_d[:, :-1].copy(), s_d[:-1].copy(), v_d[:-1, :].copy()
1092
1093    # If energy is less than 100 we truncate the SVD matrices
1094    if energy < 100:
1095        ti = (np.cumsum(s_d) / sum(s_d)) * 100 <= energy
1096        u_d, s_d, v_d = u_d[:, ti].copy(), s_d[ti].copy(), v_d[ti, :].copy()
1097
1098    # Calculate x_0 and its eigenvalue decomp.
1099    if len(cov_data.shape) == 1:
1100        x_0 = np.dot(np.diag(s_d[:]**(-1)), np.dot(u_d[:, :].T, np.expand_dims(cov_data, axis=1)*np.dot(u_d[:, :],
1101                                                                                                        np.diag(s_d[:]**(-1)).T)))
1102    else:
1103        x_0 = np.dot(np.diag(s_d[:] ** (-1)), np.dot(u_d[:, :].T, np.dot(cov_data, np.dot(u_d[:, :],
1104                                                                                          np.diag(s_d[:] ** (-1)).T))))
1105    s, u = np.linalg.eig(x_0)
1106
1107    # Calculate x_1
1108    x_1 = np.dot(u_d[:, :], np.dot(np.diag(s_d[:]**(-1)).T, u))
1109
1110    # Calculate Kalman gain based on the subspace matrices we made above
1111    k_g = np.dot(cov_cross, np.dot(x_1, linalg.solve(
1112        (np.eye(s.shape[0]) + np.diag(s)), x_1.T)))
1113
1114    # Return subspace Kalman gain
1115    return k_g

Compute the Kalman gain in a efficient subspace determined by how much energy (i.e. percentage of singluar values) to retain. For more info regarding the implementation, see Chapter 14 in [1].

Parameters cov_cross: ndarray Cross-covariance matrix between state and predicted data data_pert: ndarray Predicted data - mean of predicted data cov_data: ndarray Variance on observed data (diagonal matrix)

Returns
  • k_g (ndarray): Subspace Kalman gain
References

[1] G. Evensen (2009). Data Assimilation: The Ensemble Kalman Filter, Springer.

def compute_x(pert_preddata, cov_data, keys_da, alfa=None):
1118def compute_x(pert_preddata, cov_data, keys_da, alfa=None):
1119    """
1120    INSERT DESCRIPTION
1121
1122    Parameters
1123    ----------
1124    pert_preddata: ndarray
1125        Perturbed predicted data
1126    cov_data: ndarray
1127        Data covariance matrix
1128    keys_da: dict
1129        Dictionary with every input in `DATAASSIM`
1130    alfa: None, optional
1131        INSERT DESCRIPTION
1132
1133    Returns:
1134    X: ndarray
1135        INSERT DESCRIPTION
1136    """
1137    X = []
1138    if 'kalmangain' in keys_da and keys_da['kalmangain'][0] == 'subspace':
1139
1140        # TSVD energy
1141        energy = keys_da['kalmangain'][1]
1142
1143        # No. ensemble members
1144        ne = pert_preddata.shape[1]
1145
1146        # Calculate x_0 and its eigenvalue decomp.
1147        if len(cov_data.shape) == 1:
1148            scale = np.expand_dims(np.sqrt(cov_data), axis=1)
1149        else:
1150            scale = np.expand_dims(np.sqrt(np.diag(cov_data)), axis=1)
1151
1152        # Perform SVD on pred. data perturbations
1153        u_d, s_d, v_d = np.linalg.svd(pert_preddata/scale, full_matrices=False)
1154
1155        # If no. measurements is more than ne - 1, we only keep ne - 1 sing. val.
1156        if pert_preddata.shape[0] >= ne:
1157            u_d, s_d, v_d = u_d[:, :-1].copy(), s_d[:-1].copy(), v_d[:-1, :].copy()
1158
1159        # If energy is less than 100 we truncate the SVD matrices
1160        if energy < 100:
1161            ti = (np.cumsum(s_d) / sum(s_d)) * 100 <= energy
1162            u_d, s_d, v_d = u_d[:, ti].copy(), s_d[ti].copy(), v_d[ti, :].copy()
1163
1164        # Calculate x_0 and its eigenvalue decomp.
1165        if len(cov_data.shape) == 1:
1166            x_0 = np.dot(np.diag(s_d[:] ** (-1)),
1167                         np.dot(u_d[:, :].T, np.expand_dims(cov_data, axis=1) * np.dot(u_d[:, :],
1168                                                                                       np.diag(s_d[:] ** (-1)).T)))
1169        else:
1170            x_0 = np.dot(np.diag(s_d[:] ** (-1)), np.dot(u_d[:, :].T, np.dot(cov_data, np.dot(u_d[:, :],
1171                                                                                              np.diag(s_d[:] ** (-1)).T))))
1172        s, u = np.linalg.eig(x_0)
1173
1174        # Calculate x_1
1175        x_1 = np.dot(u_d[:, :], np.dot(np.diag(s_d[:] ** (-1)).T, u))/scale
1176
1177        # Calculate X based on the subspace matrices we made above
1178        X = np.dot(np.dot(pert_preddata.T, x_1), linalg.solve(
1179            (np.eye(s.shape[0]) + np.diag(s)), x_1.T))
1180
1181    else:
1182        if len(cov_data.shape) == 1:
1183            X = linalg.solve(np.dot(pert_preddata, pert_preddata.T) +
1184                             np.diag(cov_data), pert_preddata)
1185        else:
1186            X = linalg.solve(np.dot(pert_preddata, pert_preddata.T) +
1187                             cov_data, pert_preddata)
1188        X = X.T
1189
1190    return X

INSERT DESCRIPTION

Parameters
  • pert_preddata (ndarray): Perturbed predicted data
  • cov_data (ndarray): Data covariance matrix
  • keys_da (dict): Dictionary with every input in DATAASSIM
  • alfa (None, optional): INSERT DESCRIPTION
  • Returns:
  • X (ndarray): INSERT DESCRIPTION
def aug_state(state, list_state, cell_index=None):
1193def aug_state(state, list_state, cell_index=None):
1194    """
1195    Augment the state variables to an array.
1196
1197    Parameters
1198    ----------
1199    state: dict
1200        Dictionary of initial ensemble of (joint) state variables (static parameters and dynamic variables) to be
1201        assimilated.
1202    list_state: list
1203        Fixed list of keys in state dict.
1204    cell_index: list of vector indexes to be extracted
1205
1206    Returns
1207    -------
1208    aug: ndarray
1209        Ensemble matrix of augmented state variables
1210    """
1211    # TODO: In some rare cases, it may not be desirable to update every state variable at each assimilation step.
1212    # Change code to only augment states to be updated at the specific assimilation step
1213    # TODO: Use something other that numpy vstack for this augmentation!
1214
1215    if cell_index is not None:
1216        # Start with ensemble of first state variable
1217        aug = state[list_state[0]][cell_index]
1218
1219        # Loop over the next states (if exists)
1220        for i in range(1, len(list_state)):
1221            aug = np.vstack((aug, state[list_state[i]][cell_index]))
1222
1223        # Return the augmented array
1224
1225    else:
1226        # Start with ensemble of first state variable
1227        aug = state[list_state[0]]
1228
1229        # Loop over the next states (if exists)
1230        for i in range(1, len(list_state)):
1231            aug = np.vstack((aug, state[list_state[i]]))
1232
1233        # Return the augmented array
1234    return aug

Augment the state variables to an array.

Parameters
  • state (dict): Dictionary of initial ensemble of (joint) state variables (static parameters and dynamic variables) to be assimilated.
  • list_state (list): Fixed list of keys in state dict.
  • cell_index (list of vector indexes to be extracted):
Returns
  • aug (ndarray): Ensemble matrix of augmented state variables
def calc_scaling(state, list_state, prior_info):
1237def calc_scaling(state, list_state, prior_info):
1238    """
1239    Form the scaling to be used in svd related algoritms. Scaling consist of standard deviation for each `STATICVAR`
1240    It is important that this is formed in the same manner as the augmentet state vector is formed. Hence, with the same
1241    list of states.
1242
1243    Parameters
1244    ----------
1245    state: dict
1246        Dictionary containing the state
1247    list_state: list
1248        List of states for augmenting
1249    prior_info: dict
1250        Nested dictionary containing prior information
1251
1252    Returns
1253    -------
1254    scaling: numpy array
1255        scaling
1256    """
1257
1258    scaling = []
1259    for elem in list_state:
1260        # more than single value. This is for multiple layers. Assume all values are active
1261        if len(prior_info[elem]['variance']) > 1:
1262            scaling.append(np.concatenate(tuple(np.sqrt(prior_info[elem]['variance'][z]) *
1263                                                np.ones(
1264                                                    prior_info[elem]['ny']*prior_info[elem]['nx'])
1265                                                for z in range(prior_info[elem]['nz']))))
1266        else:
1267            scaling.append(tuple(np.sqrt(prior_info[elem]['variance']) *
1268                                 np.ones(state[elem].shape[0])))
1269
1270    return np.concatenate(scaling)

Form the scaling to be used in svd related algoritms. Scaling consist of standard deviation for each STATICVAR It is important that this is formed in the same manner as the augmentet state vector is formed. Hence, with the same list of states.

Parameters
  • state (dict): Dictionary containing the state
  • list_state (list): List of states for augmenting
  • prior_info (dict): Nested dictionary containing prior information
Returns
  • scaling (numpy array): scaling
def update_state(aug_state, state, list_state, cell_index=None):
1273def update_state(aug_state, state, list_state, cell_index=None):
1274    """
1275    Extract the separate state variables from an augmented state array. It is assumed that the augmented state
1276    array is made in `aug_state`, hence this is the reverse method of `aug_state`.
1277
1278    Parameters
1279    ----------
1280    aug_state: ndarray
1281        Augmented array of UPDATED state variables
1282    state: dict
1283        Dict. of state variables NOT updated.
1284    list_state: list
1285        List of state keys that have been updated
1286    cell_index: list
1287        List of indexes that gives the where the aug state should be placed
1288
1289    Returns
1290    -------
1291    state: dict
1292        Dict. of UPDATED state variables
1293    """
1294    if cell_index is None:
1295        # Loop over all entries in list_state and extract a matrix with same number of rows as the key in state
1296        # determines from aug and replace the values in state[key].
1297        # Init. a variable to keep track of which row in 'aug' we start from in each loop
1298        aug_row = 0
1299        for _, key in enumerate(list_state):
1300            # Find no. rows in state[lkey] to determine how many rows from aug to extract
1301            no_rows = state[key].shape[0]
1302
1303            # Extract the rows from aug and update 'state[key]'
1304            state[key] = aug_state[aug_row:aug_row + no_rows, :]
1305
1306            # Update tracking variable for row in 'aug'
1307            aug_row += no_rows
1308
1309    else:
1310        aug_row = 0
1311        for _, key in enumerate(list_state):
1312            # Find no. rows in state[lkey] to determine how many rows from aug to extract
1313            no_rows = len(cell_index)
1314
1315            # Extract the rows from aug and update 'state[key]'
1316            state[key][cell_index, :] = aug_state[aug_row:aug_row + no_rows, :]
1317
1318            # Update tracking variable for row in 'aug'
1319            aug_row += no_rows
1320    return state

Extract the separate state variables from an augmented state array. It is assumed that the augmented state array is made in aug_state, hence this is the reverse method of aug_state.

Parameters
  • aug_state (ndarray): Augmented array of UPDATED state variables
  • state (dict): Dict. of state variables NOT updated.
  • list_state (list): List of state keys that have been updated
  • cell_index (list): List of indexes that gives the where the aug state should be placed
Returns
  • state (dict): Dict. of UPDATED state variables
def resample_state(aug_state, state, list_state, new_en_size):
1323def resample_state(aug_state, state, list_state, new_en_size):
1324    """
1325    Extract the seperate state variables from an augmented state matrix. Calculate the mean and covariance, and resample
1326    this.
1327
1328    Parameters
1329    ----------
1330    aug_upd_state: ndarray
1331        Augmented matrix of state variables
1332    state: dict
1333        Dict. af state variables
1334    list_state: list
1335        List of state variable
1336    new_en_size: int
1337        Size of the new ensemble
1338
1339    Returns
1340    -------
1341    state: dict
1342        Dict. of resampled members
1343    """
1344
1345    aug_row = 0
1346    curr_ne = state[list_state[0]].shape[1]
1347    new_state = {}
1348    for elem in list_state:
1349        # determine how many rows to extract
1350        no_rows = state[elem].shape[0]
1351        new_state[elem] = np.empty((no_rows, new_en_size))
1352
1353        mean_state = np.mean(aug_state[aug_row:aug_row + no_rows, :], 1)
1354        pert_state = np.sqrt(1/(curr_ne - 1)) * (aug_state[aug_row:aug_row + no_rows, :] - np.dot(np.resize(mean_state,
1355                                                                                                            (len(mean_state), 1)), np.ones((1, curr_ne))))
1356        for i in range(new_en_size):
1357            new_state[elem][:, i] = mean_state + \
1358                np.dot(pert_state, np.random.normal(0, 1, pert_state.shape[1]))
1359
1360        aug_row += no_rows
1361
1362    return new_state

Extract the seperate state variables from an augmented state matrix. Calculate the mean and covariance, and resample this.

Parameters
  • aug_upd_state (ndarray): Augmented matrix of state variables
  • state (dict): Dict. af state variables
  • list_state (list): List of state variable
  • new_en_size (int): Size of the new ensemble
Returns
  • state (dict): Dict. of resampled members
def block_diag_cov(cov, list_state):
1365def block_diag_cov(cov, list_state):
1366    """
1367    Block diagonalize a covariance matrix dictionary.
1368
1369    Parameters
1370    ----------
1371    cov: dict
1372        Dict. with cov. matrices
1373    list_state: list
1374        Fixed list of keys in state dict.
1375
1376    Returns
1377    -------
1378    cov_out: ndarray
1379        Block diag. matrix with prior covariance matrices for each state.
1380    """
1381    # TODO: Change if there are cross-correlation between different states
1382
1383    # Init. block in matrix
1384    cov_out = cov[list_state[0]]
1385
1386    # Test if scalar has been given in init. block
1387    if not hasattr(cov_out, '__len__'):
1388        cov_out = np.array([[cov_out]])
1389
1390    # Loop of rest of the state-names and add in block diag. matrix
1391    for i in range(1, len(list_state)):
1392        cov_out = linalg.block_diag(cov_out, cov[list_state[i]])
1393
1394    # Return
1395    return cov_out

Block diagonalize a covariance matrix dictionary.

Parameters
  • cov (dict): Dict. with cov. matrices
  • list_state (list): Fixed list of keys in state dict.
Returns
  • cov_out (ndarray): Block diag. matrix with prior covariance matrices for each state.
def calc_kalman_filter_eq(aug_state, kalman_gain, obs_data, pred_data):
1398def calc_kalman_filter_eq(aug_state, kalman_gain, obs_data, pred_data):
1399    """
1400    Calculate the updated augment state using the Kalman filter equations
1401
1402    Parameters
1403    ----------
1404    aug_state: ndarray
1405        Augmented state variable (all the parameters defined in `STATICVAR` augmented in one array)
1406    kalman_gain: ndarray
1407        Kalman gain
1408    obs_data: ndarray
1409        Augmented observed data vector (all `OBSNAME` augmented in one array)
1410    pred_data: ndarray
1411        Augmented predicted data vector (all `OBSNAME` augmented in one array)
1412
1413    Returns
1414    -------
1415    aug_state_upd: ndarray
1416        Updated augmented state variable using the Kalman filter equations
1417    """
1418    # TODO: Implement svd updating algorithm
1419
1420    # Matrix version
1421    # aug_state_upd = aug_state + np.dot(kalman_gain, (obs_data - pred_data))
1422
1423    # For-loop version
1424    aug_state_upd = np.zeros(aug_state.shape)  # Init. updated state
1425
1426    for i in range(aug_state.shape[1]):  # Loop over ensemble members
1427        aug_state_upd[:, i] = aug_state[:, i] + \
1428            np.dot(kalman_gain, (obs_data[:, i] - pred_data[:, i]))
1429
1430    # Return the updated state
1431    return aug_state_upd

Calculate the updated augment state using the Kalman filter equations

Parameters
  • aug_state (ndarray): Augmented state variable (all the parameters defined in STATICVAR augmented in one array)
  • kalman_gain (ndarray): Kalman gain
  • obs_data (ndarray): Augmented observed data vector (all OBSNAME augmented in one array)
  • pred_data (ndarray): Augmented predicted data vector (all OBSNAME augmented in one array)
Returns
  • aug_state_upd (ndarray): Updated augmented state variable using the Kalman filter equations
def limits(state, prior_info):
1434def limits(state, prior_info):
1435    """
1436    Check if any state variables overshoots the limits given by the prior info. If so, modify these values
1437
1438    Parameters
1439    ----------
1440    state: dict
1441        Dictionary containing the states
1442    prior_info: dict
1443        Dictionary containing prior information for all the states.
1444
1445    Returns
1446    -------
1447    state: dict
1448        Valid state
1449    """
1450    for var in state.keys():
1451        if 'limits' in prior_info[var]:
1452            state[var][state[var] < prior_info[var]['limits']
1453                       [0][0]] = prior_info[var]['limits'][0][0]
1454            state[var][state[var] > prior_info[var]['limits']
1455                       [0][1]] = prior_info[var]['limits'][0][1]
1456    return state

Check if any state variables overshoots the limits given by the prior info. If so, modify these values

Parameters
  • state (dict): Dictionary containing the states
  • prior_info (dict): Dictionary containing prior information for all the states.
Returns
  • state (dict): Valid state
def subsample_state(index, aug_state, pert_state):
1459def subsample_state(index, aug_state, pert_state):
1460    """
1461    Draw a subsample from the original state, given by the index
1462
1463    Parameters
1464    ----------
1465    index: ndarray
1466        Index of parameters to draw.
1467    aug_state: ndarray
1468        Original augmented state.
1469    pert_state: ndarray
1470        Perturbed augmented state, for error covariance.
1471
1472    Returns
1473    -------
1474    new_state: dict
1475        Subsample of state.
1476    """
1477
1478    new_state = np.empty((aug_state.shape[0], len(index)))
1479    for i in range(len(index)):
1480        new_state[:, i] = aug_state[:, index[i]] + \
1481            np.dot(pert_state, np.random.normal(0, 1, pert_state.shape[1]))
1482        # select some elements
1483
1484    return new_state

Draw a subsample from the original state, given by the index

Parameters
  • index (ndarray): Index of parameters to draw.
  • aug_state (ndarray): Original augmented state.
  • pert_state (ndarray): Perturbed augmented state, for error covariance.
Returns
  • new_state (dict): Subsample of state.
def init_local_analysis(init, state):
1487def init_local_analysis(init, state):
1488    """Initialize local analysis.
1489
1490    Initialize the local analysis by reading the input variables, defining the parameter classes and search ranges. Build
1491    the map of data/parameter positions.
1492
1493    Args:
1494        init: dictionary containing the parsed information form the input file.
1495        state: list of states that will be updated
1496    Returns:
1497        local: dictionary of initialized values.
1498    """
1499
1500    local = {}
1501    local['cell_parameter'] = []
1502    local['region_parameter'] = []
1503    local['vector_region_parameter'] = []
1504    local['unique'] = True
1505
1506    for i, opt in enumerate(list(zip(*init))[0]):
1507        if opt.lower() == 'region_parameter':  # define scalar parameters valid in a region
1508            local['region_parameter'] = [
1509                elem for elem in init[i][1].split(' ') if elem in state]
1510        if opt.lower() == 'vector_region_parameter': # Sometimes it useful to define the same parameter for multiple
1511                                                    # regions as a vector.
1512            local['vector_region_parameter'] = [
1513                elem for elem in init[i][1].split(' ') if elem in state]
1514        if opt.lower() == 'cell_parameter':  # define cell specific vector parameters
1515            local['cell_parameter'] = [
1516                elem for elem in init[i][1].split(' ') if elem in state]
1517        if opt.lower() == 'search_range':
1518            local['search_range'] = int(init[i][1])
1519        if opt.lower() == 'column_update':
1520            local['column_update'] = [elem for elem in init[i][1].split(',')]
1521        if opt.lower() == 'parameter_position_file':  # assume pickled format
1522            with open(init[i][1], 'rb') as file:
1523                local['parameter_position'] = pickle.load(file)
1524        if opt.lower() == 'data_position_file':  # assume pickled format
1525            with open(init[i][1], 'rb') as file:
1526                local['data_position'] = pickle.load(file)
1527        if opt.lower() == 'update_mask_file':
1528            with open(init[i][1], 'rb') as file:
1529                local['update_mask'] = pickle.load(file)
1530
1531    if 'update_mask' in local:
1532        return local
1533    else:
1534        assert 'parameter_position' in local, 'A pickle file containing the binary map of the parameters is MANDATORY'
1535        assert 'data_position' in local, 'A pickle file containing the position of the data is MANDATORY'
1536
1537        data_name = [elem for elem in local['data_position'].keys()]
1538        if type(local['data_position'][data_name[0]][0]) == list:  # assim index has spesific position
1539            local['unique'] = False
1540            data_pos = [elem for data in data_name for assim_elem in local['data_position'][data]
1541                        for elem in assim_elem]
1542            data_ind = [f'{data}_{assim_indx}' for data in data_name for assim_indx, assim_elem in enumerate(local['data_position'][data])
1543                        for _ in assim_elem]
1544        else:
1545            data_pos = [elem for data in data_name for elem in local['data_position'][data]]
1546            # store the name for easy index
1547            data_ind = [data for data in data_name for _ in local['data_position'][data]]
1548        kde_search = cKDTree(data=data_pos)
1549
1550        local['update_mask'] = {}
1551        for param in local['cell_parameter']:  # find data in a distance from the parameter
1552            field_size = local['parameter_position'][param].shape
1553            local['update_mask'][param] = [[[[] for _ in range(field_size[2])] for _ in range(field_size[1])] for _
1554                                           in range(field_size[0])]
1555            for k in range(field_size[0]):
1556                for j in range(field_size[1]):
1557                    new_iter = [elem for elem, val in enumerate(
1558                        local['parameter_position'][param][k, j, :]) if val]
1559                    if len(new_iter):
1560                        for i in new_iter:
1561                            local['update_mask'][param][k][j][i] = set(
1562                                [data_ind[elem] for elem in kde_search.query_ball_point(x=(k, j, i),
1563                                                                                        r=local['search_range'], workers=-1)])
1564
1565        # see if data is inside the region. Note parameter_position is boolean map
1566        for param in local['region_parameter']:
1567            in_region = [local['parameter_position'][param][elem] for elem in data_pos]
1568            local['update_mask'][param] = set(
1569                [data_ind[count] for count, val in enumerate(in_region) if val])
1570
1571        return local

Initialize local analysis.

Initialize the local analysis by reading the input variables, defining the parameter classes and search ranges. Build the map of data/parameter positions.

Args: init: dictionary containing the parsed information form the input file. state: list of states that will be updated Returns: local: dictionary of initialized values.