import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--split", type=str, default="splits/mdCATH.txt")
parser.add_argument(
    "--sim_dir",
    type=str,
    default="anonymous",
)
parser.add_argument(
    "--outdir",
    type=str,
    default="anonymous",
)
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--suffix", type=str, default="")
parser.add_argument("--stride", type=int, default=1)
args = parser.parse_args()

import mdtraj, os, tqdm
import pandas as pd
from multiprocessing import Pool
import numpy as np
from mdgen import residue_constants as rc

os.makedirs(args.outdir, exist_ok=True)

with open(args.split, "r") as file:
    names = [line.strip() for line in file]


def main():
    jobs = []
    for name in names:
        if os.path.exists(f"{args.outdir}/{name}{args.suffix}.npy"):
            continue
        jobs.append(name)

    jobs = ["1a02F00"]
    if args.num_workers > 1:
        p = Pool(args.num_workers)
        p.__enter__()
        __map__ = p.imap
    else:
        __map__ = map
    for _ in tqdm.tqdm(__map__(do_job, jobs), total=len(jobs)):
        pass
    if args.num_workers > 1:
        p.__exit__(None, None, None)


def traj_to_atom14(traj):
    arr = np.zeros((traj.n_frames, traj.n_residues, 14, 3), dtype=np.float16)
    for i, resi in enumerate(traj.top.residues):
        res_name = resi.name
        for at in resi.atoms:
            atom_name = at.name
            if res_name == "MSE" and atom_name == "SE":
                print("Found MSE")
                break
                atom_name = "SD"  # Replace selenium atom with sulfur
            if res_name == "ILE" and atom_name == "CD":
                atom_name = "CD1"  # Handle mismatch in naming for ILE residue
            if res_name in ["HSD", "HSE", "HSP"]:
                res_name = "HIS"

            # Check if the atom exists in the 14-atom mapping
            if res_name not in rc.restype_name_to_atom14_names:
                continue
            if atom_name not in rc.restype_name_to_atom14_names[res_name]:
                # print(f"{res_name} {atom_name} not found in atom14 mapping")
                continue

            # Map atom coordinates
            j = rc.restype_name_to_atom14_names[res_name].index(atom_name)
            arr[:, i, j] = traj.xyz[:, at.index] * 10.0  # Convert to Ångström scale
    return arr


def do_job(name):
    for temp in [320, 348, 379, 413, 450]:
        for i in [0, 1, 2, 3, 4]:
            traj = mdtraj.load(
                f"{args.sim_dir}/trajectory/{name}_{temp}_{i}.xtc",
                top=f"{args.sim_dir}/topology/{name}.pdb",
            )
            traj.atom_slice(
                [a.index for a in traj.top.atoms if a.element.symbol != "H"], True
            )
            traj.superpose(traj)
            arr = traj_to_atom14(traj)
            print("arr.shape", arr.shape)
            np.save(f"{args.outdir}/{name}_{temp}_{i}.npy", arr[:: args.stride])


if __name__ == "__main__":
    main()
