"""PG-LDS-ID for learning latent-state models of discrete dynamics while enabling
prioritized identification of shared and disjoint dynamics between hybrid
discrete-continuous observations. """
import math_utils.hankel_utils as hankel_utils
import math_utils.matrix_utils as matrix_utils
import numpy as np
import PGLDSID.transformation_of_moments as xform

def PGLDSID(i, nx, Y, Z=None, n1=0, Z_horizon=None, A_method=0,
            xform_moments=True, correct_fano=True, correct_PSD=True,
            use_cov=False, input_cov_mats={}, debug_mode=False) -> dict:
  """PG-LDS-ID implementation.

  Without moment conversion, this implementation is equivalent to Ho-Kalman SSID
  with the *additional* capability of prioritization of shared dynamics between
  two modalities. (Ho-Kalman SSID algorithm, reference Katayama implementation
  chapter 7.7 or VODM Ch 3 Algorithm 2).

  Args:
    i: int. Horizon.
    nx: int. Number of latent states (shared + disjoint).
    Y: np.ndarray of shape (features, samples). Primary time-series data.
    Z: np.ndarray of shape (features, samples). Secondary time-series data.
    n1: int. Number of shared latent states.
    Z_horizon: int. Horizon for the secondary time series.
    A_method: int. Choice [0, 1]. 0: extended controllability least-squares
      approach for learning state dynamics. 1: extended observability least-
      squares approach for learning state dynamics. Default 0.
    xform_moments: bool. If True (default), PLDSID. If False, standard Ho-Kalman SSID.
    correct_fano, correct_PSD, use_cov: keyword arguments for transformation_of_moments()
    input_cov_mats: dict. Optional dependency injection used to directly provide
      analytical first and second moments to decouple error for moment conversion
      and system identification. Unit testing purposes only.

  Returns:
    A dictionary of parameters learned.
  """
  if n1 <= 0 or Z is None: # Stage 2 only --> PLDSID.
    return PLDSID(i, nx, Y, xform_moments=xform_moments,
                  correct_fano=correct_fano, correct_PSD=correct_PSD,
                  use_cov=use_cov, input_cov_mats=input_cov_mats,
                  debug_mode=debug_mode)

  parameters = {}
  ny, num_measurements = Y.shape
  nz, num_z_measurements = Z.shape
  if not Z_horizon: Z_horizon = i
  # Max possible number of measurements.
  j = hankel_utils.compute_hankel_parameters(num_measurements, i,
    num_second_observations=num_z_measurements, second_horizon=Z_horizon)
      
  if debug_mode:
    parameters['n1'] = n1
    parameters['nx'] = nx
    parameters['ny'] = ny
    parameters['nz'] = nz
    parameters['i'] = i
    parameters['Z_horizon'] = Z_horizon
    parameters['j'] = j
    parameters['Ytrain'] = Y
    parameters['Ztrain'] = Z

  if xform_moments:
    xform_inputs = {}
    if input_cov_mats: # Unit testing, dependency injection.
      xform_inputs = {
       'covS': input_cov_mats['covS'], 'corrS': input_cov_mats['corrS'],
       'meanS': np.tile(input_cov_mats['meanS'], (2*i, 1)).squeeze(),
       'meanZ': np.tile(input_cov_mats['meanZ'], (2*Z_horizon, 1)).squeeze(),
       'xcorrZS': input_cov_mats['xcorrZS'], 'xcovZS': input_cov_mats['xcovZS']}

    meanR, covR, num_min_moment = xform.transformation_of_moments(Y, i,
      correct_fano=correct_fano, correct_PSD=correct_PSD, use_cov=use_cov,
      input_stats=xform_inputs, debug_mode=debug_mode)

    # Future Z-past R Hankel matrix equation (9).
    Ci_zy, _ = xform.transformation_of_crosscovariate_moments(Y, Z, i,
        Z_horizon=Z_horizon, meanR=meanR, covR=covR, input_stats=xform_inputs,
        debug_mode=debug_mode)

    parameters['d'] = meanR[:ny, :]    
    if debug_mode:
      parameters['meanR'] = meanR
      parameters['num_min_moment'] = num_min_moment
      parameters['covR_stacked_hankel'] = covR
  
  else: # not xform_moments -> Ho-Kalman SSID + prioritized identification.
    if input_cov_mats: # Unit testing, dependency injection.
      Ci_zy = input_cov_mats['xcovZR']
      meanR = input_cov_mats['meanR']
      covR = input_cov_mats['covR']

    else: # not input_cov_mats
      Yp = hankel_utils.make_hankel(Y, i, j)
      Zf = hankel_utils.make_hankel(Z, Z_horizon, j, start=Z_horizon)
      Ci_zy = (Zf @ Yp.T) / (j - 1) # Future Z-past R Hankel matrix equation (9).

  # Parameter identification.
  full_U, full_S, full_V = np.linalg.svd(Ci_zy, full_matrices=False)
  S = np.diag(full_S[:n1])
  U = full_U[:,:n1]
  V = full_V[:n1,:]
  if debug_mode:
    parameters['Ci_zy'] = Ci_zy
    parameters['Ci_zy_S'] = full_S
    parameters['Ci_zy_V'] = full_V
    parameters['Ci_zy_U'] = full_U

  Oz = U @ S**(1/2.) # Z observability matrix.
  ctrlr_mat1 = S**(1/2.) @ V # Controllability matrix for shared latents.
  A11, Cz1 = extract_AC(Oz, nz)
  G1 = extract_G(ctrlr_mat1, ny)

  # Compute Hankel as per equation (2).
  if xform_moments or input_cov_mats:
    Ci = hankel_utils.extract_correlation(covR, ny, i, pair='fp')
  else: # not xform_moments and not input_cov_mats
    Yf = hankel_utils.make_hankel(Y, i, j, start=i)
    Ci = (Yf @ Yp.T) / (j - 1)

  # Extract Y observability matrix for shared dynamics.
  Oy1 = Ci @ matrix_utils.inverse(ctrlr_mat1, left=False)
  _, Cy1 = extract_AC(Oy1, ny)

  if debug_mode:
    parameters['Oy1'] = Oy1
    parameters['Oz'] = Oz
    parameters['ctrlr_mat1'] = ctrlr_mat1
    parameters['Ci'] = Ci
    parameters['Ci_S'] = np.linalg.svd(Ci, full_matrices=False)[1]
    parameters['Ci1_S'] = np.linalg.svd(Oy1 @ ctrlr_mat1, full_matrices=False)[1]

  # Estimate A11 by using least-squares and the extended controllability matrices,
  # section 3.2.1.
  if A_method == 0:
    A11 = ctrlr_mat1[:, :-ny] @ matrix_utils.inverse(ctrlr_mat1[:, ny:], left=False)

  n2 = nx - n1
  if n2 > 0: # Optionally model the disjoint dynamics in Y.
    Cz = np.concatenate((Cz1, np.zeros([nz, n2])), axis=1)

    # Subtract out the part of Y that is shared with Z.
    Ci2 = Ci - Oy1 @ ctrlr_mat1 # Equation (10) in manuscript.
    full_U, full_S, full_V = np.linalg.svd(Ci2, full_matrices=False)
    S = np.diag(full_S[:n2])
    U = full_U[:,:n2]
    V = full_V[:n2,:]
    if debug_mode:
      parameters['Ci2'] = Ci2
      parameters['Ci2_S'] = full_S
      parameters['Ci2_V'] = full_V
      parameters['Ci2_U'] = full_U

    Oy2 = U @ S**(1/2.)
    # Controllability matrix associated with unshared latent states.
    ctrlr_mat2 = S**(1/2.) @ V # Equivalent to matrix_utils.inverse(Oy2) @ Ci2.
    
    _, Cy2 = extract_AC(Oy2, ny)
    Cy = np.concatenate((Cy1, Cy2), axis=1)
    Oy = np.concatenate((Oy1, Oy2), axis=1)

    G2 = extract_G(ctrlr_mat2, ny)
    G = np.concatenate((G1, G2))
    ctrlr_mat = np.concatenate((ctrlr_mat1, ctrlr_mat2))
    
    if debug_mode:
      parameters['Oy2'] = Oy2
      parameters['Oy'] = Oy
      parameters['ctrlr_mat2'] = ctrlr_mat2
      parameters['ctrlr_mat'] = ctrlr_mat

    if A_method == 0: # Extract from concatenated controllability matrix, section 3.2.2.
      A21_22 = ctrlr_mat2[:, :-ny] @ matrix_utils.inverse(ctrlr_mat[:, ny:], left=False)
      A = np.concatenate((np.concatenate((A11, np.zeros((n1, n2))), axis=1), A21_22))
    elif A_method == 1:
      A, _ = extract_AC(Oy, ny)
    else:
      raise ValueError('Amethod must be 0 or 1: ', A_method)      

  else: # n2 == 0
    A, Cz, Cy, G = A11, Cz1, Cy1, G1

  parameters['Cz'] = Cz
  parameters['A'] = A
  parameters['Cy'] = parameters['C'] = Cy # 'Cy' is for extra bookkeeping.
  parameters['G'] = G

  # Add covariances to the parameters.
  if xform_moments:
    ff_mat = hankel_utils.extract_correlation(covR, ny, i, pair='ff')
    pp_mat = hankel_utils.extract_correlation(covR, ny, i, pair='pp')
    L0_ff = hankel_utils.compute_average_variance(ff_mat, ny, i)
    L0_pp = hankel_utils.compute_average_variance(pp_mat, ny, i)
    covariances = { 'L0': matrix_utils.make_symmetric((L0_ff + L0_pp) / 2) }
  else:
    covariances = compute_covariances(Y, i, j, Yf=Yf, Yp=Yp, debug_mode=debug_mode)

  for k, v in covariances.items(): parameters[k] = v
  if debug_mode: # Optionally add secondary signal covariances to the parameters.
    for k, v in compute_covariances(Z, Z_horizon, j, debug_mode=debug_mode).items():
      parameters['Z_'+k] = v
  return parameters

def PLDSID(i, nx, Y, xform_moments=True, correct_fano=True,
          correct_PSD=True, use_cov=False, input_cov_mats={}, debug_mode=False) -> dict:
  """PLDSID implementation, Buesing et al 2012.

  Moment conversion on top of Ho-Kalman SSID algorithm (Katayama implementation
  chapter 7.7 or VODM Ch 3 Algorithm 2).

  Args:
    i: int. Horizon.
    nx: int. Number of latent states.
    Y: np.ndarray of shape (features, samples). Time-series data.
    xform_moments: bool. If True (default), PLDSID. If False, standard Ho-Kalman SSID.
    correct_fano, correct_PSD, use_cov: keyword arguments for transformation_of_moments()
    input_cov_mats: dict. Optional dependency injection used to directly provide
      analytical first and second moments to decouple error for moment conversion
      and system identification. Unit testing purposes only.

  Returns:
    A dictionary of parameters learned.
  """
  ny, num_measurements = Y.shape
  j = hankel_utils.compute_hankel_parameters(num_measurements, i)
  
  parameters = {}
  if xform_moments: # PLDSID moment conversion.
    xform_inputs = {}
    if input_cov_mats: # Unit testing, dependency injection.
      xform_inputs = {
        'covS': input_cov_mats['covS'], 'corrS': input_cov_mats['corrS'],
        'meanS': np.tile(input_cov_mats['meanS'], (2*i, 1)).squeeze()}      

    # Moment conversion.
    meanR, covR, num_min_moment = xform.transformation_of_moments(Y, i,
                                                  correct_fano=correct_fano,
                                                  correct_PSD=correct_PSD,
                                                  use_cov=use_cov,
                                                  input_stats=xform_inputs)
    # Hankel matrix equation (2) manuscript.
    Ci = hankel_utils.extract_correlation(covR, ny, i, pair='fp')
    parameters['d'] = meanR[:ny, :]
    if debug_mode: parameters['num_min_moment'] = num_min_moment
  
  else: # not xform_moments --> Ho-Kalman SID
    if input_cov_mats: # Unit testing, dependency injection.
      meanR, covR = input_cov_mats['meanR'], input_cov_mats['covR']
      Ci = hankel_utils.extract_correlation(covR, ny, i, pair='fp')

    else: # no input_cov_mats
      Yp = hankel_utils.make_hankel(Y, i, j)
      Yf = hankel_utils.make_hankel(Y, i, j, start=i)
      Ci = (Yf @ Yp.T) / (j - 1) # Hankel equation (2).

  # Parameter identification.
  full_U, full_S, full_V = np.linalg.svd(Ci)
  S = np.diag(full_S[:nx])
  U = full_U[:, :nx]
  V = full_V[:nx, :]
  O = U @ S**(1/2.) # Observability matrix
  Ctrlr = S**(1/2.) @ V # Controllability matrix
  A, C = extract_AC(O, ny)
  G = extract_G(Ctrlr, ny)

  parameters['C'] = C
  parameters['Cy'] = parameters['C'] # 'Cy' is for extra bookkeeping.
  parameters['A'] = A
  parameters['G'] = G
  if debug_mode:
    parameters['full_S'] = full_S
    parameters['full_U'] = full_U
    parameters['full_V'] = full_V
    parameters['Ci'] = Ci
    parameters['O'] = O
    parameters['Ctrlr'] = Ctrlr

  # Add covariances to the parameters.
  if xform_moments:
    ff_mat = hankel_utils.extract_correlation(covR, ny, i, pair='ff')
    pp_mat = hankel_utils.extract_correlation(covR, ny, i, pair='pp')
    L0_ff = hankel_utils.compute_average_variance(ff_mat, ny, i)
    L0_pp = hankel_utils.compute_average_variance(pp_mat, ny, i)
    covariances = { 'L0': matrix_utils.make_symmetric((L0_ff + L0_pp) / 2) }
  else: # not xform_moments
    covariances = compute_covariances(Y, i, j, Yf=Yf, Yp=Yp, debug_mode=debug_mode)

  for k, v in covariances.items(): parameters[k] = v
  return parameters

######## Utility functions. ########
def compute_covariances(Y, i, j, Yf=None, Yp=None, debug_mode=False):
  """Compute all combinations of future-past cross-covariances.

  Args:
    Y: np.ndarray of size (features, samples). Data. Note: Y should be demeaned
      before being passed in as a parameter.
    i: int. Horizon.
    j: int. Number of samples per horizon (i.e., columns of Hankel matrix).
    Yf: np.ndarray of size (features*horizon, j). Optionally provide Yf directly.
    Yp: np.ndarray of size (features*horizon, j). Optionally provide Yp directly.
  """
  covariances = {'L0': np.cov(Y, ddof=1)}
  if not debug_mode: return covariances

  if Yp is None: Yp = hankel_utils.make_hankel(Y, i, j)
  if Yf is None: Yf = hankel_utils.make_hankel(Y, i, j, i)

  num_samples = j
  Sigma_YpYp = np.cov(Yp, ddof=1) # Lambda_i, Lambda_0 along diagonal.
  Sigma_YfYf = np.cov(Yf, ddof=1) # Lambda_i, should be roughly equivalent to above.
  Li = (Sigma_YpYp + Sigma_YfYf) / 2
  Sigma_YfYp = (Yf @ Yp.T) / (num_samples - 1) # C_i
  covariances['Li'], covariances['Ci'] = Li, Sigma_YfYp
  return covariances

# Extended observability matrix (least-squares) approach to extract A and C (see
# section 2.2 of manuscript).
def extract_AC(observability_matrix, ny):
  C = observability_matrix[:ny, :]
  O_minus = observability_matrix[:-ny, :] # omit last ny elements. O_floor
  O_plus = observability_matrix[ny:, :] # omit first ny elements. O_bar
  A = matrix_utils.inverse(O_minus) @ O_plus
  return A, C

# Refer to section 2.2., equation (4) of the manuscript.
def extract_G(ctrlr_mat, ny):
  return ctrlr_mat[:, -ny:]
