import json
import os

import numpy as np
import pyemma
import mdtraj
from tqdm import tqdm
from statsmodels.tsa.stattools import acf
import pickle
import deeptime
from scipy.stats import entropy
from scipy.spatial.distance import jensenshannon


def _compute_feature(
    traj: mdtraj.Trajectory, feature: str, global_reference: mdtraj.Trajectory
):
    """Return a (n_frames, 1) numpy array for the chosen observable."""

    if feature == "gr":
        return mdtraj.compute_rg(traj).reshape(-1, 1)
    if feature == "secondary":
        dssp = mdtraj.compute_dssp(traj)
        ss_codes = ["H", "G", "I", "E", "B"]
        frac = np.sum(np.isin(dssp, ss_codes), axis=1) / dssp.shape[1]
        return frac.reshape(-1, 1)
    if feature == "rmsd":
        return mdtraj.rmsd(traj, global_reference).reshape(-1, 1)
    raise ValueError(f"Unknown feature {feature!r}")


def _fit_kmeans_on_reference(ref_trajs, features, n_states):
    """Fit k-means only on reference trajectories; return model and discrete refs."""
    global_ref = ref_trajs[0][0]  # first frame of first reference trajectory\\

    ref_feat_mats = [
        np.hstack([_compute_feature(t, f, global_ref) for f in features])
        for t in ref_trajs
    ]
    kmeans = deeptime.clustering.KMeans(
        n_clusters=n_states, max_iter=100, fixed_seed=2137, n_jobs=1
    ).fit(ref_feat_mats)
    disc_ref = [kmeans.transform(feat) for feat in ref_feat_mats]
    return kmeans, disc_ref, global_ref


def _stationary_dist(discrete_trajs, lag):
    counts = deeptime.markov.TransitionCountEstimator(
        lag, count_mode="sliding"
    ).fit_fetch(discrete_trajs)
    msm = deeptime.markov.msm.MaximumLikelihoodMSM(
        allow_disconnected=True, reversible=True
    ).fit_fetch(counts)
    return msm.stationary_distribution, msm.state_symbols()


def load_tica_model(name, temp, tica_models_path):
    file_name = os.path.join(tica_models_path, f"{name}_{temp}_tica_model.pkl")
    if os.path.exists(file_name):
        print(f"Loading saved TICA model from {file_name}")
        with open(file_name, "rb") as f:
            tica_model = pickle.load(f)
        return tica_model
    else:
        raise FileNotFoundError(f"TICA model file not found: {file_name}")


def plot_triangular_flux(ref_flux, gen_flux, ax, corr=None, cmap="Reds"):
    n = ref_flux.shape[0]

    sqrt_ref = np.sqrt(ref_flux)
    sqrt_gen = np.sqrt(gen_flux)

    combined = np.full((n, n), np.nan, dtype=float)

    for i in range(n):
        for j in range(n):
            if i < j:
                combined[i, j] = sqrt_ref[i, j]
            elif i > j:
                combined[i, j] = sqrt_gen[i, j]

    im = ax.imshow(combined, origin="lower", cmap=cmap, interpolation="nearest")
    ax.plot([0, n - 1], [0, n - 1], color="black", lw=1.5)

    ax.set_xticks([])
    ax.set_yticks([])

    if corr is not None:
        ax.set_title(f"Flux Spearman R={corr:.2f}")


def compute_secondary_structure_fraction(
    traj, secondary_codes=["H", "G", "I", "E", "B"]
):
    """
    Compute the fraction of residues adopting a secondary structure per frame in a trajectory.

    Parameters:
        traj (md.Trajectory): MDTraj trajectory.
        secondary_codes (list): Secondary structure codes to consider.

    Returns:
        np.ndarray: Array of fractions (one per frame).
    """
    dssp = mdtraj.compute_dssp(traj)
    fraction = np.sum(np.isin(dssp, secondary_codes), axis=1) / dssp.shape[1]
    return fraction


def compare_secondary_structure_mdcath(
    data_dir,
    pdb_dir,
    name,
    temp=320,
    random_frames=None,
    truncate=None,
    gen_replicas=None,
):
    """
    Compare the secondary structure fractions between concatenated reference trajectories (from multiple replicas)
    and a sampler trajectory. In addition to basic statistics (mean and difference), this function computes
    divergence metrics (forward KL and JS divergence) on the distributions of secondary structure fractions.

    Parameters:
        data_dir (str): Base directory for topology and reference trajectory data.
        pdb_dir (str): Directory for sampler trajectory files.
        name (str): Identifier for the trajectories.
        temp (int): Temperature value used in file naming.
        random_frames (int): Optional; number of random frames to sample from each trajectory.

    Returns:
        dict: Dictionary containing the average secondary structure fractions (for reference and sampler),
              their difference, and divergence metrics computed between the two distributions.
    """
    secondary_codes = ["H", "G", "I", "E", "B"]

    def _load_replicas(indices):
        trajs = []
        top = f"{data_dir}/topology/{name}.pdb"
        for i in indices:
            xtc = f"{data_dir}/trajectory/{name}_{temp}_{i}.xtc"
            if os.path.exists(xtc):
                trajs.append(mdtraj.load_xtc(xtc, top=top))
        if not trajs:
            raise RuntimeError(f"No MD replicas found for {name} indices={indices}")
        return mdtraj.join(trajs) if len(trajs) > 1 else trajs[0]

    if gen_replicas is None:
        ref_traj = _load_replicas(range(5))
        samp_xtc = f"{pdb_dir}/{name}.xtc"
        samp_pdb = f"{pdb_dir}/{name}.pdb"
        if not os.path.exists(samp_xtc):
            raise RuntimeError(f"Sampler trajectory {samp_xtc} not found.")
        sampler_traj = mdtraj.load_xtc(samp_xtc, top=samp_pdb)

    else:
        gen_set = set(gen_replicas)
        if not gen_set.issubset({0, 1, 2, 3, 4}) or len(gen_set) in (0, 5):
            raise ValueError("gen_replicas must list 1‑4 indices from 0‑4")
        sampler_traj = _load_replicas(gen_set)
        ref_traj = _load_replicas([i for i in range(5) if i not in gen_set])

    ref_fractions = compute_secondary_structure_fraction(ref_traj, secondary_codes)

    if truncate:
        sampler_traj = sampler_traj[:truncate]

    sampler_fractions = compute_secondary_structure_fraction(
        sampler_traj, secondary_codes
    )

    ref_mean = ref_fractions.mean()
    sampler_mean = sampler_fractions.mean()
    mean_difference = abs(ref_mean - sampler_mean)

    divergences = compute_divergences(ref_fractions, sampler_fractions)

    results = {
        "ref_secondary_structure_fraction_mean": ref_mean,
        "sampler_secondary_structure_fraction_mean": sampler_mean,
        "mean_difference": mean_difference,
        **divergences,
    }

    return results


def compute_divergences(ref_values, sampler_values, bins=100, epsilon=1e-5):
    """
    Compute divergence metrics between two sets of values (for example, Rg distributions)
    based on histograms computed over a common rangxe.

    Parameters:
        ref_values (np.ndarray): Array of metric values from the reference trajectory (P).
        sampler_values (np.ndarray): Array of metric values from the sampler trajectory (Q).
        bins (int): Number of bins for creating histograms.
        epsilon (float): Small constant to avoid division by zero in KL divergence.

    Returns:
        dict: Dictionary with the computed forward KL divergence (D_KL(P||Q))
              and Jensen–Shannon divergence (squared JS distance).
    """
    combined_min = min(ref_values.min(), sampler_values.min())
    combined_max = max(ref_values.max(), sampler_values.max())
    bin_edges = np.linspace(combined_min, combined_max, bins + 1)

    ref_counts, _ = np.histogram(ref_values, bins=bin_edges, density=False)
    sampler_counts, _ = np.histogram(sampler_values, bins=bin_edges, density=False)

    ref_probs = ref_counts / ref_counts.sum()
    sampler_probs = sampler_counts / sampler_counts.sum()

    sampler_probs_for_kl = np.where(sampler_probs == 0, epsilon, sampler_probs)

    forward_kl = entropy(ref_probs, sampler_probs_for_kl)

    js_distance = jensenshannon(ref_probs, sampler_probs)
    js_divergence = js_distance**2

    return {
        "forward_kl_divergence": forward_kl,
        "jensen_shannon_divergence": js_divergence,
    }


def compute_gyration_radius(traj, random_frames=None):
    """
    Compute the radius of gyration (Rg) for selected frames in the trajectory.

    Parameters:
        traj (md.Trajectory): Trajectory object.
        random_frames (int): Number of random frames to sample. If None, compute Rg for all frames.

    Returns:
        np.ndarray: Array of Rg values for selected frames.
    """
    # Randomly sample frames if specified
    if random_frames is not None and random_frames < traj.n_frames:
        np.random.seed(42)  # Ensure reproducibility
        frame_indices = np.random.choice(
            traj.n_frames, size=random_frames, replace=False
        )
        sampled_traj = traj[frame_indices]
    else:
        sampled_traj = traj  # Use all frames if no sampling is specified

    # Compute Rg for the sampled trajectory
    rg_values = mdtraj.compute_rg(sampled_traj)
    return rg_values


def compare_gyration_radius(ref_traj_name, sampler_traj_name, random_frames=None):
    """
    Compute and compare the radius of gyration (Rg) for reference and sampler trajectories.

    Parameters:
        ref_traj (md.Trajectory): Reference trajectory.
        sampler_traj (md.Trajectory): Sampler trajectory.
        random_frames (int): Number of random frames to sample. If None, use all frames.

    Returns:
        dict: Average Rg for reference and sampler trajectories, and their difference.
    """
    # Load reference and sampler trajectories
    ref_traj = mdtraj.load_xtc(ref_traj_name + ".xtc", top=ref_traj_name + ".pdb")
    sampler_traj = mdtraj.load_xtc(
        sampler_traj_name + ".xtc", top=sampler_traj_name + ".pdb"
    )

    # Compute Rg for each trajectory
    ref_rg = compute_gyration_radius(ref_traj, random_frames=random_frames)

    sampler_rg = compute_gyration_radius(sampler_traj, random_frames=random_frames)

    # Compute mean Rg for both trajectories
    ref_rg_mean = ref_rg.mean()
    sampler_rg_mean = sampler_rg.mean()

    # Compare the two Rg values
    rg_difference = abs(ref_rg_mean - sampler_rg_mean)

    # Compute divergence metrics
    divergence_metrics = compute_divergences(ref_rg, sampler_rg)

    # Return the results
    return {
        "ref_gyration_radius_mean": ref_rg_mean,
        "sampler_gyration_radius_mean": sampler_rg_mean,
        "gyration_radius_difference": rg_difference,
        **divergence_metrics,
    }


def compare_gyration_radius_mdcath(
    data_dir,
    pdb_dir,
    name,
    temp=320,
    random_frames=None,
    truncate=None,
    gen_replicas=None,
):
    """
    Compute and compare the radius of gyration (Rg) for a reference trajectory (across multiple replicas)
    and a sampler trajectory. This variant concatenates all reference replicas into one trajectory.

    Parameters:
        ref_base (str): Base name (or path) for the reference trajectory. It is assumed that each replica's
                        xtc file is named as '{ref_base}_rep1.xtc', '{ref_base}_rep2.xtc', ..., etc.
        sampler_base (str): Base name (or path) for the sampler trajectory (assumed to be a single trajectory).
        random_frames (int): Number of random frames to sample. If None, use all frames.

    Returns:
        dict: Average Rg for reference and sampler trajectories, and their difference.
    """

    def _load_replicas(indices):
        trajs = []
        top = f"{data_dir}/topology/{name}.pdb"
        for i in indices:
            xtc = f"{data_dir}/trajectory/{name}_{temp}_{i}.xtc"
            if os.path.exists(xtc):
                trajs.append(mdtraj.load_xtc(xtc, top=top))
        if not trajs:
            raise RuntimeError(f"No MD replicas found for {name} indices={indices}")
        return mdtraj.join(trajs) if len(trajs) > 1 else trajs[0]

    if gen_replicas is None:
        ref_traj = _load_replicas(range(5))
        samp_xtc = f"{pdb_dir}/{name}.xtc"
        samp_pdb = f"{pdb_dir}/{name}.pdb"
        if not os.path.exists(samp_xtc):
            raise RuntimeError(f"Sampler trajectory {samp_xtc} not found.")
        sampler_traj = mdtraj.load_xtc(samp_xtc, top=samp_pdb)

    else:
        gen_set = set(gen_replicas)
        if not gen_set.issubset({0, 1, 2, 3, 4}) or len(gen_set) in (0, 5):
            raise ValueError("gen_replicas must list 1‑4 indices from 0‑4")
        sampler_traj = _load_replicas(gen_set)
        ref_traj = _load_replicas([i for i in range(5) if i not in gen_set])

    if truncate is not None:
        sampler_traj = sampler_traj[:truncate]

    ref_rg = compute_gyration_radius(ref_traj, random_frames=random_frames)
    sampler_rg = compute_gyration_radius(sampler_traj, random_frames=random_frames)

    ref_rg_mean = ref_rg.mean()
    sampler_rg_mean = sampler_rg.mean()
    rg_difference = abs(ref_rg_mean - sampler_rg_mean)

    # Compute divergence metrics
    divergence_metrics = compute_divergences(ref_rg, sampler_rg)

    # Return the results
    return {
        "ref_gyration_radius_mean": ref_rg_mean,
        "sampler_gyration_radius_mean": sampler_rg_mean,
        "gyration_radius_difference": rg_difference,
        **divergence_metrics,
    }


def effective_sample_size(data, max_lag=1000):
    r = acf(data, nlags=max_lag, fft=True, missing="drop")
    n_lags = len(r) - 1
    sum_r = r[1 : n_lags + 1].sum()
    N = len(data)
    denom = 1.0 + 2.0 * sum_r
    if denom <= 0:
        return float(N)
    return N / denom


def get_featurized_traj(name, sidechains=False, cossin=True):
    feat = pyemma.coordinates.featurizer(name + ".pdb")
    feat.add_backbone_torsions(cossin=cossin)
    if sidechains:
        feat.add_sidechain_torsions(cossin=cossin)
    traj = pyemma.coordinates.load(name + ".xtc", features=feat)
    return feat, traj


def get_featurized_traj_replicas(
    data_dir, name, temp=320, sidechains=False, cossin=True, replica_indices=None
):
    if replica_indices is None:
        replica_indices = range(5)

    feat = pyemma.coordinates.featurizer(f"{data_dir}/topology/{name}.pdb")
    feat.add_backbone_torsions(cossin=cossin)
    if sidechains:
        feat.add_sidechain_torsions(cossin=cossin)

    traj_all = []
    for i in replica_indices:
        xtc_file = f"{data_dir}/trajectory/{name}_{temp}_{i}.xtc"
        traj_all.append(pyemma.coordinates.load(xtc_file, features=feat))
    return feat, traj_all


def get_heavy_atom_coordinates(name):
    feat = pyemma.coordinates.featurizer(name + ".pdb")

    heavy_atom_indices = [
        i for i, res in enumerate(feat.topology.atoms) if res.element.symbol != "H"
    ]
    feat.add_selection(heavy_atom_indices)

    traj = pyemma.coordinates.load(name + ".xtc", features=feat)

    return feat, traj


def get_heavy_atom_coordinates_mdcath(data_dir, name, temp=320):
    feat = pyemma.coordinates.featurizer(f"{data_dir}/topology/{name}.pdb")

    heavy_atom_indices = [
        i for i, res in enumerate(feat.topology.atoms) if res.element.symbol != "H"
    ]
    feat.add_selection(heavy_atom_indices)

    traj_list = []
    for i in range(5):
        traj = pyemma.coordinates.load(
            f"{data_dir}/trajectory/{name}_{temp}_{i}.xtc", features=feat
        )
        traj_list.append(traj)

    return feat, traj_list


def get_featurized_traj_mdcath(data_dir, name, temp=320, half_ca_threshold=400, step=2):
    feat = pyemma.coordinates.featurizer(f"{data_dir}/topology/{name}.pdb")

    ca_indices = feat.select_Ca()
    if len(ca_indices) > half_ca_threshold:
        ca_indices = ca_indices[::step]
        feat.add_distances(indices=ca_indices)
    else:
        feat.add_distances_ca()
    traj_list = []
    for i in range(5):
        traj = pyemma.coordinates.load(
            f"{data_dir}/trajectory/{name}_{temp}_{i}.xtc", features=feat
        )
        traj_list.append(traj)
    return feat, traj_list


def get_gr_traj_mdcath(data_dir, name, temp=320, half_ca_threshold=400, step=2):
    feat = pyemma.coordinates.featurizer(f"{data_dir}/topology/{name}.pdb")

    ca_indices = feat.select_Ca()
    if len(ca_indices) > half_ca_threshold:
        ca_indices = ca_indices[::step]
        feat.add_distances(indices=ca_indices)
    else:
        feat.add_distances_ca()
    traj_list = []
    for i in range(5):
        traj = pyemma.coordinates.load(
            f"{data_dir}/trajectory/{name}_{temp}_{i}.xtc", features=feat
        )
        traj_list.append(traj)
    return feat, traj_list


def get_featurized_traj_mdcath_sampled(pdbdir, name, half_ca_threshold=400):
    feat = pyemma.coordinates.featurizer(f"{pdbdir}/{name}.pdb")

    ca_indices = feat.select_Ca()
    if len(ca_indices) > half_ca_threshold:
        ca_indices = ca_indices[::2]
        feat.add_distances(indices=ca_indices)
    else:
        feat.add_distances_ca()
    traj = pyemma.coordinates.load(f"{pdbdir}/{name}.xtc", features=feat)
    return feat, traj


def get_featurized_atlas_traj(name, sidechains=False, cossin=True):
    feat = pyemma.coordinates.featurizer(name + ".pdb")
    feat.add_backbone_torsions(cossin=cossin)

    if sidechains:
        feat.add_sidechain_torsions(cossin=cossin)
    traj = pyemma.coordinates.load(name + "_prod_R1_fit.xtc", features=feat)
    return feat, traj


def get_tica(traj_list, lag=1000, var_cutoff=0.95):
    tica = pyemma.coordinates.tica(
        traj_list, lag=lag, kinetic_map=True, var_cutoff=var_cutoff
    )

    return tica, tica.transform(traj_list)


def get_kmeans(traj):
    kmeans = pyemma.coordinates.cluster_kmeans(
        traj, k=100, max_iter=100, fixed_seed=137
    )
    return kmeans, kmeans.dtrajs


def get_kmeans_replicas(traj):
    kmeans = deeptime.clustering.KMeans(
        n_clusters=20, max_iter=100, fixed_seed=2137
    ).fit(np.concatenate(traj))
    dtrajs = [kmeans.transform(data_tica_transformed) for data_tica_transformed in traj]
    return kmeans, dtrajs


def get_msm(traj, lag=1000, nstates=10):
    msm = pyemma.msm.estimate_markov_model(traj, lag=lag)

    pcca = msm.pcca(nstates)
    assert len(msm.metastable_assignments) == 100
    cmsm = pyemma.msm.estimate_markov_model(msm.metastable_assignments[traj], lag=lag)
    return msm, pcca, cmsm


def get_msm_mdcath(traj, lag=1000, nstates=10):
    msm = pyemma.msm.estimate_markov_model(traj, lag=lag, connectivity="all")
    pcca = msm.pcca(nstates)
    assert len(msm.metastable_assignments) == 100
    assignment = [msm.metastable_assignments[t] for t in traj]
    cmsm = pyemma.msm.estimate_markov_model(assignment, lag=lag, connectivity="all")
    return msm, pcca, cmsm


def get_msm_mdcath_deeptime(trajs, lag=1, nstates=5):

    counts_estimator = deeptime.markov.TransitionCountEstimator(lag, "sliding")
    counts = counts_estimator.fit_fetch(trajs)
    msm = deeptime.markov.msm.MaximumLikelihoodMSM(
        allow_disconnected=True, reversible=True
    ).fit_fetch(counts)
    pcca_obj = msm.pcca(nstates)
    metastable_assignments = pcca_obj.assignments
    meta_dtrajs = [metastable_assignments[dtraj] for dtraj in trajs]

    counts_estimator_pcca = deeptime.markov.TransitionCountEstimator(lag, "sliding")
    counts_pcca = counts_estimator_pcca.fit_fetch(meta_dtrajs)
    msm_pcca = deeptime.markov.msm.MaximumLikelihoodMSM(
        allow_disconnected=True, reversible=True
    ).fit_fetch(counts_pcca)
    transition_matrix = msm_pcca.transition_matrix
    pi = msm_pcca.stationary_distribution

    return metastable_assignments, transition_matrix, pi


def discretize_deeptime(traj, kmeans, metastable_assignments):
    kmeans_traj = [
        kmeans.transform(data_tica_transformed) for data_tica_transformed in traj
    ]
    meta_dtrajs = [metastable_assignments[dtraj] for dtraj in kmeans_traj]
    return meta_dtrajs


def discretize(traj, kmeans, msm):
    return msm.metastable_assignments[kmeans.transform(traj)[:, 0]]


def load_tps_ensemble(name, directory):
    metadata = json.load(open(os.path.join(directory, f"{name}_metadata.json"), "rb"))
    all_feats = []
    all_traj = []
    for i, meta_dict in tqdm(enumerate(metadata)):
        feats, traj = get_featurized_traj(f"{directory}/{name}_{i}", sidechains=True)
        all_feats.append(feats)
        all_traj.append(traj)
    return all_feats, all_traj


def sample_tp(trans, start_state, end_state, traj_len, n_samples):
    s_1 = start_state
    s_N = end_state
    N = traj_len

    s_t = np.ones(n_samples, dtype=int) * s_1
    states = [s_t]
    for t in range(1, N - 1):
        numerator = np.linalg.matrix_power(trans, N - t - 1)[:, s_N] * trans[s_t, :]
        probs = numerator / np.linalg.matrix_power(trans, N - t)[s_t, s_N][:, None]
        s_t = np.zeros(n_samples, dtype=int)
        for n in range(n_samples):
            s_t[n] = np.random.choice(np.arange(len(trans)), 1, p=probs[n])
        states.append(s_t)
    states.append(np.ones(n_samples, dtype=int) * s_N)
    return np.stack(states, axis=1)


def get_tp_likelihood(tp, trans):
    N = tp.shape[1]
    n_samples = tp.shape[0]
    s_N = tp[0, -1]
    trans_probs = []
    for i in range(N - 1):
        t = i + 1
        s_t = tp[:, i]
        numerator = np.linalg.matrix_power(trans, N - t - 1)[:, s_N] * trans[s_t, :]
        probs = numerator / np.linalg.matrix_power(trans, N - t)[s_t, s_N][:, None]

        s_tp1 = tp[:, i + 1]
        trans_prob = probs[np.arange(n_samples), s_tp1]
        trans_probs.append(trans_prob)
    probs = np.stack(trans_probs, axis=1)
    probs[np.isnan(probs)] = 0
    return probs


def get_state_probs(tp, num_states=10):
    stationary = np.bincount(tp.reshape(-1), minlength=num_states)
    return stationary / stationary.sum()
