import argparse

parser = argparse.ArgumentParser()
parser.add_argument(
    "--mdcath_dir",
    type=str,
    default="anonymise",
)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument(
    "--mdcath_processed_dir",
    type=str,
    default="anonymise",
)
parser.add_argument("--split", type=str, default="splits/mdCATH_test.csv")
parser.add_argument("--pdbdir", type=str, required=True)
parser.add_argument("--pdb_id", nargs="*", default=[])
parser.add_argument("--bb_only", action="store_true")
parser.add_argument("--ca_only", action="store_true")
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--temp", type=int, default=320)
parser.add_argument("--xtc", action="store_true")
parser.add_argument("--truncate", type=int, default=None)
parser.add_argument(
    "--gen_replicas",
    type=int,
    nargs="+",
    default=None,
    help=(
        "Index(es) of MD‑CATH replicas (0–4) to treat as the *generated* ensemble. "
        "All remaining replicas become the reference MD ensemble. "
        "If omitted, the script behaves exactly as before (AF2 vs. all 5 replicas)."
    ),
)

parser.add_argument(
    "--gen_replicas_savedir",
    type=str,
    default="./workdir/MD_baselines_4",
)

parser.add_argument(
    "--no_distributional", action=argparse.BooleanOptionalAction, default=False
)
parser.add_argument(
    "--no_observables", action=argparse.BooleanOptionalAction, default=False
)
args = parser.parse_args()
from sklearn.decomposition import PCA
from multiprocessing import Manager, Pool
import functools
import operator

import mdtraj, pickle, tqdm, os
import numpy as np
from multiprocessing import Pool
from scipy.optimize import linear_sum_assignment
from scipy.spatial.distance import jensenshannon
import os
import pandas as pd
import torch
from mdgen.utils import atom14_to_pdb
from mdgen.residue_constants import restype_order
from contextlib import contextmanager
import time

# from memory_profiler import profile


# @contextmanager
# def timer(label: str):
#     start = time.perf_counter()  # more precise than time.time()
#     try:
#         yield
#     finally:
#         end = time.perf_counter()
#         print(f"[{label}] took {end - start:.4f} seconds")
from mdgen.utils import set_seed

set_seed(args.seed)


def get_pca(xyz):
    traj_reshaped = xyz.reshape(xyz.shape[0], -1)
    pca = PCA(n_components=min(traj_reshaped.shape))
    coords = pca.fit_transform(traj_reshaped)
    return pca, coords


def get_rmsds(traj1, traj2, broadcast=False):
    n_atoms = traj1.shape[1]
    traj1 = traj1.reshape(traj1.shape[0], n_atoms * 3)
    traj2 = traj2.reshape(traj2.shape[0], n_atoms * 3)
    if broadcast:
        traj1, traj2 = traj1[:, None], traj2[None]
    distmat = np.square(traj1 - traj2).sum(-1) ** 0.5 / n_atoms**0.5 * 10
    return distmat


def condense_sidechain_sasas(sasas, top):
    assert top.n_residues > 1

    if top.n_atoms != sasas.shape[1]:
        raise exception.DataInvalid(
            f"The number of atoms in top ({top.n_atoms}) didn't match the "
            f"number of SASAs provided ({sasas.shape[1]}). Make sure you "
            f"computed atom-level SASAs (mode='atom') and that you've passed "
            "the correct topology file and array of SASAs"
        )

    sc_mask = np.array([a.name not in ["CA", "C", "N", "O", "OXT"] for a in top.atoms])
    res_id = np.array([a.residue.index for a in top.atoms])

    rsd_sasas = np.zeros((sasas.shape[0], top.n_residues), dtype="float32")

    for i in range(top.n_residues):
        rsd_sasas[:, i] = sasas[:, sc_mask & (res_id == i)].sum(1)
    return rsd_sasas


def sasa_mi(sasa):
    N, L = sasa.shape
    joint_probs = np.zeros((L, L, 2, 2))

    joint_probs[:, :, 1, 1] = (sasa[:, :, None] & sasa[:, None, :]).mean(0)
    joint_probs[:, :, 1, 0] = (sasa[:, :, None] & ~sasa[:, None, :]).mean(0)
    joint_probs[:, :, 0, 1] = (~sasa[:, :, None] & sasa[:, None, :]).mean(0)
    joint_probs[:, :, 0, 0] = (~sasa[:, :, None] & ~sasa[:, None, :]).mean(0)

    marginal_probs = np.stack([1 - sasa.mean(0), sasa.mean(0)], -1)
    indep_probs = marginal_probs[None, :, None, :] * marginal_probs[:, None, :, None]
    mi = np.nansum(joint_probs * np.log(joint_probs / indep_probs), (-1, -2))
    mi[np.arange(L), np.arange(L)] = 0
    return mi


def get_mean_covar(xyz):
    mean = xyz.mean(0)
    xyz = xyz - mean
    covar = (xyz[..., None] * xyz[..., None, :]).mean(0)
    return mean, covar


def sqrtm(M):
    D, P = np.linalg.eig(M)
    out = (P * np.sqrt(D[:, None])) @ np.linalg.inv(P)
    return out


def get_wasserstein(distmat, p=2):
    assert distmat.shape[0] == distmat.shape[1]
    distmat = distmat**p
    row_ind, col_ind = linear_sum_assignment(distmat)
    return distmat[row_ind, col_ind].mean() ** (1 / p)


def get_jsd(a, b, bins=50):
    """Return Jensen-Shannon divergence between two 1-D samples.

    Both samples are histogrammed with `bins` equally-spaced bins whose
    limits are the global min / max of the two distributions (the *min–max*
    rule the user asked for).  A tiny ε is added so we never divide by 0.
    The square of `scipy.spatial.distance.jensenshannon` is the actual JSD.
    """
    mn, mx = float(min(a.min(), b.min())), float(max(a.max(), b.max()))
    h1, _ = np.histogram(a, bins=bins, range=(mn, mx), density=True)
    h2, _ = np.histogram(b, bins=bins, range=(mn, mx), density=True)
    eps = 1e-12
    h1 = (h1 + eps) / (h1 + eps).sum()
    h2 = (h2 + eps) / (h2 + eps).sum()
    return jensenshannon(h1, h2, base=2.0) ** 2


def align_tops(top1, top2):
    names1 = [repr(a) for a in top1.atoms]
    names2 = [repr(a) for a in top2.atoms]

    intersection = [nam for nam in names1 if nam in names2]

    mask1 = [names1.index(nam) for nam in intersection]
    mask2 = [names2.index(nam) for nam in intersection]
    return mask1, mask2


def main(name, seqres):
    # with timer("load"):
    out = {}
    topfile = f"{args.mdcath_processed_dir}/{name}_analysis.pdb"
    xtc_files = []
    pdb_files = []
    for i in range(5):
        xtc_files.append(
            f"{args.mdcath_processed_dir}/{name}_{args.temp}_{i}_analysis.xtc"
        )
        pdb_files.append(
            f"{args.mdcath_processed_dir}/{name}_{args.temp}_{i}_analysis.pdb"
        )

    all_files_exist = all([os.path.exists(f) for f in xtc_files + [topfile]])
    if not all_files_exist:
        seqres = torch.tensor([restype_order[c] for c in seqres])
        all_atom_14s = []
        for i in range(5):
            all_atom_14s.append(
                np.load(f"{args.mdcath_processed_dir}/{name}_{args.temp}_{i}.npy")
            )
            atom14_to_pdb(all_atom_14s[i], seqres[None][0].numpy(), pdb_files[i])
            traj = mdtraj.load(pdb_files[i])
            traj.superpose(traj)
            traj.save(xtc_files[i])
            traj[0].save(topfile)
    try:
        replicate_trajs = [mdtraj.load(xtc_files[i], top=topfile) for i in range(5)]

        if args.gen_replicas is not None:
            gen_set = set(args.gen_replicas)
            if not gen_set.issubset({0, 1, 2, 3, 4}):
                raise ValueError("--gen_replicas must be indices 0–4")
            if len(gen_set) == 0 or len(gen_set) == 5:
                raise ValueError("--gen_replicas must specify 1–4 replicas")
            aftraj_aa = functools.reduce(
                operator.add, [replicate_trajs[i] for i in gen_set]
            )
            traj_aa = functools.reduce(
                operator.add, [replicate_trajs[i] for i in range(5) if i not in gen_set]
            )
            ref_aa = traj_aa[0]
            use_af2 = False

        else:
            traj_aa = functools.reduce(operator.add, replicate_trajs)
            ref_aa = traj_aa[0]
            use_af2 = True

    except:
        return None, None

    # print(f"Loaded {traj_aa.n_frames} reference frames")

    # print("Loading AF2 conformers")
    if use_af2:
        if not args.xtc:
            aftraj_aa = mdtraj.load(f"{args.pdbdir}/{name}.pdb")
        else:
            try:
                aftraj_aa = mdtraj.load(
                    f"{args.pdbdir}/{name}.xtc", top=f"{args.pdbdir}/{name}.pdb"
                )
            except:
                return None, None

    # print("Length before truncation:", aftraj_aa.n_frames)
    if args.truncate:
        aftraj_aa = aftraj_aa[: args.truncate]
    # print("Length after truncation:", aftraj_aa.n_frames)

    # traj_aa = traj_aa[400:500]
    # aftraj_aa = aftraj_aa[400:500]

    # with timer("load"):

    # print(f"Loaded {aftraj_aa.n_frames} AF2 conformers")
    # print(f"Reference has {traj_aa.n_atoms} atoms")
    # print(f"Crystal has {ref_aa.n_atoms} atoms")
    # print(f"AF has {aftraj_aa.n_atoms} atoms")

    # print("Removing hydrogens")

    traj_aa.atom_slice(
        [a.index for a in traj_aa.top.atoms if a.element.symbol != "H"], True
    )
    ref_aa.atom_slice(
        [a.index for a in ref_aa.top.atoms if a.element.symbol != "H"], True
    )
    aftraj_aa.atom_slice(
        [a.index for a in aftraj_aa.top.atoms if a.element.symbol != "H"], True
    )

    # print(f"Reference has {traj_aa.n_atoms} atoms")
    # print(f"Crystal has {ref_aa.n_atoms} atoms")
    # print(f"AF has {aftraj_aa.n_atoms} atoms")

    if args.bb_only:
        # print("Removing sidechains")
        aftraj_aa.atom_slice(
            [
                a.index
                for a in aftraj_aa.top.atoms
                if a.name in ["CA", "C", "N", "O", "OXT"]
            ],
            True,
        )
        # print(f"AF has {aftraj_aa.n_atoms} atoms")

    elif args.ca_only:
        # print("Removing sidechains")
        aftraj_aa.atom_slice(
            [a.index for a in aftraj_aa.top.atoms if a.name == "CA"], True
        )
        # print(f"AF has {aftraj_aa.n_atoms} atoms")

    refmask, afmask = align_tops(traj_aa.top, aftraj_aa.top)
    traj_aa.atom_slice(refmask, True)
    ref_aa.atom_slice(refmask, True)
    aftraj_aa.atom_slice(afmask, True)

    # print(f"Aligned on {aftraj_aa.n_atoms} atoms")

    np.random.seed(137)
    RAND1 = np.random.randint(0, traj_aa.n_frames, aftraj_aa.n_frames)
    RAND2 = np.random.randint(0, traj_aa.n_frames, aftraj_aa.n_frames)
    RAND1K = np.random.randint(0, traj_aa.n_frames, 1000)

    traj_aa.superpose(ref_aa)
    aftraj_aa.superpose(ref_aa)

    out["ca_mask"] = ca_mask = [a.index for a in traj_aa.top.atoms if a.name == "CA"]
    traj = traj_aa.atom_slice(ca_mask, False)
    ref = ref_aa.atom_slice(ca_mask, False)
    aftraj = aftraj_aa.atom_slice(ca_mask, False)
    # print(f"Sliced {aftraj.n_atoms} C-alphas")

    traj.superpose(ref)
    aftraj.superpose(ref)

    n_atoms = aftraj.n_atoms

    # print(f"Doing PCA")
    # with timer("pca"):
    if not args.no_distributional:
        ref_pca, ref_coords = get_pca(traj.xyz)
        af_coords_ref_pca = ref_pca.transform(aftraj.xyz.reshape(aftraj.n_frames, -1))
        seed_coords_ref_pca = ref_pca.transform(ref.xyz.reshape(1, -1))

        af_pca, af_coords = get_pca(aftraj.xyz)
        ref_coords_af_pca = af_pca.transform(traj.xyz.reshape(traj.n_frames, -1))
        seed_coords_af_pca = af_pca.transform(ref.xyz.reshape(1, -1))

        joint_pca, _ = get_pca(np.concatenate([traj[RAND1].xyz, aftraj.xyz]))
        af_coords_joint_pca = joint_pca.transform(
            aftraj.xyz.reshape(aftraj.n_frames, -1)
        )
        ref_coords_joint_pca = joint_pca.transform(traj.xyz.reshape(traj.n_frames, -1))
        seed_coords_joint_pca = joint_pca.transform(ref.xyz.reshape(1, -1))

        out["ref_variance"] = ref_pca.explained_variance_ / n_atoms * 100
        out["af_variance"] = af_pca.explained_variance_ / n_atoms * 100
        out["joint_variance"] = joint_pca.explained_variance_ / n_atoms * 100

    out["af_rmsf"] = mdtraj.rmsf(aftraj_aa, ref_aa) * 10
    out["ref_rmsf"] = mdtraj.rmsf(traj_aa, ref_aa) * 10
    out["jsd_rmsf"] = get_jsd(out["ref_rmsf"], out["af_rmsf"])

    # print(f"Computing atomic EMD")
    if not args.no_distributional:
        ref_mean, ref_covar = get_mean_covar(traj_aa[RAND1K].xyz)
        af_mean, af_covar = get_mean_covar(aftraj_aa.xyz)
        out["emd_mean"] = (np.square(ref_mean - af_mean).sum(-1) ** 0.5) * 10
        try:
            out["emd_var"] = (
                np.trace(
                    ref_covar + af_covar - 2 * sqrtm(ref_covar @ af_covar),
                    axis1=1,
                    axis2=2,
                )
                ** 0.5
            ) * 10
        except:
            out["emd_var"] = np.trace(ref_covar) ** 0.5 * 10

    if not args.no_observables:
        # print(f"Analyzing SASA")
        sasa_thresh = 0.02
        # print("shrake rupley 1")
        af_sasa = mdtraj.shrake_rupley(aftraj_aa, probe_radius=0.28)
        af_sasa = condense_sidechain_sasas(af_sasa, aftraj_aa.top)
        # print("shrake rupley 2")
        ref_sasa = mdtraj.shrake_rupley(traj_aa[RAND1K], probe_radius=0.28)
        ref_sasa = condense_sidechain_sasas(ref_sasa, traj_aa.top)
        # print("shrake rupley 3")
        crystal_sasa = mdtraj.shrake_rupley(ref_aa, probe_radius=0.28)
        out["crystal_sasa"] = condense_sidechain_sasas(crystal_sasa, ref_aa.top)

        out["ref_sa_prob"] = (ref_sasa > sasa_thresh).mean(0)
        out["af_sa_prob"] = (af_sasa > sasa_thresh).mean(0)
        out["ref_mi_mat"] = sasa_mi(ref_sasa > sasa_thresh)
        out["af_mi_mat"] = sasa_mi(af_sasa > sasa_thresh)

        ref_distmat = np.linalg.norm(
            traj[RAND1].xyz[:, None, :] - traj[RAND1].xyz[:, :, None], axis=-1
        )
        af_distmat = np.linalg.norm(
            aftraj.xyz[:, None, :] - aftraj.xyz[:, :, None], axis=-1
        )

        out["ref_contact_prob"] = (ref_distmat < 0.8).mean(0)
        out["af_contact_prob"] = (af_distmat < 0.8).mean(0)
        out["crystal_distmat"] = np.linalg.norm(
            ref.xyz[0, None, :] - ref.xyz[0, :, None], axis=-1
        )

    ref_pw = get_rmsds(traj[RAND1].xyz, traj[RAND2].xyz, broadcast=True).ravel()
    af_pw = get_rmsds(aftraj.xyz, aftraj.xyz, broadcast=True).ravel()
    out["jsd_pairwise_rmsd"] = get_jsd(ref_pw, af_pw)
    out["ref_mean_pairwise_rmsd"] = get_rmsds(
        traj[RAND1].xyz, traj[RAND2].xyz, broadcast=True
    ).mean()
    out["af_mean_pairwise_rmsd"] = get_rmsds(
        aftraj.xyz, aftraj.xyz, broadcast=True
    ).mean()

    out["ref_rms_pairwise_rmsd"] = (
        np.square(get_rmsds(traj[RAND1].xyz, traj[RAND2].xyz, broadcast=True)).mean()
        ** 0.5
    )
    out["af_rms_pairwise_rmsd"] = (
        np.square(get_rmsds(aftraj.xyz, aftraj.xyz, broadcast=True)).mean() ** 0.5
    )

    out["ref_self_mean_pairwise_rmsd"] = get_rmsds(
        traj[RAND1].xyz, traj[RAND1].xyz, broadcast=True
    ).mean()
    out["ref_self_rms_pairwise_rmsd"] = (
        np.square(get_rmsds(traj[RAND1].xyz, traj[RAND1].xyz, broadcast=True)).mean()
        ** 0.5
    )
    if not args.no_distributional:
        out["cosine_sim"] = (ref_pca.components_[0] * af_pca.components_[0]).sum()

    def get_emd(ref_coords1, ref_coords2, af_coords, seed_coords, K=None):
        if len(ref_coords1.shape) == 3:
            ref_coords1 = ref_coords1.reshape(ref_coords1.shape[0], -1)
            ref_coords2 = ref_coords2.reshape(ref_coords2.shape[0], -1)
            af_coords = af_coords.reshape(af_coords.shape[0], -1)
            seed_coords = seed_coords.reshape(seed_coords.shape[0], -1)
        if K is not None:
            ref_coords1 = ref_coords1[:, :K]
            ref_coords2 = ref_coords2[:, :K]
            af_coords = af_coords[:, :K]
            seed_coords = seed_coords[:, :K]
        emd = {}
        emd["ref|ref mean"] = (
            (np.square(ref_coords1 - ref_coords1.mean(0)).sum(-1)).mean() ** 0.5
            / n_atoms**0.5
            * 10
        )

        distmat = np.square(ref_coords1[:, None] - ref_coords2[None]).sum(-1)
        distmat = distmat**0.5 / n_atoms**0.5 * 10
        emd["ref|ref2"] = get_wasserstein(distmat)
        emd["ref mean|ref2 mean"] = (
            np.square(ref_coords1.mean(0) - ref_coords2.mean(0)).sum() ** 0.5
            / n_atoms**0.5
            * 10
        )

        distmat = np.square(ref_coords1[:, None] - af_coords[None]).sum(-1)
        distmat = distmat**0.5 / n_atoms**0.5 * 10
        emd["ref|af"] = get_wasserstein(distmat)
        emd["ref mean|af mean"] = (
            np.square(ref_coords1.mean(0) - af_coords.mean(0)).sum() ** 0.5
            / n_atoms**0.5
            * 10
        )

        emd["ref|seed"] = (
            (np.square(ref_coords1 - seed_coords).sum(-1)).mean() ** 0.5
            / n_atoms**0.5
            * 10
        )
        emd["ref mean|seed"] = (
            (np.square(ref_coords1.mean(0) - seed_coords).sum(-1)).mean() ** 0.5
            / n_atoms**0.5
            * 10
        )

        emd["af|seed"] = (
            (np.square(af_coords - seed_coords).sum(-1)).mean() ** 0.5
            / n_atoms**0.5
            * 10
        )
        emd["af|af mean"] = (
            (np.square(af_coords - af_coords.mean(0)).sum(-1)).mean() ** 0.5
            / n_atoms**0.5
            * 10
        )
        emd["af mean|seed"] = (
            (np.square(af_coords.mean(0) - seed_coords).sum(-1)).mean() ** 0.5
            / n_atoms**0.5
            * 10
        )
        return emd

    K = 2
    if not args.no_distributional:
        out[f"EMD,ref"] = get_emd(
            ref_coords[RAND1],
            ref_coords[RAND2],
            af_coords_ref_pca,
            seed_coords_ref_pca,
            K=K,
        )
        out[f"EMD,af2"] = get_emd(
            ref_coords_af_pca[RAND1],
            ref_coords_af_pca[RAND2],
            af_coords,
            seed_coords_af_pca,
            K=K,
        )
        out[f"EMD,joint"] = get_emd(
            ref_coords_joint_pca[RAND1],
            ref_coords_joint_pca[RAND2],
            af_coords_joint_pca,
            seed_coords_joint_pca,
            K=K,
        )
    return name, out


def wrapper_main(args, finished):
    name, result = main(*args)
    # Append finished task name to the shared list
    finished.append(name)
    return name, result


def monitor_pending(all_names, finished, interval=120):
    while True:
        pending = set(all_names) - set(finished)
        print("Pending tasks:", pending)
        # Sleep for the desired interval (in seconds)
        time.sleep(interval)


from multiprocessing import Manager, Pool
import threading


if __name__ == "__main__":

    if args.pdb_id:
        pdb_ids = args.pdb_id
    else:
        pdb_ids = [
            nam.split(".")[0] for nam in os.listdir(args.pdbdir) if ".pdb" in nam
        ]
        df = pd.read_csv(args.split, index_col="name")
        pdb_ids = [nam for nam in pdb_ids if nam in df.index]
    print("Number of PDBs:", len(pdb_ids))

    try:
        pdb_ids.remove("1gkgA02")
    except:
        pass

    # Define the output file path
    output_file = os.path.join(args.pdbdir, "out.pkl")
    if args.truncate and args.truncate != 500:
        output_file = os.path.join(args.pdbdir, f"out_{args.truncate}.pkl")

    if args.gen_replicas is not None:
        output_file = os.path.join(args.gen_replicas_savedir, "out.pkl")
        if args.truncate and args.truncate != 500:
            output_file = os.path.join(
                args.gen_replicas_savedir, f"out_trunc{args.truncate}.pkl"
            )

    # Check if file exists and if we should avoid overwriting
    # if os.path.exists(output_file):
    #     with open(output_file, "rb") as f:
    #         existing_out = pickle.load(f)
    #     print(
    #         f"Found existing output with {len(existing_out)} entries. Skipping these tasks."
    #     )
    #     # Filter out IDs that already have results
    #     pdb_ids = [id_ for id_ in pdb_ids if id_ not in existing_out]
    # else:
    #     existing_out = {}
    existing_out = {}

    # Read the CSV file
    df = pd.read_csv(args.split, index_col="name")
    tasks = [(id_, df.seqres.loc[id_]) for id_ in pdb_ids]

    # Set up a Manager list to track completed tasks
    manager = Manager()
    finished = manager.list()

    # Start a monitoring thread that prints pending tasks every minute (60 sec)
    monitor_thread = threading.Thread(
        target=monitor_pending, args=(pdb_ids, finished, 60), daemon=True
    )
    monitor_thread.start()

    # Use Pool.apply_async to process tasks in parallel while updating 'finished'
    with Pool(args.num_workers) as pool:
        results = [
            pool.apply_async(wrapper_main, args=(task, finished)) for task in tasks
        ]
        # Optionally, use tqdm to show overall progress
        for r in tqdm.tqdm(results):
            r.wait()

    # Collect results from the current run
    new_out = {
        name: result
        for name, result in (r.get() for r in results)
        if result is not None
    }
    # Merge new results with existing ones
    final_out = {**existing_out, **new_out}

    # Write the merged dictionary back to the output file
    with open(output_file, "wb") as f:
        pickle.dump(final_out, f)
    print(f"Results saved to {output_file}")
