import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from deeptime.markov import TransitionCountEstimator
from deeptime.markov.msm import MaximumLikelihoodMSM
from deeptime.util.validation import implied_timescales
from deeptime.clustering import KMeans
from deeptime.clustering import BoxDiscretization
from deeptime.markov.tools.analysis import rdl_decomposition
from scipy.interpolate import make_interp_spline
from scipy.interpolate import griddata
import joblib
import time
from tqdm import tqdm

from .Girsanov import GirsanovReweightingEstimator

def stationary(x):
    dim = x.shape[-1]
    if dim == 1:
        pi = np.exp(-4.*(x**2-1)**2)
        pi = pi / pi.sum()
    elif dim == 2:
        x1, x2 = x[..., 0], x[..., 1]
        pi = np.exp(-(((x1**2-1)**2 + (x2**2-1)**2) + abs(x1 - x2)))
        pi = pi / pi.sum()
    return pi.reshape(-1)

def stationary_biased(x):
    dim = x.shape[-1]
    if dim == 1:
        pi = np.exp(-4.*(x**2-1)**2-4.*(x+0.3)**2)
        pi = pi / pi.sum()
    elif dim == 2:
        x1, x2 = x[..., 0], x[..., 1]
        pi = np.exp(-(((x1**2-1)**2 + (x2**2-1)**2) + abs(x1 - x2)-0.5*1.5*((x1+1)**2+(x2-1)**2)))
        pi = pi / pi.sum()
    return pi.reshape(-1)

class MSM(object):
    def __init__(self, n_states):
        self.n_states = n_states

    def __call__(self, sample_path, sample, save_path, lagtime,
                 reweighting_path=None, log_likeli_ratio=None,
                 kmeans=None,
                 ):
        start_point = 0
        if sample is None:
            sample = np.load(f'{sample_path}.npz')
            sample = sample['sample'][:, start_point:]
        dim = sample.shape[-1]
        if not os.path.exists(f'{sample_path}_assignments.npy'):
            print('begin cluster')
            assignments, cluster_model = self.cluster(sample.reshape(-1, dim), method='box')
            assignments = assignments.reshape(*sample.shape[:2])
            state_centers = cluster_model.cluster_centers
            np.save(f'{sample_path}_assignments', assignments)
            np.save(f'{sample_path}_clustercenters.npy', state_centers)
            print('cluster finish!')
        else:
            print('load cluster')
            assignments = np.load(f'{sample_path}_assignments.npy')[:, start_point:]
            state_centers = np.load(f'{sample_path}_clustercenters.npy')
        dim = sample.shape[-1]
        reweighting_factors = None
        if reweighting_path is not None:
            reweighting = reweighting_path[:, start_point:]
        if log_likeli_ratio is not None:
            likeli_ratio = np.exp(log_likeli_ratio)[:, start_point:]
            reweighting_factors = (likeli_ratio, reweighting)
        
        state_centers_unique = state_centers[np.unique(assignments)]
        if not os.path.exists(f'{save_path}_eigen_lag{lagtime}.npz'):
            print('begin counting')
            if reweighting_factors is not None:
                count_estimator = GirsanovReweightingEstimator(lagtime=lagtime,count_mode='sliding')
                counts = count_estimator.fit(data=assignments,reweighting_factors=reweighting_factors).fetch_model()
            else:
                counts = TransitionCountEstimator(lagtime=lagtime, count_mode='sliding').fit_fetch(assignments)
            print('counting finishes!')
            msm = MaximumLikelihoodMSM(
                transition_matrix_tolerance=1e-8,
            ).fit_fetch(counts)
            its = implied_timescales(msm)
            its_list = []
            for i in range(30):
                its_list.append(its.timescales_for_process(i).item())
            its_list = np.array(its_list)
            phi0 = msm.stationary_distribution
            left_eigenvectors = msm.eigenvectors_left()[:5]
            right_eigenvectors = msm.eigenvectors_right()[:5]
            eigenvalues = msm.eigenvalues()

            eigen_results = {
                'eigenvalues': eigenvalues,
                'left_eigenvectors': left_eigenvectors,
                'right_eigenvectors': right_eigenvectors,
                'state_centers': state_centers_unique,
                'phi0': phi0,
                'its': its_list,
            }
            np.savez(f'{save_path}_eigen_lag{lagtime}.npz', **eigen_results)
        
    
    def cluster(self, sample, method='kmeans'):
        if method == 'kmeans':
            dim = sample.shape[-1]
            model = KMeans(n_clusters=self.n_states, max_iter=2000)
            model.fit(sample.reshape(-1, dim))
            assignments = model.transform(sample.reshape(-1, dim)) 
        else:
            grid_box = BoxDiscretization(
                dim = sample.shape[-1],  # the number of dimensions the data lives in
                n_boxes = self.n_states  # number of boxes per axis (can also be single int for all axes)
            )
            model = grid_box.fit(sample).fetch_model()
            assignments = model.transform(sample).reshape(-1,)
        return assignments, model


if __name__ == '__main__':
    n_states = 40
    lag_times = np.arange(50, 1001, 50)

    data_file = '' # TODO: add data file
    output_file = '' # TODO: add output file

    for lag_time in lag_times:
        msm = MSM(n_states=n_states)
        msm(f'{data_file}/biased_trajtotal', None,
            f'{data_file}/msm_biased_trajtotal', lagtime=lag_time)

        msm = MSM(n_states=n_states)
        msm(f'{data_file}/unbiased_trajtotal', None,
            f'{data_file}/msm_unbiased_trajtotal', lagtime=lag_time)

        data = np.load(f'{data_file}/biased_trajtotal.npz')
        msm = MSM(n_states=n_states)
        gr_weights = np.load(f'{data_file}/biased_trajtotal.npz')['gr_weights']
        gr_weight_cum = np.cumsum(gr_weights, axis=1)
        gr_weight_cum = gr_weight_cum[:, lag_time:] - gr_weight_cum[:, :-lag_time] # (n_traj, len(self))
        ## normalizing to 1
        gr_weights = np.exp(gr_weight_cum) / np.mean(np.exp(gr_weight_cum), axis=1, keepdims=True)
        msm(f'{data_file}/biased_trajtotal', None,
        f'{data_file}/msm_biased_trajtotal_regr', lagtime=lag_time,
        reweighting_path=gr_weights,
        log_likeli_ratio=np.ones_like(data['log_likeli_ratio']))
        
        data = np.load(f'{data_file}/biased_trajtotal.npz')
        msm = MSM(n_states=n_states)
        msm(f'{data_file}/biased_trajtotal', None,
            f'{output_file}/msm_biased_trajtotal_renn', lagtime=lag_time,
            reweighting_path=np.load(f'{output_file}/model_weight_lag{lag_time}.npy'),
            log_likeli_ratio=np.ones_like(data['log_likeli_ratio']))
        
        print(f'finish msm lagtime {lag_time}!')
