import numpy as np
from typing import Union
from scipy.stats import multivariate_normal
import sys
sys.path.append('./')
from est.models.Model import Model


class TransitionDensity:
    def __init__(self, model: Model):
        """
        Class which represents the transition density for a model
        :param model: the SDE model, referenced during calls to the transition density
        """
        self._model = model

    @property
    def model(self) -> Model:
        """ Access to the underlying model """
        return self._model


class ExactDensity(TransitionDensity):
    def __init__(self, model: Model):
        """
        Class which represents the exact transition density for a model, and implements a __call__ method to evalute the
        transition density (bound to the model)
        :param model: the SDE model, referenced during calls to the transition density
        """
        super().__init__(model=model)
    
    def __call__(self,
                 xc: Union[float, np.ndarray],
                 xt: Union[float, np.ndarray],
                 tc: Union[float, np.ndarray],
                 dt: float) -> Union[float, np.ndarray]:
        """
        The exact transition density (when applicable)
        :param xc: Union[float, np.ndarray], the current value
        :param xt: Union[float, np.ndarray], the value to transition to  (must be same dimension as x0)
        :param tc: float, the time of at which to evalate the coefficients. Irrelevant For time inhomogenous models
        :param dt: float, the time step between x0 and xt
        :return: probability (same dimension as x0 and xt)
        """
        return self._model.exact_density(xc=xc, xt=xt, tc=tc, dt=dt)
    

class EulerMaruyamaDensity(TransitionDensity):
    def __init__(self, model: Model):
        """
        Class which represents the Euler-Maruyama approximation transition density for a model, and implements a __call__ method to evalute the
        transition density (bound to the model)
        :param model: the SDE model, referenced during calls to the transition density
        """
        super().__init__(model=model)
    
    def __call__(self,
                 xc: Union[float, np.ndarray],
                 xt: Union[float, np.ndarray],
                 tc: Union[float, np.ndarray],
                 dt: float) -> Union[float, np.ndarray]:
        """
        The transition density (when applicable)
        :param xc: Union[float, np.ndarray], the current value
        :param xt: Union[float, np.ndarray], the value to transition to  (must be same dimension as x0)
        :param tc: float, the time of at which to evalate the coefficients. Irrelevant For time inhomogenous models
        :param dt: float, the time step between xc and xt 
        :return: probability (same dimension as xc and xt)
        """
        mu = xc + self._model.drift(xc) * dt
        Sigma = self._model.diffusion(xc) @ self._model.diffusion(xc).T * dt
        d = self._model.dim
        if np.linalg.det(Sigma) <= 1e-30:
            return multivariate_normal.pdf(xt, mean=mu, cov=Sigma, allow_singular=True)
        else:
            return (2*np.pi)**(-d/2)*(np.linalg.det(Sigma)**(-1/2))*np.exp(-1/2*(xt-mu)@np.linalg.pinv(Sigma)@(xt-mu))
        