from openmmtools.integrators import LangevinIntegrator
from openmm.app import *
from openmm import *
from openmm.unit import *
from openmmtools.integrators import LangevinIntegrator
import scipy
from scipy.sparse import issparse
from scipy.sparse import csr_matrix
from deeptime.markov.tools import estimation as msmest
from deeptime.markov._transition_counting import TransitionCountModel, TransitionCountEstimator
from deeptime.util.types import ensure_dtraj_list
from tqdm import tqdm

import numpy as np
from typing import Optional, List


class LangevinSplittingGirsanov(LangevinIntegrator):
    r"""Creates a Langevin splitting integrator method that allows you to 
    output the reweighting factors on-the-fly in addition to standard 
    trajectory properties. 
    This custom integrator is based on the :class:`openmmtools.integrators.LangevinIntegrator`. 
    For information about the splitting rules see the documentation there. 
    However, not all functionalities of the :class:`LangevinIntegrator` can 
    be used, since the theoretical background for this method is so far only 
    available for selective (ABO) splitting schemes (mts integrators bias 
    overlaps, etc.):
    
        A B O       ->   "R V O"
        A B O B A   ->   "R V O V R"
        A O B O A   ->   "R O V O R" 
        B O A O B   ->   "V O R O V" 
        O B A B O   ->   "O V R V O" 
        
    Parameters
    ----------
        nstxout : int, write out frequency of simulation data
        
        temperature : np.unit.Quantity compatible with kelvin, 
                      default: 298.0*unit.kelvin Fictitious "bath" temperature

        collision_rate : np.unit.Quantity compatible with 1/picoseconds, 
                         default: 91.0/unit.picoseconds Collision rate

        timestep : np.unit.Quantity compatible with femtoseconds, 
                   default: 1.0*unit.femtoseconds Integration timestep

        constraint_tolerance : float, default: 1.0e-8
                               Tolerance for constraint solver
            

        References
        ----------
            [Schaefer and Keller, 2024] Implementation of Girsanov reweighting in OpenMM and Deeptime
            [Kieninger, Ghysbrecht and Keller, 2023] Girsanov reweighting for simulations 
                                                     of underdamped Langevin dynamics. Theory

        Example
        -------
            # Import integration class
            >>> import LangevinSplittingGirsanov as LSIWGPR
        
            # Create a splitting integrator.
            >>> integrator = LSIWGPR(nstxout=1,
            >>>                      temperature=298.0 * unit.kelvin,
            >>>                      collision_rate=1.0 / unit.picoseconds,
            >>>                      timestep=1.0 * unit.femtoseconds,
            >>>                      splitting="R O V O R" ,
            >>>                      constraint_tolerance=1e-8
            >>>                     )
        
            # Prepare openMM simulation object (not shown in detail)
            ...
            >>> simulation = simulation.Simulation(top.topology, system, integrator, platform)
        
            # Run openMM simulation
            >>> simulation.step(nsteps)
    """
    
    def __init__(self, nstxout, *args, **kwargs):
        '''
        Langevin integrator with Girsanov reweighting 
        based on openmmtools.integrators.LangevinIntegrator
        '''
        # Check splitting input keyword
        if np.array([kwargs['splitting'].split()[i] in ['A', 'B'] for i in range(len(kwargs['splitting'].split()))]).any():
            raise Exception("Please use different notation for the splitting algorithm.\n A = R\n B = V") 
        if kwargs['splitting'] not in self._delta_eta_table.keys():
            raise Exception("Unfortunately, the splitting algorithm is not implemented.\n Modify '_delta_eta_table' according to\n \"string of splitting elements\" : (number of random variables, expression for delta eta, update rule)\n to include.") 
        if self._delta_eta_table[kwargs['splitting']] == 'nan':
            raise Exception("Unfortunately, the splitting algorithm is not suitable for Girsanov reweighting.") 
            
        # Set global variables       
        self._nstxout   = nstxout
        self._splitting = kwargs['splitting']
        self._timestep = kwargs['timestep']
        
        # Initialize a new CustomIntegrator
        super(LangevinSplittingGirsanov, self).__init__(*args, **kwargs)
    
    # Extend inherited functions 
    def _add_global_variables(self):
        """Add global variables needed for reweighting.
        """
        # Add global variables defined in LangevinIntegrator
        super(LangevinSplittingGirsanov, self)._add_global_variables()
        # Add variables for reweighting 
        self._init_rwght()
    
    def _add_integrator_steps(self):
        """Add integrator steps, with update for random variable differences 
        (and bias forces -> in force group 1). The exponent of reweighting 
        factors M is computed.
        """
        # Integrate
        self.addUpdateContextState()
        self.addComputeTemperatureDependentConstants({"sigma": "sqrt(kT/m)"}) 
        
        # Add random number
        self._get_eta()
        
        # Create variable for reweighting factor M 
        U_idx = 0 
        O_idx = 0
        for i, step in enumerate(self._delta_eta_table[self._splitting][2].split()):
            if step == 'U':
                # Because biase needs to be updated depending on position index
                U_idx +=1
                self._get_delta_eta(U_idx-1)
            elif step == 'O':
                O_idx +=1
                function, _ = self._step_dispatch_table[step]
                # Because _add_O_step needs extra input of eta_idx
                function(O_idx-1)
            else:
                # Like in setup function but without possibility of mts
                function, _ = self._step_dispatch_table[step]
                function()
       
        # Trick to enable sumation over the path 
        # for n=0 and after tau steps of a path delta gives 0
        # so the integrals for the new path are recalculate  
        self.addComputeGlobal("ndivtau", "n/tau")
        self.addComputeGlobal("onedelta","1 - delta(ndivtau-floor(ndivtau))") 

        # Random number based reweighting factor logM(\eta)
        self._get_logM()
        # Increase timestep n for the next integration step
        self.addComputeGlobal("n", "n + 1")
        
    def _add_O_step(self, eta_idx):
        """Add a O step (stochastic velocity update) for reweighting (use Eta{idx}).
        """
        # update velocities with stored eta
        self.addComputePerDof("v", "(a * v) + (b * sigma * Eta{idx})".format(idx=eta_idx))
        self.addConstrainVelocities()

    # Additional functions needed to output reweighting factors
    @property
    def _delta_eta_table(self): 
        """The dictionary provides information to set up the reweighting factors.
        (number of random variables, expression for delta eta, update rule).
        The expression for delta eta are derived in [Kieninger and Keller, 2023].
        The timestep dependent factors a and b (d, f or d', f' in [Kieninger and Keller, 2023]) 
        are provided in LangevinIntegrator with the inizialisation of global variables. 
        To set correct units for b (dimensionless in LangevinIntegrator) one needs to multiply 
        by the term 'sigma*m' (sigma=sqrt(kT/m)). 
        """
        dispatch_table = {
            "R V O"     : (1, ['a/(b*sigma*m) * timestep * ff0'], "R U V O"),
            "R V O V R" : (1, ['1/(b*sigma*m) * (1 + a) * timestep/2 * ff0'], "R U V O V R"),
            "V R O R V" : ('nan'),
            "V R O R"   : ('nan'),
            "R O V O R" : (2, ['a/(b*sigma*m) * timestep * ff0'], "R U O V O R"),
            "V O R O V" : (2, ['a/(b*sigma*m) * timestep/2 * ff0', '1/(b*sigma*m) * timestep/2 * ff1'], "U V O R U O V"),
            "O V R V O" : (2, ['1/(b*sigma*m) * timestep/2 * ff0', 'a/(b*sigma*m) * timestep/2 * ff1'], "U O V R U V O")
        }
        return dispatch_table
    
    def _init_eta(self):
        """Initialize random numbers according to the splitting scheme.
        """
        for i in range(self._delta_eta_table[self._splitting][0]):
            self.addPerDofVariable("Eta{}".format(i),0)
                
    def _init_delta_eta(self):
        """Initialize difference in random numbers between biased and target 
        simulation according to the splitting scheme.
        """
        for i in range(len(self._delta_eta_table[self._splitting][1])):
            self.addPerDofVariable("DeltaEta{}".format(i),0)
            self.addPerDofVariable("ff{}".format(i),0)


    def _get_eta(self):
        """Add random numbers drawn from Gaussian distribution 
        according to the splitting scheme.
        """
        for i in range(self._delta_eta_table[self._splitting][0]):
            self.addComputePerDof("Eta{}".format(i),"gaussian")
                
    def _get_delta_eta(self, idx):
        """Add the difference in random numbers between the target and biased 
        simulation according to the splitting scheme. The forces associated 
        with the bias must be updated for the respective time step. An update 
        rule is given in _delta_eta_table()). The force of the bias must be in 
        force group 1.
        """
        # Update forces
        self.addComputePerDof("ff{}".format(idx),"f1")  
        # Set delta eta 
        self.addComputePerDof("DeltaEta{}".format(idx), self._delta_eta_table[self._splitting][1][idx])  
        
    def _get_logM(self):
        """Add the pre-reweighting factor M in terms of the difference in random numbers
        between the target and biased simulation. 
        """
        # AOBOA has a combined random number 
        # ToDo: check
        if self._splitting=='R O V O R':
            self.addComputePerDof("Eta0","a*Eta0+Eta1")   
            SOP="(Eta0 * DeltaEta0)/(a*a+1) + 0.5 * (DeltaEta0 * DeltaEta0)/(a*a+1)"
            
        else:
            # Sum over the path (SOP) of length \tau 
            SOP = str()
            for i in range(len(self._delta_eta_table[self._splitting][1])):
                SOP+="Eta{idx} * DeltaEta{idx} + 0.5 * (DeltaEta{idx} * DeltaEta{idx})".format(idx=i)
                if i+1 < len(self._delta_eta_table[self._splitting][1]):
                    SOP+=" + "
        self.addComputeSum("SumOverPath", SOP) 
        self.addComputeGlobal('M', "M * onedelta + SumOverPath")

    def _init_rwght(self):
        """ Initialize
        """
        ## Add a variable for the timestep n and size
        self.addGlobalVariable("n", 0)
        self.addGlobalVariable("timestep", self._timestep)
        
        ## Add a variable for \tau the length of a path \omega; 
        ## here given by the write-out freuquency nstxout
        self.addGlobalVariable("tau", self._nstxout) 
        
        ## Add variables to enable sumation over the path 
        ## cf. J. Chem. Phys. 146, 244112 (2017) EQ:(29)
        self.addGlobalVariable("ndivtau", 0)
        self.addGlobalVariable("onedelta", 0)
        
        ## Abb variable give the sum over the path                     
        ## cf. J. Chem. Phys. 146, 244112 (2017) EQ:(25) 
        self.addGlobalVariable("SumOverPath", 0)
        self.addGlobalVariable("M", 0)
        
        ## Add variable for \eta and \Delta\eta needed to give reweighting factor M(\eta)
        ## cf. J. Chem. Phys. 154, 094102 (2021) EQ:(10)
        self._init_eta()
        self._init_delta_eta()


class GirsanovReweightingEstimator(TransitionCountEstimator):

    def __init__(self, lagtime: int, count_mode: str, n_states=None, sparse=False):
        super().__init__(lagtime=lagtime, count_mode=count_mode,  n_states=n_states, sparse=sparse)
        self.lagtime = lagtime 
        self.count_mode = count_mode
        self.sparse = sparse
        self.n_states = n_states

    def fetch_model(self) -> Optional[TransitionCountModel]:
        r"""
        Yields the latest estimated :class:`TransitionCountModel`. Might be None if fetched before any data was fit.
        Returns
        -------
        The latest :class:`TransitionCountModel` or None.
        """
        return self._model

    def fit(self, data, reweighting_factors, *args, **kw):  
        r""" Counts transitions at given lag time according to configuration of the estimator.
        Parameters
        ----------
        data : array_like or list of array_like
            discretized trajectories; check for same length of random number array :code:`eta`, 
            discrete trajectory and reweighting factors; note, most integrator give trajectories of length
        reweighting_factors: tuple 
            tuple of reweighting factors :code:`(g,M)`
             :code:`g` is the likelihood ratio between probability measures with :code:`dim=len(dtraj)`. 
             :code:`M` is the likelihood ratio between the path probabilitiy densities with :code:`dim=len(dtraj)`. 
        """
        dtrajs = ensure_dtraj_list(data)

        # Compute count matrix
        count_matrix = GirsanovReweightingEstimator.count(self.count_mode , dtrajs, self.lagtime, 
                                                      reweighting_factors=reweighting_factors,
                                                      sparse=self.sparse)
        
        # basic count statistics like in deeptime._transition_counting.TransitionCountEstimator
        from deeptime.markov import count_states
         
        histogram = count_states(dtrajs, ignore_negative=True) 
        
        if self.n_states is not None and self.n_states > count_matrix.shape[0]:
            histogram = np.pad(histogram, pad_width=[(0, self.n_states - count_matrix.shape[0])])
            n_pad = self.n_states - count_matrix.shape[0]
            if issparse(count_matrix):
                indptr = np.pad(count_matrix.indptr, pad_width=[(0, n_pad)], 
                                constant_values=count_matrix.indptr[-1])
                count_matrix = csr_matrix((count_matrix.data, count_matrix.indices, indptr),
                                                       shape=(self.n_states, self.n_states))
            else:
                count_matrix = np.pad(count_matrix, pad_width=[(0, n_pad), (0, n_pad)])

        self._model = TransitionCountModel(
            count_matrix = count_matrix, counting_mode=self.count_mode, lagtime=self.lagtime, 
            state_histogram = histogram
        )
        return self

    @staticmethod
    def count(count_mode: str, dtrajs: List[np.ndarray], lagtime: int, reweighting_factors: tuple, 
              sparse: bool = False):
        r""" Computes a reweighted count matrix according to Girsanov path reweighting for Markov state models 
        based on the sliding mode, discrete trajectories, a lagtime, the precomputed reweighting factors and
        whether to use sparse matrices, :footcite:`schaefer2024implementation`.
        Parameters
        ----------
        count_mode : str
            The counting mode to be used so far is "sliding".
            See :meth:`__init__` for a more detailed description.
        dtrajs : array_like or list of array_like
            Discrete trajectories, i.e., a list of arrays which contain non-negative integer values. A single ndarray
            can also be passed, which is then treated as if it was a list with that one ndarray in it.
        lagtime : int
            Distance between two frames in the discretized trajectories under which their potential change of state
            is considered a transition.
        sparse : bool, default=False
            Whether to use sparse matrices or dense matrices. Sparse matrices can make sense when dealing with a lot of
            states.
        Returns
        -------
        count_matrix : (N, N) ndarray or sparse array
            The computed count matrix. Can be ndarray or sparse depending on whether sparse was set to true or false.
            N is the number of encountered states, i.e., :code:`np.max(dtrajs)+1`.
        Example
        -------
        >>> from deeptime.markov import GirsanovReweightingEstimator
        >>> dtrajs = np.array([0, 0, 1, 1, 0, 1, 0, 1, 1])
        >>> _reweighting = (np.array([1., 1., 1., 1., 1., 1., 1., 1., 1.]),np.array([0., 0., 0., 0., 0., 0., 0., 0., 0.]))
        >>> reweighted_counts_estimator = GirsanovReweightingEstimator(lagtime=2,
        ...                                                            count_mode='sliding')
        >>> reweighted_counts = reweighted_counts_estimator.fit(dtrajs,
        ...                                                     reweighting_factors=_reweighting).fetch_model()
        >>> np.testing.assert_equal(reweighted_counts.count_matrix, np.array([[1., 3.],[1., 2.]]))
        >>> print(reweighted_counts.count_matrix)
        [[1. 3.]
         [1. 2.]]
        """ 
        if count_mode == 'sliding':
            # count_matrix = msmest.girsanov_reweighted_count_matrix(dtrajs, lagtime, reweighting_factors, 
            #                                                        sliding=True, sparse_return=sparse)
            count_matrix = count_matrix_coo2_mult(dtrajs, lagtime, sparse=sparse, 
                                                  reweighting_factors=reweighting_factors, 
                                                  reweighting_method='nn')
        else:
            raise ValueError('Count mode {} is not compatible with the Girsanov reweighting estimator, only "sliding" supported.'.format(count_mode))
        return count_matrix


'''
referrence: 
https://github.com/deeptime-ml/deeptime/blob/6280fdaf54f276011d0126f83ed4f8d2adf319d6/deeptime/markov/tools/estimation/sparse/count_matrix.py
'''
def count_matrix_coo2_mult(dtrajs, lag, reweighting_factors=None, reweighting_method=None,
                           sliding=True, sparse=True, nstates=None):
    """Generate a count matrix from a given list discrete trajectories.

    The generated count matrix is a sparse matrix in compressed
    sparse row (CSR) or numpy ndarray format.

    Parameters
    ----------
    dtraj : list of ndarrays
        discrete trajectories
    lag : int
        Lagtime in trajectory steps
    reweighting: tuple, optional
        Enforce a count-matrix with reweighting factors shape=(g,logM), :footcite:`schaefer2024implementation`. g is the state-space 
        probability density ratio. logM is the pre-expression of the M reweighting factor, 
        negative sign and exponent is realised in summation over the path of length lagtime.  
        The tuple gives two lists of ndarrays for g and logM, which must have the shape of dtraj.
    sliding : bool, optional
        If true the sliding window approach
        is used for transition counting
    sparse : bool (optional)
        Whether to return a dense or a sparse matrix
    nstates : int, optional
        Enforce a count-matrix with shape=(n_states, n_states). If there are
        more states in the data, this will lead to an exception.

    Returns
    -------
    C : scipy.sparse.csr_matrix or numpy.ndarray
        The countmatrix at given lag in scipy compressed sparse row
        or numpy ndarray format.
    """
    dtrajs = ensure_dtraj_list(dtrajs)
    # Determine number of states
    if nstates is None:
        from deeptime.markov import number_of_states
        nstates = number_of_states(dtrajs)
    rows = []
    cols = []
    # collect transition index pairs
    for dtraj in tqdm(dtrajs):
        if dtraj.size > lag:
            if sliding:
                rows.append(dtraj[0:-lag])
                cols.append(dtraj[lag:])
            else:
                rows.append(dtraj[0:-lag:lag])
                cols.append(dtraj[lag::lag])
    # is there anything?
    if len(rows) == 0:
        raise ValueError('No counts found - lag ' + str(lag) + ' may exceed all trajectory lengths.')
    # feed into one COO matrix
    row = np.concatenate(rows)
    col = np.concatenate(cols)
    ## choose option for including reweighting factors g and M
    if reweighting_factors is None:
        data = np.ones(row.size)
    elif type(reweighting_factors) is tuple:
        if sliding:
            g_factors, M_factors = reweighting_factors
            if len(g_factors.shape) == 1:
                g_factors = np.expand_dims(g_factors, axis=0)
                M_factors = np.expand_dims(M_factors, axis=0)
            g_factors = list(g_factors)
            M_factors = list(M_factors)
            factors = []
            for g,M in zip(g_factors, M_factors):
                if reweighting_method == 'gr':
                    m = M.cumsum()
                    m[lag:] = m[lag:] - m[:len(m)-lag]
                    m = m[(lag):]
                    m = np.exp(m)
                elif reweighting_method == 'nn':
                    assert M.shape == g[0:-lag].shape, print('The shape of M is wrong:', M.shape, g[0:-lag].shape)
                    m = M
                factors.append(g[0:-lag]*m)
        else:
            raise NotImplementedError('Only the sliding scheme is implemented.')
        factors = np.concatenate(factors)
        data = factors
    else:
        raise NotImplementedError('An input format other than a tuple (g,M) for the reweighting factors is not implemented.')
    C = scipy.sparse.coo_matrix((data, (row, col)), shape=(nstates, nstates))
    # export to output format
    if sparse:
        return C.tocsr()
    else:
        return C.toarray()