import argparse
import pandas as pd

parser = argparse.ArgumentParser()
parser.add_argument(
    "--mddir",
    type=str,
    default="anonymise",
)
parser.add_argument(
    "--seed", type=int, default=42, help="Random seed for reproducibility"
)
parser.add_argument("--pdbdir", type=str, required=True)
parser.add_argument(
    "--split",
    type=str,
    default="splits/mdCATH_test.csv",
    help="Path to the split file containing the test set.",
)
parser.add_argument(
    "--gen_replicas",
    type=int,
    nargs="+",
    default=None,
    help=(
        "Indices of MD‑CATH replicas (0–4) that should be treated as the "
        "generated ensemble. All remaining replicas become the reference. "
        "If omitted, the external trajectory in --pdbdir is the generated "
        "ensemble (original behaviour)."
    ),
)
parser.add_argument(
    "--gen_replicas_savedir",
    type=str,
    default="./workdir/MD_baseline_320/MD_baselines",
)

parser.add_argument("--save", action="store_true")
parser.add_argument("--plot", action="store_true")
parser.add_argument("--save_name", type=str, default="out_basic.pkl")
parser.add_argument("--pdb_id", nargs="*", default=[])
parser.add_argument("--no_msm", action="store_true")
parser.add_argument("--truncate", type=int, default=None)
parser.add_argument("--seperation_steps", type=int, default=None)
parser.add_argument("--msm_lag", type=int, default=1)
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--overfit_peptide", type=str, default=None)
parser.add_argument("--notica", action="store_true")
parser.add_argument("--minimum_length", type=int, default=0)
parser.add_argument("--maximum_length", type=int, default=10000)
parser.add_argument(
    "--numstates_tica",
    type=int,
    nargs="+",
    default=[5, 10, 9, 8, 7, 6, 4, 3, 2],
    help="List of candidate values for the number of TICA states.",
)
parser.add_argument("--pdf_name", type=str, default=None)
parser.add_argument(
    "--tica_models_path",
    type=str,
    default="anonymise",
)
parser.add_argument("--temp", type=int, default=320)
parser.add_argument(
    "--feature_states",
    type=int,
    default=10,
    help="Number of k‑means clusters for the feature‑based MSM metric.",
)

args = parser.parse_args()

import mdgen.analysis
import functools, operator
import pyemma, tqdm, os, pickle
from scipy.spatial.distance import jensenshannon
from multiprocessing import Pool
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import spearmanr
from scipy.stats import entropy

colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]


def main(name):
    import os

    os.environ["OMP_NUM_THREADS"] = "1"
    os.environ["OPENBLAS_NUM_THREADS"] = "1"
    os.environ["MKL_NUM_THREADS"] = "1"
    os.environ["NUMEXPR_NUM_THREADS"] = "1"
    import mdtraj

    def load_md_replicas(mddir, name, temp, replica_indices=None):
        if replica_indices is None:
            replica_indices = range(5)
        trajs = []
        for i in replica_indices:
            xtc_path = f"{mddir}/trajectory/{name}_{temp}_{i}.xtc"
            pdb_path = f"{mddir}/topology/{name}.pdb"
            if os.path.exists(xtc_path):
                trajs.append(mdtraj.load_xtc(xtc_path, top=pdb_path))
        if not trajs:
            raise RuntimeError(
                f"No MD replicas found for {name} with indices={replica_indices}"
            )

        return trajs

    out = {}

    from mdgen.utils import set_seed

    set_seed(args.seed)

    try:
        gr_stats = mdgen.analysis.compare_gyration_radius_mdcath(
            args.mddir,
            args.pdbdir,
            name,
            temp=args.temp,
            truncate=args.truncate,
            gen_replicas=args.gen_replicas,
        )
        out["gyration_radius_difference"] = gr_stats["gyration_radius_difference"]
        out["gyration_radius_KL"] = gr_stats["forward_kl_divergence"]
        out["gyration_radius_JSD"] = gr_stats["jensen_shannon_divergence"]

        secondary_structure_stats = mdgen.analysis.compare_secondary_structure_mdcath(
            args.mddir,
            args.pdbdir,
            name,
            temp=args.temp,
            truncate=args.truncate,
            gen_replicas=args.gen_replicas,
        )
        out["ss_difference"] = secondary_structure_stats["mean_difference"]
        out["ss_KL"] = secondary_structure_stats["forward_kl_divergence"]
        out["ss_JSD"] = secondary_structure_stats["jensen_shannon_divergence"]
    except Exception as e:
        print("ERROR", e, name)
        return name, out

    # # try:
    if args.gen_replicas is not None:
        gen_set = set(args.gen_replicas)
        if not gen_set.issubset({0, 1, 2, 3, 4}) or len(gen_set) in (0, 5):
            raise ValueError("--gen_replicas must contain 1‑4 indices from 0‑4")

        # load MD replicas ONCE
        samp_trajs = load_md_replicas(args.mddir, name, args.temp, gen_set)
        ref_trajs = load_md_replicas(
            args.mddir, name, args.temp, [i for i in range(5) if i not in gen_set]
        )

        traj_samp = functools.reduce(operator.add, samp_trajs)  # concatenate
    else:
        # original path: generated ensemble lives in --pdbdir
        samp_xtc = f"{args.pdbdir}/{name}.xtc"
        samp_pdb = f"{args.pdbdir}/{name}.pdb"
        if not os.path.exists(samp_xtc):
            raise RuntimeError(f"Sampler trajectory {samp_xtc} not found.")
        traj_samp = mdtraj.load_xtc(samp_xtc, top=samp_pdb)

        # reference = all five MD replicas
        ref_trajs = load_md_replicas(args.mddir, name, args.temp)

    if args.truncate:
        traj_samp = traj_samp[: args.truncate]

    FEATURES = [
        ["gr", "secondary"],
    ]
    out["feature_MSM"] = {}

    for feat in FEATURES:

        kmeans, disc_ref, global_ref = mdgen.analysis._fit_kmeans_on_reference(
            ref_trajs, feat, args.feature_states
        )
        disc_samp = kmeans.transform(
            np.hstack(
                [
                    mdgen.analysis._compute_feature(traj_samp, f, traj_samp[0])
                    for f in feat
                ]
            )
        )
        try:
            pi_ref, state_symbols_ref = mdgen.analysis._stationary_dist(
                disc_ref, lag=args.msm_lag
            )
            pi_samp, state_symbols_samp = mdgen.analysis._stationary_dist(
                [disc_samp], lag=args.msm_lag
            )
        except Exception as e:
            print("ERROR", e, name)
            return name, out

        pi_full_ref = np.zeros(args.feature_states)
        pi_full_ref[state_symbols_ref] = pi_ref
        pi_full_samp = np.zeros(args.feature_states)
        pi_full_samp[state_symbols_samp] = pi_samp

        out["feature_MSM"][f"{','.join(feat)}_JSD"] = (
            jensenshannon(pi_full_ref.flatten(), pi_full_samp.flatten()) ** 2
        )
        q = pi_full_samp.flatten()
        p = pi_full_ref.flatten()
        q = np.maximum(q, 1e-5)
        p /= p.sum()
        q /= q.sum()

        out["feature_MSM"][f"{','.join(feat)}_KL"] = entropy(p, q)

    feats, traj = mdgen.analysis.get_featurized_traj(
        f"{args.pdbdir}/{name}", sidechains=True, cossin=False
    )

    # if args.seperation_steps:
    #     traj = traj[:: args.seperation_steps]
    # if args.truncate:
    #     traj = traj[: args.truncate]
    # feats, ref = mdgen.analysis.get_featurized_traj_replicas(
    #     args.mddir, name, temp=args.temp, sidechains=True, cossin=False
    # )

    # ref = np.concatenate(ref)
    # out["features"] = feats.describe()

    # out["JSD"] = {}
    # for i, feat in enumerate(feats.describe()):
    #     ref_p = np.histogram(ref[:, i], range=(-np.pi, np.pi), bins=100)[0]
    #     traj_p = np.histogram(traj[:, i], range=(-np.pi, np.pi), bins=100)[0]
    #     out["JSD"][feat] = jensenshannon(ref_p, traj_p)

    # for i in [1, 3]:
    #     ref_p = np.histogram2d(
    #         *ref[:, i : i + 2].T, range=((-np.pi, np.pi), (-np.pi, np.pi)), bins=50
    #     )[0]
    #     traj_p = np.histogram2d(
    #         *traj[:, i : i + 2].T, range=((-np.pi, np.pi), (-np.pi, np.pi)), bins=50
    #     )[0]
    #     out["JSD"]["|".join(feats.describe()[i : i + 2])] = jensenshannon(
    #         ref_p.flatten(), traj_p.flatten()
    #     )

    ####### TICA #############
    # feats, traj = mdgen.analysis.get_featurized_traj(
    #     f"{args.pdbdir}/{name}", sidechains=True, cossin=True
    # )
    # if not args.notica:
    #     if args.temp == 320:
    #         try:
    #             ca_threshold = 388

    #             feats, traj = mdgen.analysis.get_featurized_traj_mdcath_sampled(
    #                 args.pdbdir, name, ca_threshold
    #             )
    #             if args.seperation_steps:
    #                 traj = traj[:: args.seperation_steps]
    #             if args.truncate:
    #                 traj = traj[: args.truncate]
    #             _, ref = mdgen.analysis.get_featurized_traj_mdcath(
    #                 args.mddir, name, args.temp, ca_threshold
    #             )
    #             tica = mdgen.analysis.load_tica_model(
    #                 name, args.temp, args.tica_models_path
    #             )
    #             traj_tica = tica.transform(traj)
    #             ref_tica = [tica.transform(r) for r in ref]
    #             ref_tica = np.concatenate(ref_tica)
    #         except:
    #             ca_threshold = 250
    #             feats, traj = mdgen.analysis.get_featurized_traj_mdcath_sampled(
    #                 args.pdbdir, name, ca_threshold
    #             )
    #             if args.seperation_steps:
    #                 traj = traj[:: args.seperation_steps]
    #             if args.truncate:
    #                 traj = traj[: args.truncate]
    #             _, ref = mdgen.analysis.get_featurized_traj_mdcath(
    #                 args.mddir, name, args.temp, ca_threshold
    #             )
    #             tica = mdgen.analysis.load_tica_model(
    #                 name, args.temp, args.tica_models_path
    #             )
    #             traj_tica = tica.transform(traj)
    #             ref_tica = [tica.transform(r) for r in ref]
    #             ref_tica = np.concatenate(ref_tica)
    #     else:
    #         ca_threshold = 2
    #         feats, traj = mdgen.analysis.get_featurized_traj_mdcath_sampled(
    #             args.pdbdir, name, ca_threshold
    #         )
    #         if args.seperation_steps:
    #             traj = traj[:: args.seperation_steps]
    #         if args.truncate:
    #             traj = traj[: args.truncate]
    #         _, ref = mdgen.analysis.get_featurized_traj_mdcath(
    #             args.mddir, name, args.temp, ca_threshold
    #         )
    #         tica = mdgen.analysis.load_tica_model(
    #             name, args.temp, args.tica_models_path
    #         )
    #         traj_tica = tica.transform(traj)
    #         ref_tica = [tica.transform(r) for r in ref]
    #         ref_tica = np.concatenate(ref_tica)

    #     tica_0_min = min(ref_tica[:, 0].min(), traj_tica[:, 0].min())
    #     tica_0_max = max(ref_tica[:, 0].max(), traj_tica[:, 0].max())

    #     tica_1_min = min(ref_tica[:, 1].min(), traj_tica[:, 1].min())
    #     tica_1_max = max(ref_tica[:, 1].max(), traj_tica[:, 1].max())

    #     ref_p = np.histogram(ref_tica[:, 0], range=(tica_0_min, tica_0_max), bins=100)[
    #         0
    #     ]
    #     traj_p = np.histogram(
    #         traj_tica[:, 0], range=(tica_0_min, tica_0_max), bins=100
    #     )[0]
    #     out["JSD"]["TICA-0"] = jensenshannon(ref_p, traj_p)

    #     ref_p = np.histogram2d(
    #         *ref_tica[:, :2].T,
    #         bins=50,
    #         range=((tica_0_min, tica_0_max), (tica_1_min, tica_1_max)),
    #     )[0]
    #     traj_p = np.histogram2d(
    #         *traj_tica[:, :2].T,
    #         bins=50,
    #         range=((tica_0_min, tica_0_max), (tica_1_min, tica_1_max)),
    #     )[0]
    #     out["JSD"]["TICA-0,1"] = jensenshannon(ref_p.flatten(), traj_p.flatten())

    #     ###### Markov state model stuff #################
    #     feats, traj = mdgen.analysis.get_featurized_traj_mdcath_sampled(
    #         args.pdbdir, name, ca_threshold
    #     )
    #     if args.seperation_steps:
    #         traj = traj[:: args.seperation_steps]
    #     if args.truncate:
    #         traj = traj[: args.truncate]
    #     _, ref = mdgen.analysis.get_featurized_traj_mdcath(
    #         args.mddir,
    #         name,
    #         args.temp,
    #         ca_threshold,
    #     )

    #     kmeans, ref_kmeans = mdgen.analysis.get_kmeans_replicas(
    #         [tica.transform(r) for r in ref]
    #     )
    #     try:
    #         msm_success = False
    #         for candidate in args.numstates_tica:
    #             try:
    #                 metastable_assignments, msm_transition_matrix, msm_pi = (
    #                     mdgen.analysis.get_msm_mdcath_deeptime(
    #                         ref_kmeans, lag=args.msm_lag, nstates=candidate
    #                     )
    #                 )
    #                 print(f"numstates_tica candidate {candidate} succeeded. {name}")
    #                 successful_candidates = candidate
    #                 msm_success = True
    #                 break
    #             except Exception as e:
    #                 print(
    #                     f"ERROR with numstates_tica candidate {candidate}: {e}",
    #                     flush=True,
    #                 )

    #         if not msm_success:
    #             raise ValueError(
    #                 "None of the provided numstates_tica values succeeded!"
    #             )

    #         out["kmeans"] = kmeans

    #         traj_discrete = mdgen.analysis.discretize_deeptime(
    #             tica.transform([traj]), kmeans, metastable_assignments
    #         )
    #         ref_discrete = mdgen.analysis.discretize_deeptime(
    #             [tica.transform(r) for r in ref], kmeans, metastable_assignments
    #         )
    #         out["traj_metastable_probs"] = (
    #             traj_discrete == np.arange(successful_candidates)[:, None]
    #         ).mean(1)
    #         out["ref_metastable_probs"] = (
    #             np.concatenate(ref_discrete)
    #             == np.arange(successful_candidates)[:, None]
    #         ).mean(1)
    #         #########

    #         # set constants
    #         kB = 0.0019872041
    #         T = args.temp
    #         kBT = kB * T

    #         ref_pi = np.maximum(out["ref_metastable_probs"], 1e-4)
    #         traj_pi = np.maximum(out["traj_metastable_probs"], 1e-4)

    #         G_ref = -kBT * np.log(ref_pi)
    #         G_traj = -kBT * np.log(traj_pi)

    #         mMAE = np.mean(np.abs(G_traj - G_ref))

    #         out["mMAE"] = mMAE
    #         out["G_ref"] = G_ref
    #         out["G_traj"] = G_traj

    #         out["msm_transition_matrix"] = msm_transition_matrix
    #         out["msm_pi"] = msm_pi

    #         traj_msm = pyemma.msm.estimate_markov_model(traj_discrete, lag=args.msm_lag)
    #         out["traj_msm"] = traj_msm

    #         traj_transition_matrix = np.eye(successful_candidates)
    #         for a, i in enumerate(traj_msm.active_set):
    #             for b, j in enumerate(traj_msm.active_set):
    #                 traj_transition_matrix[i, j] = traj_msm.transition_matrix[a, b]
    #         out["traj_transition_matrix"] = traj_transition_matrix

    #         traj_pi = np.zeros(successful_candidates)
    #         traj_pi[traj_msm.active_set] = traj_msm.pi
    #         out["traj_pi"] = traj_pi

    #         ref_flux = msm_transition_matrix * msm_pi[:, None]
    #         traj_flux = traj_transition_matrix * traj_pi[:, None]
    #         corr, pval = spearmanr(ref_flux.flatten(), traj_flux.flatten())
    #         out["flux_spearman_corr"] = corr
    #         out["flux_spearman_pval"] = pval

    #     except Exception as e:
    #         print("ERROR", e, name)

    return name, out


if args.pdb_id:
    pdb_id = args.pdb_id
else:
    pdb_id = [
        nam.split(".")[0]
        for nam in os.listdir(args.pdbdir)
        if ".pdb" in nam and not "_traj" in nam
    ]
pdb_id = [nam for nam in pdb_id if os.path.exists(f"{args.pdbdir}/{nam}.xtc")]

# filter out trajectories not in split test
df = pd.read_csv(args.split, index_col="name")

pdb_id = [nam for nam in pdb_id if nam in df.index]

import csv


def load_domain_sequences():
    """Load domain sequences from a list of CSV files.
    Assumes CSV files have columns 'name' and 'seqres'."""
    csv_files = [
        "splits/mdCATH_train.csv",
        "splits/mdCATH_val.csv",
        "splits/mdCATH_test.csv",
    ]
    domain_to_seq = {}
    for csv_file in csv_files:
        if not os.path.exists(csv_file):
            print(f"File {csv_file} not found. Skipping.")
            continue
        with open(csv_file, "r", newline="") as f:
            reader = csv.DictReader(f)
            for row in reader:
                name = row["name"]
                seq = row["seqres"]
                if seq:
                    domain_to_seq[name] = len(seq)
    return domain_to_seq


domain_to_seq = load_domain_sequences()
# Filter domains based on sequence length
if args.minimum_length or args.maximum_length:
    if args.minimum_length:
        pdb_id = [
            name for name in pdb_id if domain_to_seq.get(name, 0) >= args.minimum_length
        ]
    if args.maximum_length:
        pdb_id = [
            name for name in pdb_id if domain_to_seq.get(name, 0) <= args.maximum_length
        ]


# pdb_id = pdb_id[12:16]
# pdb_id = ["2faoA01"]
# filter out trajectories based on seqres
try:
    pdb_id.remove("1ia5A00")
except:
    pass
try:
    pdb_id.remove("3amrA00")
except:
    pass
print("number of trajectories", len(pdb_id))
print(
    "average length of trajectories", np.mean([domain_to_seq[name] for name in pdb_id])
)
if args.overfit_peptide:
    pdb_id = args.overfit_peptide.split(",")


# Define the path for the output file
save_path = f"{args.pdbdir}/{args.save_name}"
if args.truncate and args.truncate != 500:
    save_path = f"{args.pdbdir}/{args.save_name.split('.')[0]}_trunc{args.truncate}.pkl"

if args.gen_replicas is not None:
    save_path = f"{args.gen_replicas_savedir}/{args.save_name}"
    if args.truncate and args.truncate != 500:
        save_path = f"{args.gen_replicas_savedir}/{args.save_name.split('.')[0]}_trunc{args.truncate}.pkl"

# Load existing results if the file exists and we're not overwriting
# existing_out = {}
# if os.path.exists(save_path):
#     with open(save_path, "rb") as f:
#         existing_out = pickle.load(f)
#     print(
#         f"Found existing output with {len(existing_out)} entries. Skipping these tasks."
#     )
#     # Remove tasks that already have results
#     pdb_id = [id_ for id_ in pdb_id if id_ not in existing_out]

# Run your tasks in parallel
if args.num_workers > 1:
    p = Pool(args.num_workers)
    p.__enter__()
    __map__ = p.imap
else:
    __map__ = map

# Process the tasks (each call to main returns a (name, out) tuple)
out_new = dict(tqdm.tqdm(__map__(main, pdb_id), total=len(pdb_id)))

if args.num_workers > 1:
    p.__exit__(None, None, None)

# Merge existing and new results
# final_out = {**existing_out, **out_new}
final_out = {**out_new}

# Save the combined results
with open(save_path, "wb") as f:
    print("saving", save_path)
    f.write(pickle.dumps(final_out))
