import argparse

parser = argparse.ArgumentParser()
parser.add_argument(
    "--mddir",
    type=str,
    default="anonymise",
)
parser.add_argument("--pdbdir", type=str, required=True)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--save", action="store_true")
parser.add_argument("--plot", action="store_true")
parser.add_argument("--save_name", type=str, default="out.pkl")
parser.add_argument("--pdb_id", nargs="*", default=[])
parser.add_argument("--no_msm", action="store_true")
parser.add_argument("--no_decorr", action="store_true")
parser.add_argument("--no_traj_msm", action="store_true")
parser.add_argument("--truncate", type=int, default=None)
parser.add_argument("--seperation_steps", type=int, default=None)
parser.add_argument("--msm_lag", type=int, default=10)
parser.add_argument("--ito", action="store_true")
parser.add_argument("--n_lag_traj", type=int, default=1000)
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--overfit_peptide", type=str, default=None)
parser.add_argument("--numstates_tica", type=int, default=10)
parser.add_argument("--md_as_traj", action="store_true")
parser.add_argument("--pdf_name", type=str, default=None)

args = parser.parse_args()

import mdgen.analysis
import pyemma, tqdm, os, pickle
from scipy.spatial.distance import jensenshannon
from multiprocessing import Pool
import numpy as np
import matplotlib.pyplot as plt
from statsmodels.tsa.stattools import acovf, acf
from scipy.stats import spearmanr
from scipy.stats import entropy

EPS = 1e-5

colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]


def plot_triangular_flux(ref_flux, gen_flux, ax, corr=None, cmap="Reds"):
    n = ref_flux.shape[0]

    sqrt_ref = np.sqrt(ref_flux)
    sqrt_gen = np.sqrt(gen_flux)

    combined = np.full((n, n), np.nan, dtype=float)

    for i in range(n):
        for j in range(n):
            if i < j:
                # Upper triangle
                combined[i, j] = sqrt_ref[i, j]
            elif i > j:
                # Lower triangle
                combined[i, j] = sqrt_gen[i, j]
            # diagonal stays NaN (or could set to 0)

    # Plot the matrix
    im = ax.imshow(combined, origin="lower", cmap=cmap, interpolation="nearest")
    # Draw black diagonal
    ax.plot([0, n - 1], [0, n - 1], color="black", lw=1.5)

    # Remove ticks for a cleaner look
    ax.set_xticks([])
    ax.set_yticks([])

    # Optional correlation in title
    if corr is not None:
        ax.set_title(f"Flux Spearman R={corr:.2f}")

    return im


def main(name):
    out = {}
    fig, axs = plt.subplots(4, 4, figsize=(20, 20))

    from mdgen.utils import set_seed

    set_seed(args.seed)

    if args.plot:
        if args.md_as_traj:
            feats, traj = mdgen.analysis.get_featurized_traj(
                f"{args.mddir}/{name}/{name}", sidechains=False, cossin=False
            )
        else:
            feats, traj = mdgen.analysis.get_featurized_traj(
                f"{args.pdbdir}/{name}", sidechains=False, cossin=False
            )
        if args.seperation_steps:
            traj = traj[:: args.seperation_steps]
        if args.truncate:
            traj = traj[: args.truncate]
        feats, ref = mdgen.analysis.get_featurized_traj(
            f"{args.mddir}/{name}/{name}", sidechains=False, cossin=False
        )
        # if args.truncate: ref = ref[:args.truncate]

        pyemma.plots.plot_feature_histograms(
            ref, feature_labels=feats, ax=axs[0, 0], color=colors[0]
        )
        pyemma.plots.plot_feature_histograms(traj, ax=axs[0, 0], color=colors[1])
        axs[0, 0].set_title("BB torsions")

    ### JENSEN SHANNON DISTANCES ON ALL TORSIONS
    if args.md_as_traj:
        feats, traj = mdgen.analysis.get_featurized_traj(
            f"{args.mddir}/{name}/{name}", sidechains=True, cossin=False
        )
    else:
        feats, traj = mdgen.analysis.get_featurized_traj(
            f"{args.pdbdir}/{name}", sidechains=True, cossin=False
        )

    if args.seperation_steps:
        traj = traj[:: args.seperation_steps]
    if args.truncate:
        traj = traj[: args.truncate]
    feats, ref = mdgen.analysis.get_featurized_traj(
        f"{args.mddir}/{name}/{name}", sidechains=True, cossin=False
    )
    # if args.truncate: ref = ref[:args.truncate]

    ### gyrations
    out["gyration_radius_difference"] = mdgen.analysis.compare_gyration_radius(
        f"{args.pdbdir}/{name}", f"{args.mddir}/{name}/{name}"
    )["gyration_radius_difference"]

    out["features"] = feats.describe()

    out["JSD"] = {}
    out["FWD_KL"] = {}
    for i, feat in enumerate(feats.describe()):
        ref_p = np.histogram(ref[:, i], range=(-np.pi, np.pi), bins=100)[0]
        traj_p = np.histogram(traj[:, i], range=(-np.pi, np.pi), bins=100)[0]
        out["JSD"][feat] = jensenshannon(ref_p, traj_p)
        ref_probs = ref_p / ref_p.sum()  # P
        traj_probs = traj_p / traj_p.sum()  # Q
        traj_probs = np.where(traj_probs == 0, EPS, traj_probs)
        traj_probs = traj_probs / traj_probs.sum()
        out["FWD_KL"][feat] = entropy(ref_probs, traj_probs)
    for i in [1, 3]:
        ref_p = np.histogram2d(
            *ref[:, i : i + 2].T, range=((-np.pi, np.pi), (-np.pi, np.pi)), bins=50
        )[0]
        traj_p = np.histogram2d(
            *traj[:, i : i + 2].T, range=((-np.pi, np.pi), (-np.pi, np.pi)), bins=50
        )[0]
        out["JSD"]["|".join(feats.describe()[i : i + 2])] = jensenshannon(
            ref_p.flatten(), traj_p.flatten()
        )
        ref_probs = ref_p / ref_p.sum()  # P
        traj_probs = traj_p / traj_p.sum()  # Q
        traj_probs = np.where(traj_probs == 0, EPS, traj_probs)
        traj_probs = traj_probs / traj_probs.sum()
        out["FWD_KL"]["|".join(feats.describe()[i : i + 2])] = entropy(
            ref_probs, traj_probs
        )

    ############ Torsion decorrelations
    if args.no_decorr:
        pass
    else:
        out["md_decorrelation"] = {}
        out["md_Neff"] = {}

        for i, feat in enumerate(feats.describe()):
            autocorr = acovf(
                np.sin(ref[:, i]), demean=False, adjusted=True, nlag=100000
            ) + acovf(np.cos(ref[:, i]), demean=False, adjusted=True, nlag=100000)

            baseline = np.sin(ref[:, i]).mean() ** 2 + np.cos(ref[:, i]).mean() ** 2
            # E[(X(t) - E[X(t)]) * (X(t+dt) - E[X(t+dt)])] = E[X(t)X(t+dt) - E[X(t)]X(t+dt) - X(t)E[X(t+dt)] + E[X(t)]E[X(t+dt)]] = E[X(t)X(t+dt)] - E[X]**2
            lags = 1 + np.arange(len(autocorr))
            if "PHI" in feat or "PSI" in feat:
                axs[0, 1].plot(
                    lags,
                    (autocorr - baseline) / (1 - baseline),
                    color=colors[i % len(colors)],
                )
            else:
                axs[0, 2].plot(
                    lags,
                    (autocorr - baseline) / (1 - baseline),
                    color=colors[i % len(colors)],
                )

            out["md_decorrelation"][feat] = (autocorr.astype(np.float16) - baseline) / (
                1 - baseline
            )

        axs[0, 1].set_title("Backbone decorrelation")
        axs[0, 2].set_title("Sidechain decorrelation")
        axs[0, 1].set_xscale("log")
        axs[0, 2].set_xscale("log")

        out["our_decorrelation"] = {}
        out["our_Neff"] = {}
        for i, feat in enumerate(feats.describe()):

            out["our_Neff"][feat] = (
                mdgen.analysis.effective_sample_size(np.sin(traj[:, i]), max_lag=10)
                + mdgen.analysis.effective_sample_size(np.cos(traj[:, i]), max_lag=10)
            ) / 2

            autocorr = acovf(
                np.sin(traj[:, i]), demean=False, adjusted=True, nlag=args.n_lag_traj
            ) + acovf(
                np.cos(traj[:, i]), demean=False, adjusted=True, nlag=args.n_lag_traj
            )

            baseline = np.sin(traj[:, i]).mean() ** 2 + np.cos(traj[:, i]).mean() ** 2
            lags = 1 + np.arange(len(autocorr))
            if "PHI" in feat or "PSI" in feat:
                axs[1, 1].plot(
                    lags,
                    (autocorr - baseline) / (1 - baseline),
                    color=colors[i % len(colors)],
                )
            else:
                axs[1, 2].plot(
                    lags,
                    (autocorr - baseline) / (1 - baseline),
                    color=colors[i % len(colors)],
                )

            out["our_decorrelation"][feat] = (
                autocorr.astype(np.float16) - baseline
            ) / (1 - baseline)

        axs[1, 1].set_title("Backbone decorrelation")
        axs[1, 2].set_title("Sidechain decorrelation")
        axs[1, 1].set_xscale("log")
        axs[1, 2].set_xscale("log")

    ####### TICA #############
    if args.md_as_traj:
        feats, traj = mdgen.analysis.get_featurized_traj(
            f"{args.mddir}/{name}/{name}", sidechains=True, cossin=True
        )
    else:
        feats, traj = mdgen.analysis.get_featurized_traj(
            f"{args.pdbdir}/{name}", sidechains=True, cossin=True
        )
    if args.seperation_steps:
        traj = traj[:: args.seperation_steps]
    if args.truncate:
        traj = traj[: args.truncate]
    feats, ref = mdgen.analysis.get_featurized_traj(
        f"{args.mddir}/{name}/{name}", sidechains=True, cossin=True
    )
    # if args.truncate: ref = ref[:args.truncate]

    tica, _ = mdgen.analysis.get_tica(ref)
    # tica, _ = mdgen.analysis.get_tica(ref, lag = 100)
    ref_tica = tica.transform(ref)
    traj_tica = tica.transform(traj)

    tica_0_min = min(ref_tica[:, 0].min(), traj_tica[:, 0].min())
    tica_0_max = max(ref_tica[:, 0].max(), traj_tica[:, 0].max())

    tica_1_min = min(ref_tica[:, 1].min(), traj_tica[:, 1].min())
    tica_1_max = max(ref_tica[:, 1].max(), traj_tica[:, 1].max())

    ref_p = np.histogram(ref_tica[:, 0], range=(tica_0_min, tica_0_max), bins=100)[0]
    traj_p = np.histogram(traj_tica[:, 0], range=(tica_0_min, tica_0_max), bins=100)[0]
    out["JSD"]["TICA-0"] = jensenshannon(ref_p, traj_p)
    ref_probs = ref_p / ref_p.sum()
    traj_probs = traj_p / traj_p.sum()
    traj_probs = np.where(traj_probs == 0, EPS, traj_probs)
    traj_probs = traj_probs / traj_probs.sum()
    out["FWD_KL"]["TICA-0"] = entropy(ref_probs, traj_probs)

    ref_p = np.histogram2d(
        *ref_tica[:, :2].T,
        range=((tica_0_min, tica_0_max), (tica_1_min, tica_1_max)),
        bins=50,
    )[0]
    traj_p = np.histogram2d(
        *traj_tica[:, :2].T,
        range=((tica_0_min, tica_0_max), (tica_1_min, tica_1_max)),
        bins=50,
    )[0]
    out["JSD"]["TICA-0,1"] = jensenshannon(ref_p.flatten(), traj_p.flatten())
    ref_probs = ref_p.flatten() / ref_p.flatten().sum()
    traj_probs = traj_p.flatten() / traj_p.flatten().sum()
    traj_probs = np.where(traj_probs == 0, EPS, traj_probs)
    traj_probs = traj_probs / traj_probs.sum()
    out["FWD_KL"]["TICA-0,1"] = entropy(ref_probs, traj_probs)

    #### 1,0, 1,1 TICA FES
    if args.plot:
        pyemma.plots.plot_free_energy(*ref_tica[::100, :2].T, ax=axs[2, 0], cbar=False)
        pyemma.plots.plot_free_energy(*traj_tica[:, :2].T, ax=axs[2, 1], cbar=False)
        axs[2, 0].set_title("TICA FES (MD)")
        axs[2, 1].set_title("TICA FES (ours)")

    ####### TICA decorrelation ########
    if args.no_decorr:
        pass
    else:
        # x, adjusted=False, demean=True, fft=True, missing='none', nlag=None
        autocorr = acovf(
            ref_tica[:, 0],
            nlag=100000 if not args.truncate else args.truncate / 10,
            adjusted=True,
            demean=False,
        )
        out["md_decorrelation"]["tica"] = autocorr.astype(np.float16)
        if args.plot:
            axs[0, 3].plot(autocorr)
            axs[0, 3].set_title("MD TICA")

        # autocorr = acovf(traj_tica[:,0], nlag=1 if args.ito else 1000, adjusted=True, demean=False)
        autocorr = acovf(
            traj_tica[:, 0], nlag=args.n_lag_traj, adjusted=True, demean=False
        )

        out["our_decorrelation"]["tica"] = autocorr.astype(np.float16)
        out["our_Neff"]["tica_0"] = mdgen.analysis.effective_sample_size(
            traj_tica[:, 0], max_lag=10
        )

        if args.plot:
            axs[1, 3].plot(autocorr)
            axs[1, 3].set_title("Traj TICA")

    ###### Markov state model stuff #################
    if not args.no_msm:
        kmeans, ref_kmeans = mdgen.analysis.get_kmeans(tica.transform(ref))
        try:
            ref_kmeans = ref_kmeans[0]
            msm, pcca, cmsm = mdgen.analysis.get_msm(
                ref_kmeans, nstates=args.numstates_tica, lag=100 * args.msm_lag
            )

            out["kmeans"] = kmeans
            out["msm"] = msm
            out["pcca"] = pcca
            out["cmsm"] = cmsm

            traj_discrete = mdgen.analysis.discretize(tica.transform(traj), kmeans, msm)
            ref_discrete = mdgen.analysis.discretize(tica.transform(ref), kmeans, msm)
            out["traj_metastable_probs"] = (
                traj_discrete == np.arange(args.numstates_tica)[:, None]
            ).mean(1)
            out["ref_metastable_probs"] = (
                ref_discrete == np.arange(args.numstates_tica)[:, None]
            ).mean(1)
            #########

            # set constants
            kB = 0.0019872041  # (kcal/mol·K) common in MD contexts, or use whichever units are appropriate
            T = 350.0  # temperature in Kelvin
            kBT = kB * T  # e.g., ~0.596 kcal/mol at 300K

            ref_pi = np.maximum(out["ref_metastable_probs"], 1e-4)
            traj_pi = np.maximum(out["traj_metastable_probs"], 1e-4)

            G_ref = -kBT * np.log(ref_pi)
            G_traj = -kBT * np.log(traj_pi)

            mMAE = np.mean(np.abs(G_traj - G_ref))

            out["mMAE"] = mMAE
            out["G_ref"] = G_ref
            out["G_traj"] = G_traj

            msm_transition_matrix = np.eye(args.numstates_tica)
            for a, i in enumerate(cmsm.active_set):
                for b, j in enumerate(cmsm.active_set):
                    msm_transition_matrix[i, j] = cmsm.transition_matrix[a, b]

            out["msm_transition_matrix"] = msm_transition_matrix
            out["pcca_pi"] = pcca._pi_coarse

            msm_pi = np.zeros(args.numstates_tica)
            msm_pi[cmsm.active_set] = cmsm.pi
            out["msm_pi"] = msm_pi

            if args.no_traj_msm:
                pass
            else:

                traj_msm = pyemma.msm.estimate_markov_model(
                    traj_discrete, lag=args.msm_lag
                )
                out["traj_msm"] = traj_msm

                traj_transition_matrix = np.eye(args.numstates_tica)
                for a, i in enumerate(traj_msm.active_set):
                    for b, j in enumerate(traj_msm.active_set):
                        traj_transition_matrix[i, j] = traj_msm.transition_matrix[a, b]
                out["traj_transition_matrix"] = traj_transition_matrix

                traj_pi = np.zeros(args.numstates_tica)
                traj_pi[traj_msm.active_set] = traj_msm.pi
                out["traj_pi"] = traj_pi

                ref_flux = msm_transition_matrix * msm_pi[:, None]
                traj_flux = traj_transition_matrix * traj_pi[:, None]
                corr, pval = spearmanr(ref_flux.flatten(), traj_flux.flatten())
                out["flux_spearman_corr"] = corr
                out["flux_spearman_pval"] = pval

                if args.plot:
                    im = plot_triangular_flux(
                        ref_flux, traj_flux, ax=axs[3, 3], corr=corr, cmap="Reds"
                    )

                    fig.colorbar(im, ax=axs[3, 3], fraction=0.046, pad=0.04)

        except Exception as e:
            print("ERROR", e, name)

    if args.plot:
        if args.pdf_name:
            fig.savefig(f"{args.pdbdir}/{name}_{args.pdf_name}.pdf")
        else:
            fig.savefig(f'{args.pdbdir}/{name}_{args.pdbdir.split("/")[-1]}.pdf')

    return name, out


if args.pdb_id:
    pdb_id = args.pdb_id
else:
    pdb_id = [
        nam.split(".")[0]
        for nam in os.listdir(args.pdbdir)
        if ".pdb" in nam and not "_traj" in nam
    ]
pdb_id = [nam for nam in pdb_id if os.path.exists(f"{args.pdbdir}/{nam}.xtc")]
print("number of trajectories", len(pdb_id))
if args.overfit_peptide:
    pdb_id = args.overfit_peptide.split(",")
if args.num_workers > 1:
    p = Pool(args.num_workers)
    p.__enter__()
    __map__ = p.imap
else:
    __map__ = map
out = dict(tqdm.tqdm(__map__(main, pdb_id), total=len(pdb_id)))
if args.num_workers > 1:
    p.__exit__(None, None, None)

if args.save:
    with open(f"{args.pdbdir}/{args.save_name}", "wb") as f:
        f.write(pickle.dumps(out))
