import os
import polars as pl
import numpy as np
import matplotlib.pyplot as plt
import mdtraj as md

from .Girsanov import GirsanovReweightingEstimator
from deeptime.markov import TransitionCountEstimator
from deeptime.markov.msm import MaximumLikelihoodMSM
from deeptime.util.validation import implied_timescales
from deeptime.clustering import BoxDiscretization
from .internal_coordinates import InternalCoordinates

def stationary(x):
    x1, x2 = x[..., 0], x[..., 1]
    energy_phi = 0.27*np.cos(2*x1) + 0.42*np.cos(3*x1)
    energy_psi = 0.45*np.cos(x2-np.pi) + 1.58*np.cos(2*x2-np.pi) + 0.44*np.cos(3*x2-np.pi)
    pi = np.exp(-energy_phi - energy_psi)
    pi = pi / pi.sum()
    return pi.reshape(-1)

def stationary_biased(x):
    x1, x2 = x[..., 0], x[..., 1]
    energy_phi = 0.27*np.cos(2*x1) + 0.42*np.cos(3*x1) + 0.5*x1**2
    energy_psi = 0.45*np.cos(x2-np.pi) + 1.58*np.cos(2*x2-np.pi) + 0.44*np.cos(3*x2-np.pi) + 0.5*x2**2
    pi = np.exp(-energy_phi - energy_psi)
    pi = pi / pi.sum()
    return pi.reshape(-1)

def stationary_biased(x):
    dim = x.shape[-1]
    if dim == 1:
        pi = np.exp(-4.*(x**2-1)**2-4.*(x+0.3)**2)
        pi = pi / pi.sum()
    elif dim == 2:
        x1, x2 = x[..., 0], x[..., 1]
        pi = np.exp(-(((x1**2-1)**2 + (x2**2-1)**2) + abs(x1 - x2)-0.5*1.5*((x1+1)**2+(x2-1)**2)))
        pi = pi / pi.sum()
    return pi.reshape(-1)

def main(data_dir, output_dir, lagtime, weight_type, angles_num=2):
    drop = 0

    if os.path.exists(f'{data_dir}/{angles_num}angles.npy'):
        angles = np.load(f'{data_dir}/{angles_num}angles.npy')
    else:
        data = md.load_dcd(f'{data_dir}/trajectory_total.dcd', top=f'{data_dir}/alanine-dipeptide.pdb')
        sample = data.xyz[None, ...] # (1, n_frames, n_atoms, 3)
        sample = sample[:, drop:]

        ic = InternalCoordinates(f'{data_dir}/alanine-dipeptide.pdb')
        angles = ic.get_all_angles(sample)[None, :, :angles_num] # (1, n_frames, angles_num)
        np.save(f'{data_dir}/{angles_num}angles.npy', angles)
    
    print(angles.shape)
    ### assignments and cluster
    if os.path.exists(f'{data_dir}/{angles_num}angles_msm/{angles_num}angles_assignments.npy'):
        assignments = np.load(f'{data_dir}/{angles_num}angles_msm/{angles_num}angles_assignments.npy')[:, drop:]
        cluster_centers = np.load(f'{data_dir}/{angles_num}angles_msm/{angles_num}angles_clustercenters.npy')[:, drop:]
    else:
        grid_box = BoxDiscretization(
            dim = angles.shape[-1],  # the number of dimensions the data lives in
            n_boxes = 36  # number of boxes per axis (can also be single int for all axes)
        )
        print(angles.shape)
        model = grid_box.fit(angles.reshape(-1, angles.shape[-1])).fetch_model()
        assignments = model.transform(angles.reshape(-1, angles.shape[-1])).reshape(*angles.shape[:2])
        cluster_centers = model.cluster_centers
        np.save(f'{data_dir}/{angles_num}angles_msm/{angles_num}angles_assignments.npy', assignments.reshape(1, -1))
        np.save(f'{data_dir}/{angles_num}angles_msm/{angles_num}angles_clustercenters.npy', cluster_centers)

    cluster_centers_unique = cluster_centers[np.unique(assignments)]

    if weight_type == 'gr':
        gr_info = pl.read_csv(f'{data_dir}/gr_total.csv')
        gr_weights = gr_info['logM'].to_numpy().reshape([1, -1])
        likeli_ratio = gr_info['BiasEnergy_kJmol'].to_numpy().reshape([1, -1])
        gr_weights = -gr_weights
        gr_weight_cum = np.cumsum(gr_weights, axis=1)
        gr_weight_cum = gr_weight_cum[:, lagtime:] - gr_weight_cum[:, :-lagtime] # (n_traj, len(self))
        ## normalizing to 1
        weights = np.exp(gr_weight_cum) / np.mean(np.exp(gr_weight_cum), axis=1, keepdims=True)
    elif weight_type == 'model':
        if os.path.exists(f'{output_dir}/model_eigen_lag{lagtime}.npz'):
            return
        gr_info = pl.read_csv(f'{data_dir}/gr_total.csv')
        likeli_ratio = gr_info['BiasEnergy_kJmol'].to_numpy().reshape([1, -1])
        weights = np.load(f'{output_dir}/model_weight_lag{lagtime}.npy')
        weights = weights.reshape(1, -1)

    # MSM 
    assignments = assignments.reshape(1, -1)
    if weight_type == 'gr' or weight_type == 'model':
        count_estimator = GirsanovReweightingEstimator(lagtime=lagtime,count_mode='sliding')
        likeli_ratio = np.ones_like(likeli_ratio)
        reweighting_factors = (likeli_ratio, weights)
        counts = count_estimator.fit(data=assignments,reweighting_factors=reweighting_factors).fetch_model()
    else:
        if os.path.exists(f'{data_dir}/{angles_num}angles_msm/eigen_lag{lagtime}.npz'):
            return
        counts = TransitionCountEstimator(lagtime=lagtime, count_mode='sliding').fit_fetch(assignments)
    msm = MaximumLikelihoodMSM(
        transition_matrix_tolerance=1e-6,
        # stationary_distribution_constraint=stationary_constraint,
    ).fit_fetch(counts)
    its = implied_timescales(msm)
    its_list = []
    for i in range(50):
        its_list.append(its.timescales_for_process(i).item())
    its_list = np.array(its_list)
    phi0 = msm.stationary_distribution
    left_eigenvectors = msm.eigenvectors_left()[:5]
    right_eigenvectors = msm.eigenvectors_right()[:5]
    eigenvalues = msm.eigenvalues()

    eigen_results = {
        'eigenvalues': eigenvalues,
        'left_eigenvectors': left_eigenvectors,
        'right_eigenvectors': right_eigenvectors,
        'state_centers': cluster_centers_unique,
        'phi0': phi0,
        'its': its_list,
    }
    if weight_type == 'model':
        np.savez(f'{output_dir}/model_eigen_lag{lagtime}.npz', **eigen_results)
    elif weight_type == 'gr':
        np.savez(f'{data_dir}/{angles_num}angles_msm/gr_eigen_lag{lagtime}.npz', **eigen_results)
    else:
        np.savez(f'{data_dir}/{angles_num}angles_msm/eigen_lag{lagtime}.npz', **eigen_results)


if __name__ == '__main__':
    data_dir = '' # TODO: add data dir

    output_dir = '' # TODO: add output dir

    lagtimes = np.arange(1, 51, 1)
    for lagtime in lagtimes:
        main(data_dir, output_dir, lagtime, 'model', angles_num=2)
        print(f'eigen_lag{lagtime} done!')

