import numpy as np
import sys
from scipy.linalg import expm
sys.path.append('./')
from est.TransitionDensity import TransitionDensity
from est.fit.Minimizer import ScipyMinimizer
from est.fit.likelihoodestimator import LikelihoodEstimator


class MLE(LikelihoodEstimator):
    def __init__(self,
                 sample: np.ndarray,
                 dt: float,
                 density: TransitionDensity,
                 minimizer: ScipyMinimizer = ScipyMinimizer(),
                 t0: float = 0):
        """
        Maximimum likelihood estimator based on some analytical represenation for the transition density.
        e.g. ExactDensity, EulerDensity, ShojiOzakiDensity, etc.
        :param sample: array, N paths draw from some theoretical model
        :param dt: float, time step (time between diffusion steps)
            Either supply a constant dt for all time steps, or supply a set of dt's equal in length to the sample
        :param minimizer: Minimizer, the minimizer that is used to maximize the likelihood function. If none is
            supplied, then ScipyMinimizer is used by default
        :param t0: Union[float, np.ndarray], optional parameter, if you are working with a time-homogenous model,
            then this doesnt matter.
        """
        
        super().__init__(sample=sample, dt=dt, model=density.model,
                         minimizer=minimizer, t0=t0)
        self._density = density


    def log_likelihood_negative(self, params: np.ndarray) -> float:
        self._model.params = params
        if self._model._has_exact_density:  # model1 
            Sigma = self._model.Sigma_MFD(self._dt)
            self._model.Sigma = Sigma

            d = self._sample.shape[2]
            A = self._model._params[:d ** 2].reshape(d, d)
            self._model.expAdt = expm(A * self._dt)

        log_neg = 0
        for i in range(self._sample.shape[0]):
            for j in range(self._sample.shape[1]-1):
                log_neg += -np.log(np.maximum(self._min_prob,
                                              self._density(xc=self._sample[i,j], 
                                                            xt=self._sample[i,j+1],
                                                            tc=self._t0,
                                                            dt=self._dt)))
        return log_neg/self._sample.shape[0]
    



