Source code for lair.inversion.core

"""
Core inversion classes and functions.

This module provides core classes and utilities for formulating and solving inverse problems, including:
    - Abstract base classes for inversion estimators.
    - Registry for estimator implementations.
    - Forward operator and covariance matrix wrappers with index-aware functionality.
    - Utilities for convolving state vectors with forward operators.
    - The `InverseProblem` class, which orchestrates the alignment of data, prior information, error covariances, and the solution process.
"""

from abc import ABC, abstractmethod
from functools import cached_property, partial
from pathlib import Path
from typing_extensions import \
    Self  # requires python 3.11 to import from typing

import numpy as np
from numpy.linalg import inv as invert
import pandas as pd
import xarray as xr


from lair.inversion.utils import dataframe_matrix_to_xarray, round_index


# TODO
# - Obs aggregation
# - file io


[docs] class Estimator(ABC): """ Base inversion estimator class. Attributes ---------- z : np.ndarray Observed data. x_0 : np.ndarray Prior model state estimate. H : np.ndarray Forward operator. S_0 : np.ndarray Prior error covariance. S_z : np.ndarray Model-data mismatch covariance. c : np.ndarray or float, optional Constant data, defaults to 0.0. n_z : int Number of observations. n_x : int Number of state variables. x_hat : np.ndarray Posterior mean model state estimate (solution). S_hat : np.ndarray Posterior error covariance. y_hat : np.ndarray Posterior modeled observations. y_0 : np.ndarray Prior modeled observations. K : np.ndarray Kalman gain. A : np.ndarray Averaging kernel. chi2 : float Chi-squared statistic. R2 : float Coefficient of determination. RMSE : float Root mean square error. U_red : np.ndarray Reduced uncertainty. Methods ------- cost(x: np.ndarray) -> float Cost/loss/misfit function. forward(x: np.ndarray) -> np.ndarray Forward model calculation. residual(x: np.ndarray) -> np.ndarray Forward model residual. leverage(x: np.ndarray) -> np.ndarray Calculate the leverage matrix. """
[docs] def __init__(self, z: np.ndarray, x_0: np.ndarray, H: np.ndarray, S_0: np.ndarray, S_z: np.ndarray, c: np.ndarray | float | None = None, ): """ Initialize the Estimator object. Parameters ---------- z : np.ndarray Observed data. x_0 : np.ndarray Prior model state estimate. H : np.ndarray Forward operator. S_0 : np.ndarray Prior error covariance. S_z : np.ndarray Model-data mismatch covariance. c : np.ndarray or float, optional Constant data, defaults to 0.0. """ self.z = z self.x_0 = x_0 self.H = H self.S_0 = S_0 self.S_z = S_z self.c = c if c is not None else 0.0 self.n_z = z.shape[0] self.n_x = x_0.shape[0]
[docs] def forward(self, x) -> np.ndarray: """ Forward model calculation. .. math:: y = Hx + c Parameters ---------- x : np.ndarray State vector. Returns ------- np.ndarray Model output (Hx + c). """ print('Performing forward calculation...') return self.H @ x + self.c
[docs] def residual(self, x) -> np.ndarray: """ Forward model residual. .. math:: r = z - (Hx + c) Parameters ---------- x : np.ndarray State vector. Returns ------- np.ndarray Residual (z - (Hx + c)). """ print('Performing residual calculation...') return self.z - self.forward(x)
[docs] def leverage(self, x) -> np.ndarray: """ Calculate the leverage matrix. Which observations are likely to have more impact on the solution. .. math:: L = Hx ((Hx)^T (H S_0 H^T + S_z)^{-1} Hx)^{-1} (Hx)^T (H S_0 H^T + S_z)^{-1} Parameters ---------- x : np.ndarray State vector. Returns ------- np.ndarray Leverage matrix. """ print('Calculating Leverage matrix...') Hx = self.forward(x) Hx_T = Hx.T HS_0H_Sz_inv = invert(self._HS_0H + self.S_z) return Hx @ invert(Hx_T @ HS_0H_Sz_inv @ Hx) @ Hx_T @ HS_0H_Sz_inv
[docs] @abstractmethod def cost(self, x) -> float: """ Cost/loss/misfit function. Parameters ---------- x : np.ndarray State vector. Returns ------- float Cost value. """ print('Performing cost calculation...') raise NotImplementedError
@property @abstractmethod def x_hat(self) -> np.ndarray: """ Posterior mean model state estimate (solution). Returns ------- np.ndarray Posterior state estimate. """ print('Calculating Posterior Mean Model State Estimate...') raise NotImplementedError @property @abstractmethod def S_hat(self) -> np.ndarray: """ Posterior error covariance matrix. Returns ------- np.ndarray Posterior error covariance matrix. """ print('Calculating Posterior Error Covariance Matrix...') raise NotImplementedError @cached_property def y_hat(self) -> np.ndarray: """ Posterior mean observation estimate. .. math:: \\hat{y} = H \\hat{x} + c Returns ------- np.ndarray Posterior observation estimate. """ print('Calculating Posterior Mean Observation Estimate...') return self.forward(self.x_hat) @cached_property def y_0(self) -> np.ndarray: """ Prior mean data estimate. .. math:: \\hat{y}_0 = H x_0 + c Returns ------- np.ndarray Prior data estimate. """ print('Calculating Prior Mean Data Estimate...') return self.forward(self.x_0) @cached_property def K(self): """ Kalman gain matrix. .. math:: K = (H S_0)^T (H S_0 H^T + S_z)^{-1} Returns ------- np.ndarray Kalman gain matrix. """ print('Calculating Kalman Gain Matrix...') return self._HS_0.T @ invert(self._HS_0H + self.S_z) @cached_property def A(self): """ Averaging kernel matrix. .. math:: A = KH = (H S_0)^T (H S_0 H^T + S_z)^{-1} H Returns ------- np.ndarray Averaging kernel matrix. """ print('Calculating Averaging Kernel Matrix...') return self.K @ self.H @cached_property def _H_T(self): """ Transpose of the forward operator """ return self.H.T @cached_property def _HS_0(self): """ ... math:: H S_0 """ return self.H @ self.S_0 @cached_property def _HS_0H(self): """ ... math:: H S_0 H^T """ return self._HS_0 @ self._H_T @cached_property def _S_0_inv(self): """ Inverse of prior error covariance matrix """ return invert(self.S_0) @cached_property def _S_z_inv(self): """ Inverse of model-data mismatch covariance matrix """ return invert(self.S_z) @cached_property def DOFS(self) -> float: """ Degrees Of Freedom for Signal (DOFS). .. math:: DOFS = Tr(A) Returns ------- float Degrees of Freedom value. """ return np.trace(self.A) @cached_property def chi2(self) -> float: """ Reduced Chi-squared statistic. .. math:: \\chi^2 = \\frac{1}{n_z} ((z - H\\hat{x})^T S_z^{-1} (z - H\\hat{x}) + (\\hat{x} - x_0)^T S_0^{-1} (\\hat{x} - x_0)) Returns ------- float Reduced Chi-squared value. """ # TBH im not 100% sure this is right return (self.chi2_obs + self.chi2_state) / self.n_z @cached_property def chi2_obs(self) -> float: """ Chi-squared statistic for observation params .. math:: \\chi^2 = (z - H\\hat{x})^T S_z^{-1} (z - H\\hat{x}) Returns ------- float Chi-squared value. """ r = self.residual(self.x_hat) return (r.T @ self._S_z_inv @ r) / self.n_z @cached_property def chi2_state(self) -> float: """ Chi-squared statistic for state params .. math:: \\chi^2 = (\\hat{x} - x_0)^T S_0^{-1} (\\hat{x} - x_0) """ r = self.x_hat - self.x_0 return (r.T @ self._S_0_inv @ r) / self.n_x @cached_property def R2(self) -> float: """ Coefficient of determination (R-squared). .. math:: R^2 = corr(z, H\\hat{x})^2 Returns ------- float R-squared value. """ print('Calculating Coefficient of determination (R-squared)...') return np.corrcoef(self.z, self.y_hat)[0, 1] ** 2 @cached_property def RMSE(self) -> float: """ Root mean square error (RMSE). .. math:: RMSE = \\sqrt{\\frac{(z - H\\hat{x})^2}{n_z}} Returns ------- float RMSE value. """ print('Calculating Root Mean Square Error (RMSE)...') r = self.residual(self.x_hat) return np.sqrt((r ** 2) / self.n_z) @cached_property def U_red(self): """ Uncertainty reduction metric. .. math:: U_{red} = 1 - \\frac{\\sqrt{trace(\\hat{S})}}{\\sqrt{trace(S_0)}} Returns ------- float Uncertainty reduction value. """ print('Calculating Uncertainty reduction metric...') return 1 - (np.sqrt(np.trace(self.S_hat)) / np.sqrt(np.trace(self.S_0)))
[docs] class EstimatorRegistry(dict): """ Registry for estimator classes. """
[docs] def register(self, name: str): """ Register an estimator class under a given name. Parameters ---------- name : str Name to register the estimator under. Returns ------- decorator : function Decorator to register the class. """ def decorator(cls: type[Estimator]) -> type[Estimator]: self[name] = cls return cls return decorator
ESTIMATOR_REGISTRY = EstimatorRegistry()
[docs] def convolve(forward_operator: pd.DataFrame, state: pd.Series, coord_decimals: int = 6) -> pd.Series: """ Convolve a forward_operator with a state field to get modeled observations. Parameters ---------- forward_operator : pd.DataFrame DataFrame with columns corresponding to the state index and rows corresponding to the observation index. state : pd.Series Series with rows corresponding to the state index. coord_decimals : int, optional Number of decimal places to round coordinates to when matching indices, by default 6. Returns ------- pd.Series Series with the same index as the forward_operator, containing the modeled observations. """ fo = forward_operator.copy() state = state.copy() # Round floating point coordinates to avoid precision issues fo.columns = round_index(fo.columns, decimals=coord_decimals) state.index = round_index(state.index, decimals=coord_decimals) # Ensure the state index matches the forward operator columns if isinstance(fo.columns, pd.MultiIndex): if not isinstance(state.index, pd.MultiIndex): raise ValueError("If forward operator columns are a MultiIndex, state index must also be a MultiIndex.") state.index = state.index.reorder_levels(fo.columns.names) common = fo.columns.union(state.index) fo = fo.reindex(columns=common) state = state.reindex(index=common) if np.isnan(fo).any().any(): raise ValueError("Forward operator contains NaN values after reindexing.") if np.isnan(state).any(): raise ValueError("state contains NaN values after reindexing.") # Perform the matrix multiplication to get modeled observations modeled_obs = fo @ state modeled_obs.name = f'{state.name}_obs' return modeled_obs
[docs] class ForwardOperator: """ Forward operator class for modeling observations. Parameters ---------- data : pd.DataFrame Forward operator matrix. Attributes ---------- data : pd.DataFrame Underlying forward operator matrix. obs_index : pd.Index Observation index (row index). state_index : pd.Index State index (column index). obs_dims : tuple Observation dimension names. state_dims : tuple State dimension names. Methods ------- convolve(state: pd.Series, coord_decimals: int = 6) -> pd.Series Convolve the forward operator with a state vector. to_xarray() -> xr.DataArray Convert the forward operator to an xarray DataArray. """
[docs] def __init__(self, data: pd.DataFrame): """ Initialize the ForwardOperator. Parameters ---------- data : pd.DataFrame Forward operator matrix. """ if not isinstance(data, pd.DataFrame): raise TypeError("Input data must be a pandas DataFrame.") self._data = data
@property def data(self) -> pd.DataFrame: """ Get the underlying data of the forward operator. Returns ------- pd.DataFrame Forward operator matrix. """ return self._data @property def obs_index(self) -> pd.Index: """ Get the observation index (row index) of the forward operator. Returns ------- pd.Index Observation index. """ return self._data.index @property def state_index(self) -> pd.Index: """ Get the state index (column index) of the forward operator. Returns ------- pd.Index State index. """ return self._data.columns @property def obs_dims(self) -> tuple: """ Get the observation dimensions (names of the row index). Returns ------- tuple Observation dimension names. """ return tuple(self.obs_index.names) @property def state_dims(self) -> tuple: """ Get the state dimensions (names of the column index). Returns ------- tuple State dimension names. """ return tuple(self.state_index.names)
[docs] def convolve(self, state: pd.Series, coord_decimals: int = 6 ) -> pd.Series: """ Convolve the forward operator with a state vector. Parameters ---------- state : pd.Series State vector. coord_decimals : int, optional Number of decimal places to round coordinates to when matching indices, by default 6. Returns ------- pd.Series Result of convolution. """ return convolve(forward_operator=self._data, state=state, coord_decimals=coord_decimals)
[docs] def to_xarray(self) -> xr.DataArray: """ Convert the forward operator to an xarray DataArray. Returns ------- xr.DataArray Xarray representation of the forward operator. """ """Convert the forward operator to an xarray DataArray.""" return dataframe_matrix_to_xarray(self._data)
[docs] class SymmetricMatrix: """ Symmetric matrix wrapper class for pandas DataFrames. Parameters ---------- data : pd.DataFrame Symmetric matrix with identical row and column indices. Attributes ---------- data : pd.DataFrame Symmetric matrix. index : pd.Index Index of the symmetric matrix. dims : tuple Dimension names of the symmetric matrix. values : np.ndarray Underlying data as a NumPy array. shape : tuple Dimensionality of the symmetric matrix. loc : SymmetricMatrix._Indexer Custom accessor for label-based selection and assignment. Methods ------- from_numpy(array: np.ndarray, index: pd.Index) -> SymmetricMatrix Create a SymmetricMatrix from a NumPy array and an index. reindex(index: pd.Index, **kwargs) -> SymmetricMatrix Reindex the symmetric matrix, filling new entries with 0. reorder_levels(order) -> SymmetricMatrix Reorder the levels of a MultiIndex symmetric matrix. """
[docs] def __init__(self, data: pd.DataFrame): """ Initialize the SymmetricMatrix with a square DataFrame. Parameters ---------- data : pd.DataFrame Square symmetric matrix. """ if not isinstance(data, pd.DataFrame): raise TypeError("Input data must be a pandas DataFrame.") if not data.index.equals(data.columns): raise ValueError("Symmetric matrix must have identical row and column indices.") self._data = data self.loc = self.__class__._Indexer(self)
[docs] @classmethod def from_numpy(cls, array: np.ndarray, index: pd.Index) -> Self: """ Create a SymmetricMatrix from a NumPy array. Parameters ---------- array : np.ndarray Symmetric matrix array. index : pd.Index Index for rows and columns. Returns ------- SymmetricMatrix SymmetricMatrix instance. """ return cls(pd.DataFrame(array, index=index, columns=index))
@property def data(self) -> pd.DataFrame: """ Returns the underlying data as a pandas DataFrame. Returns ------- pd.DataFrame Underlying symmetric matrix. """ return self._data @property def dims(self) -> tuple: """ Returns a tuple representing the dimension names of the matrix. Returns ------- tuple Dimension names of the symmetric matrix. """ return tuple(self.index.names) @property def index(self) -> pd.Index: """ Returns the pandas Index of the matrix. Returns ------- pd.Index Index of the symmetric matrix. """ return self.data.index @index.setter def index(self, index: pd.Index) -> None: """ Sets a new index for the symmetric matrix, ensuring it remains square. Parameters ---------- index : pd.Index New index for the symmetric matrix. Raises ------ TypeError If the index is not a pandas Index. ValueError If the index length does not match the number of rows/columns in the matrix. """ if not isinstance(index, pd.Index): raise TypeError("Index must be a pandas Index.") if len(index) != self.data.shape[0]: raise ValueError("Index length must match the number of rows/columns in the matrix.") self._data.index = index self._data.columns = index @property def values(self) -> np.ndarray: """ Returns the underlying data as a NumPy array. Returns ------- np.ndarray Underlying data array. """ return self.data.values @property def shape(self) -> tuple: """ Returns a tuple representing the dimensionality of the matrix. Returns ------- tuple Dimensionality of the symmetric matrix. """ return self.data.shape
[docs] def reindex(self, index: pd.Index, **kwargs) -> 'SymmetricMatrix': """ Reindex the symmetric matrix, filling new entries with 0. Parameters ---------- index : pd.Index New index for the symmetric matrix. **kwargs : additional keyword arguments Passed to pandas' reindex method. Returns ------- SymmetricMatrix Reindexed SymmetricMatrix instance. """ reindexed_data = self.data.reindex(index=index, columns=index, **kwargs).fillna(0.0) return self.__class__(data=reindexed_data)
[docs] def reorder_levels(self, order) -> 'SymmetricMatrix': """ Reorder the levels of a MultiIndex symmetric matrix. Parameters ---------- order : list New order for the levels. Returns ------- SymmetricMatrix SymmetricMatrix instance with reordered levels. Raises ------ TypeError If the index is not a MultiIndex. """ if not isinstance(self.index, pd.MultiIndex): raise TypeError("Index must be a MultiIndex to reorder levels.") data = self.data.copy() data = data.reorder_levels(order, axis='index') data = data.reorder_levels(order, axis='columns') return self.__class__(data=data)
class _Indexer: """ A custom accessor object for the SymmetricMatrix class, similar to pandas' .loc. It enables label-based selection and assignment while enforcing the symmetrical nature of a symmetric matrix. """ def __init__(self, matrix_obj: 'SymmetricMatrix'): self._obj = matrix_obj def __getitem__(self, key): """ Get data from the symmetric matrix. Parameters ---------- key : scalar or array-like Row and column labels. Returns ------- pd.DataFrame Selected data. """ return self._obj.data.loc[key, key] def __setitem__(self, key, value): """ Set data in the symmetric matrix, enforcing symmetry. Parameters ---------- key : scalar or array-like Row and column labels. value : scalar or array-like Value to set. Notes ----- This method automatically enforces symmetry and supports advanced indexing like slices and lists (e.g., cov.loc[:, 'a'] = some_values). """ rows = cols = key # Set the primary value self._obj.data.loc[rows, cols] = value # Determine the value for the symmetric assignment symmetric_value = value if hasattr(value, 'T'): # For DataFrames, Series, and numpy arrays, we need the transpose # for the symmetric assignment. Scalars do not have .T. symmetric_value = value.T # Set the symmetric value self._obj.data.loc[cols, rows] = symmetric_value
[docs] class CovarianceMatrix(SymmetricMatrix): """ Covariance matrix class wrapping pandas DataFrames. Attributes ---------- variance : pd.Series Series containing the variances (diagonal elements). """ @property def variance(self) -> pd.Series: """ Returns the diagonal of the covariance matrix (the variances). Returns ------- pd.Series Series containing the variances. """ return pd.Series(np.diag(self.data), index=self.index, name='variance')
[docs] class InverseProblem: """ Inverse problem class for estimating model states from observations. Represents a statistical inverse problem for estimating model states from observed data using Bayesian inference and linear forward operators. An inverse problem seeks to infer unknown model parameters (the "state") from observed data, given prior knowledge and a mathematical relationship (the forward operator) that links the state to the observations. This class provides a flexible interface for formulating and solving such problems using various estimators. Parameters ---------- estimator : str or type[Estimator] The estimator to use for solving the inverse problem. Can be the name of a registered estimator or an Estimator class. obs : pd.Series Observed data as a pandas Series, indexed by observation dimensions. prior : pd.Series Prior estimate of the model state as a pandas Series, indexed by state dimensions. forward_operator : ForwardOperator or pd.DataFrame Linear operator mapping model state to observations. Can be a ForwardOperator instance or a pandas DataFrame with appropriate indices and columns. prior_error : CovarianceMatrix Covariance matrix representing uncertainty in the prior state estimate. modeldata_mismatch : CovarianceMatrix Covariance matrix representing uncertainty in the observed data (model-data mismatch). constant : float or pd.Series or None, optional Optional constant term added to the forward model output. If not provided, defaults to zero. state_index : pd.Index or None, optional Index for the state variables. If None, uses the index from the prior. estimator_kwargs : dict, optional Additional keyword arguments to pass to the estimator. coord_decimals : int, optional Number of decimal places to round coordinate values for alignment (default is 6). Raises ------ TypeError If input types are incorrect. ValueError If input data dimensions are incompatible or indices do not align. Attributes ---------- obs_index : pd.Index Index of the observations used in the problem. state_index : pd.Index Index of the state variables used in the problem. obs_dims : tuple Names of the observation dimensions. state_dims : tuple Names of the state dimensions. n_obs : int Number of observations. n_state : int Number of state variables. posterior : pd.Series Posterior mean estimate of the model state. posterior_error : CovarianceMatrix Posterior error covariance matrix. posterior_obs : pd.Series Posterior mean estimate of the observations. prior_obs : pd.Series Prior mean estimate of the observations. xr : InverseProblem._XR Xarray interface for accessing inversion results as xarray DataArrays. Methods ------- solve() -> dict[str, pd.Series | CovarianceMatrix | pd.Series] Solves the inverse problem and returns a dictionary with posterior state, posterior error covariance, and posterior observation estimates. Notes ----- This class is designed for linear inverse problems with Gaussian error models, commonly encountered in geosciences, remote sensing, and other fields where model parameters are inferred from indirect measurements. It supports flexible input formats and provides robust alignment and validation of input data. """
[docs] def __init__(self, estimator: str | type[Estimator], obs: pd.Series, prior: pd.Series, forward_operator: ForwardOperator | pd.DataFrame, prior_error: SymmetricMatrix, modeldata_mismatch: SymmetricMatrix, constant: float | pd.Series | None = None, state_index: pd.Index | None = None, estimator_kwargs: dict = {}, coord_decimals: int = 6, ) -> None: """ Initialize the InverseProblem. Parameters ---------- estimator : str or type[Estimator] Estimator class or its name as a string. obs : pd.Series Observed data. prior : pd.Series Prior model state estimate. forward_operator : pd.DataFrame Forward operator matrix. prior_error : CovarianceMatrix Prior error covariance matrix. modeldata_mismatch : CovarianceMatrix Model-data mismatch covariance matrix. constant : float or pd.Series, optional Constant data, defaults to 0.0. state_index : pd.Index, optional Index for the state variables. estimator_kwargs : dict, optional Additional keyword arguments for the estimator. obs_aggregation : optional Observation aggregation method. coord_decimals : int, optional Number of decimal places for rounding coordinates. Raises ------ TypeError If any of the inputs are of the wrong type. ValueError If there are issues with the input data (e.g., incompatible dimensions). """ # Validate state_index if state_index is None: state_index = prior.index if not isinstance(state_index, pd.Index): raise TypeError("state_index must be a pandas Index.") # Set problem dimensions self.obs_dims = tuple(obs.index.names) self.state_dims = tuple(prior.index.names) # Handle forward operator if isinstance(forward_operator, ForwardOperator): forward_operator = forward_operator.data # Handle constant data if not isinstance(constant, pd.Series): constant_series = obs.copy(deep=True) constant_series[:] = constant if constant is not None else 0.0 constant = constant_series # Assert dimensions are in indices if not all(dim in forward_operator.index.names for dim in self.obs_dims): raise ValueError("Observation dimensions must be in the forward operator index.") if not all(dim in constant.index.names for dim in self.obs_dims): raise ValueError("Observation dimensions must be in the constant index.") if not all(dim in forward_operator.columns.names for dim in self.state_dims): raise ValueError("State dimensions must be in the forward operator columns.") if not all(dim in state_index.names for dim in self.state_dims): raise ValueError("State dimensions must be in the state index.") # Order levels if indexes are MultiIndex if isinstance(forward_operator.index, pd.MultiIndex): forward_operator = forward_operator.reorder_levels(self.obs_dims, axis='index') obs = obs.reorder_levels(self.obs_dims) modeldata_mismatch = modeldata_mismatch.reorder_levels(self.obs_dims) constant = constant.reorder_levels(self.obs_dims) if isinstance(forward_operator.columns, pd.MultiIndex): forward_operator = forward_operator.reorder_levels(self.state_dims, axis='columns') prior = prior.reorder_levels(self.state_dims) prior_error = prior_error.reorder_levels(self.state_dims) # Round index coordinates to avoid floating point issues during alignment round_coords = partial(round_index, decimals=coord_decimals) state_index = round_coords(state_index) obs.index = round_coords(obs.index) prior.index = round_coords(prior.index) forward_operator.index = round_coords(forward_operator.index) forward_operator.columns = round_coords(forward_operator.columns) prior_error.index = round_coords(prior_error.index) modeldata_mismatch.index = round_coords(modeldata_mismatch.index) constant.index = round_coords(constant.index) # Define the obs index as the intersection of the observation and forward operator obs indices obs_index = obs.index.intersection(forward_operator.index) if obs_index.empty: raise ValueError("No overlapping indices between observations and forward operator.") # Align inputs self.obs = obs.reindex(obs_index).dropna() self.prior = prior.reindex(state_index).dropna() self.forward_operator = forward_operator.reindex(index=obs_index, columns=state_index).fillna(0.0) self.prior_error = prior_error.reindex(state_index) self.modeldata_mismatch = modeldata_mismatch.reindex(obs_index) self.constant = constant.reindex(obs_index).fillna(0.0) # Store the problem indices self.obs_index = obs_index self.state_index = state_index # Initialize the estimator estimator_input = { 'z': self.obs.values, 'x_0': self.prior.values, 'H': self.forward_operator.values, 'S_0': self.prior_error.values, 'S_z': self.modeldata_mismatch.values, 'c': self.constant.values } self.estimator = self._init_estimator(estimator, estimator_input=estimator_input, **estimator_kwargs) # Build xarray interface self.xr = self._XR(self)
def _init_estimator(self, estimator: str | type[Estimator], estimator_input: dict, **kwargs) -> Estimator: """ Initialize the estimator. Parameters ---------- estimator : str or type[Estimator] The estimator class or its name as a string. estimator_input : dict Input parameters for the estimator, including: - 'z': Observed data - 'x_0': Prior state estimate - 'H': Forward operator - 'S_0': Prior error covariance - 'S_z': Model-data mismatch covariance - 'c': Constant data (optional) kwargs : dict Additional keyword arguments to pass to the estimator constructor. Returns ------- Estimator An instance of the specified estimator class. """ if isinstance(estimator, str): if estimator not in ESTIMATOR_REGISTRY: raise ValueError(f"Estimator '{estimator}' is not registered.") estimator_cls = ESTIMATOR_REGISTRY[estimator] elif isinstance(estimator, type) and issubclass(estimator, Estimator): estimator_cls = estimator else: raise TypeError("Estimator must be a string or a subclass of Estimator.") z = estimator_input['z'] x_0 = estimator_input['x_0'] H = estimator_input['H'] S_0 = estimator_input['S_0'] S_z = estimator_input['S_z'] c = estimator_input.get('c') return estimator_cls(z=z, x_0=x_0, H=H, S_0=S_0, S_z=S_z, c=c, **kwargs)
[docs] def solve(self) -> dict[str, pd.Series | SymmetricMatrix | pd.Series]: """ Solve the inversion problem using the configured estimator. Returns ------- dict[str, State | Covariance | Data] A dictionary containing the posterior estimates: - 'posterior': Pandas series with the posterior mean model estimate. - 'posterior_error': Covariance object with the posterior error covariance matrix. - 'posterior_obs': Pandas series with the posterior observation estimates. """ return { 'posterior': self.posterior, 'posterior_error': self.posterior_error, 'posterior_obs': self.posterior_obs, }
@property def n_obs(self) -> int: """ Number of observations. Returns ------- int Number of observations. """ return self.estimator.n_z @property def n_state(self) -> int: """ Number of state variables. Returns ------- int Number of state variables. """ return self.estimator.n_x @cached_property def posterior(self) -> pd.Series: """ Posterior state estimate. Returns ------- pd.Series Pandas series with the posterior mean model estimate. """ x_hat = self.estimator.x_hat return pd.Series(x_hat, index=self.state_index, name='posterior') @cached_property def posterior_obs(self) -> pd.Series: """ Posterior observation estimates. Returns ------- pd.Series Pandas series with the posterior observation estimates. """ y_hat = self.estimator.y_hat return pd.Series(y_hat, index=self.obs_index, name='posterior_obs') @cached_property def posterior_error(self) -> SymmetricMatrix: """ Posterior error covariance matrix. Returns ------- CovarianceMatrix CovarianceMatrix instance with the posterior error covariance matrix. """ S_hat = self.estimator.S_hat return SymmetricMatrix(pd.DataFrame(S_hat, index=self.state_index, columns=self.state_index)) @cached_property def prior_obs(self) -> pd.Series: """ Prior observation estimates. Returns ------- pd.Series Pandas series with the prior observation estimates. """ y_0 = self.estimator.y_0 return pd.Series(y_0, index=self.obs_index, name='prior_obs') class _XR: """ Xarray interface for Inversion data. """ def __init__(self, inversion: 'InverseProblem'): self._inversion = inversion def __getattr__(self, attr): """ Get an xarray representation of an attribute from the inversion object. Parameters ---------- attr : str Attribute name. Returns ------- xr.DataArray Xarray representation of the attribute. Raises ------ AttributeError If the attribute does not exist. TypeError If the attribute type is not supported. """ if attr == '_inversion': return self._inversion if hasattr(self._inversion, attr): obj = getattr(self._inversion, attr) if isinstance(obj, pd.Series): return self._series_to_xarray(series=obj, attr=attr) elif isinstance(obj, pd.DataFrame): return self._dataframe_to_xarray(df=obj, attr=attr) else: raise TypeError(f"Unable to represent {type(obj)} as Xarray.") else: raise AttributeError(f"'{type(self._inversion).__name__}' object has no attribute '{attr}'") def __setattr__(self, *args): """ Prevent setting attributes on the Xarray interface. Parameters ---------- *args : tuple Attribute name and value. Raises ------ AttributeError If attempting to set an attribute. """ if args[0] == '_inversion': super().__setattr__(*args) else: raise AttributeError(f"Cannot set attribute '{args[0]}' on Xarray interface.") @staticmethod def _series_to_xarray(series: pd.Series, attr) -> xr.DataArray: """ Convert a Pandas Series to an Xarray DataArray. Parameters ---------- series : pd.Series Pandas Series to convert. attr : str Attribute name. Returns ------- xr.DataArray Xarray DataArray representation of the series. """ series = series.copy() series.name = attr return series.to_xarray() @staticmethod def _dataframe_to_xarray(df: pd.DataFrame, attr) -> xr.DataArray: """ Convert a Pandas DataFrame to an Xarray DataArray. Parameters ---------- df : pd.DataFrame Pandas DataFrame to convert. attr : str Attribute name. Returns ------- xr.DataArray Xarray DataArray representation of the DataFrame. """ df = df.copy() if isinstance(df.columns, pd.MultiIndex): # Stack all levels of the columns MultiIndex into the index n_levels = len(df.columns.levels) s = df.stack(list(range(n_levels)), future_stack=True) else: s = df.stack(future_stack=True) return InverseProblem._XR._series_to_xarray(series=s, attr=attr)