import numpy as np
from typing import Union, Optional
from scipy.linalg import expm
from scipy.stats import multivariate_normal
import time

import sys
sys.path.append('./')
from est.models.Model import Model


class Model1(Model):
    """
    Model1:
    parameters: A (d by d), G (d by m)
    dX_t = AX_tdt + GdW_t
    """
    def __init__(self,
                dim: int,
                m: int):
        super().__init__(dim=dim, m=m, has_exact_density=True)
        self._Sigma: Optional[np.ndarry] = None
        self._expAdt: Optional[np.ndarray] = None
  
    def drift(self, x:Union[float, np.ndarray])-> Union[float, np.ndarray]:
        return self._params[:self._dim ** 2].reshape(self._dim,self._dim) @ x
    
    def diffusion(self, x:Union[float, np.ndarray])-> Union[float, np.ndarray]:
        return self._params[self._dim ** 2:].reshape(self._dim,self._m)
    
    @property
    def Sigma(self) -> np.ndarray:
        """ Access the params """
        return self._Sigma
    
    @Sigma.setter
    def Sigma(self, vals: np.ndarray):
        """ Set parameters, used by fitter to move through param space """
        self._Sigma = vals

    @property
    def expAdt(self) -> np.ndarray:
        """ Access the params """
        return self._expAdt
    
    @expAdt.setter
    def expAdt(self, vals: np.ndarray):
        """ Set parameters, used by fitter to move through param space """
        self._expAdt = vals

    def exact_density(self, xc: Union[float, np.ndarray], 
                      xt: Union[float, np.ndarray], tc: float, dt: float) -> float:
        """
        In the case where the exact transition density,
        P(Xt | Xc) is known
        :param xc: Union[float, np.ndarray], the current value
        :param xt: Union[float, np.ndarray], the value to transition to
        :param tc: float, the time of observing xc
        :param dt: float, the time step between xc and xt
        :param Sigma: Union[float, np.ndarray], the covariance matrix of the transition density
        :return: probability
        """
        
        mu = self._expAdt @ xc
        prob = multivariate_normal.pdf(xt, mean=mu, cov=self._Sigma, allow_singular=True)
        return prob
        
    def Sigma_MFD(self, dt: float):
        """
        Using matrix fraction decomposition to calculate the Sigma of transition density
        """
        d = self._dim
        A = self._params[:d ** 2].reshape(d, d)
        G = self._params[d ** 2:].reshape(d, self._m)
        Vc = np.vstack((np.zeros((d, d)), np.identity(d)))
        M1 = np.hstack((A, G @ G.T))
        M2 = np.hstack((np.zeros((d, d)), -A.T))
        M = np.vstack((M1, M2))
        V = expm(M * dt) @ Vc
        C = V[:d,]
        D = V[d:,]

        Sigma = C @ np.linalg.pinv(D)
        return Sigma
    
    def check_condition(self, x0:Union[float, np.ndarray]):
        d = self._dim
        x0 = x0.reshape(self._dim, 1)
        A = self._params[:d ** 2].reshape(d, d)
        G = self._params[d ** 2:].reshape(d, self._m)
        H = G @ G.T
        i = 0
        M = np.hstack((x0, H))
        v = A @ x0
        h = A @ H
        while i < d - 1:
            M = np.hstack((M, v, h))
            v = A @ v
            h = A @ h
            i += 1
        if np.linalg.matrix_rank(M) == d:
            return True
        else:
            return False
        