import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--sim_ckpt_msm", type=str, default=None)
parser.add_argument("--sim_ckpt_mdgen", type=str, default=None, required=True)
parser.add_argument("--seed", type=int, default=42)

parser.add_argument("--data_dir", type=str, default=None, required=True)
parser.add_argument("--suffix", type=str, default="_i100")
parser.add_argument("--pdb_id", nargs="*", default=[])
parser.add_argument("--num_rollouts", type=int, default=10)
parser.add_argument("--out_dir", type=str, default=".")
parser.add_argument("--split", type=str, default="splits/4AA_test.csv")
parser.add_argument("--overfit_peptide", type=str, default=None)
parser.add_argument(
    "--offsets_from_noise", action=argparse.BooleanOptionalAction, default=True
)
parser.add_argument(
    "--torsions_from_noise", action=argparse.BooleanOptionalAction, default=True
)
parser.add_argument(
    "--replace_with_first",
    action="store_true",
    help="If set, replaces arr[0:lag_time] with arr[0].",
)
parser.add_argument(
    "--from_boltz", action="store_true", help="If set, loads from boltz trajectory."
)
parser.add_argument(
    "--calls_mdgen",
    type=int,
    default=1,
    help='How many times to call the "1000-at-once" rollout in a row, targetting short length dynamics.',
)
parser.add_argument(
    "--calls_msm",
    type=int,
    default=5,
    help='How many times to call the "single-at-a-time" rollout in a row, targetting long length dynamics.',
)
parser.add_argument(
    "--sequential_msm",
    action="store_true",
    help="If set, calls to the MSM model are sequential (each call updates the starting point). "
    "Otherwise (fixed mode), the starting input is used for each call so that noise generates different outputs.",
)
parser.add_argument(
    "--initial_calls_mdgen",
    type=int,
    default=0,
    help="If > 0, call the MDGen model this many times on the very first "
    "batch (derived from the starting structure) *before* any MSM calls "
    "in every rollout cycle.",
)
parser.add_argument("--truncate_mdgen", type=int, default=None)
parser.add_argument("--mdcath", action="store_true")
parser.add_argument("--temp", type=int, default=320)
parser.add_argument("--replica", type=int, default=0)
parser.add_argument("--do_not_overwrite", action="store_true")
parser.add_argument(
    "--num_tree_rollouts",
    type=int,
    default=0,
    help="Depth of the MSM tree. 0, use the old sequential sampler.",
)
parser.add_argument(
    "--max_msm_samples",
    type=int,
    default=None,
    help="Hard cap on the total number of MSM trajectories produced by the tree.",
)
parser.add_argument(
    "--tree_parallel_chunk",
    type=int,
    default=1000,
    help="Max #trajectories that are fed to the MSM in a single GPU call.",
)

args = parser.parse_args()

import os, torch, mdtraj, tqdm
import numpy as np
from mdgen.geometry import atom14_to_frames, atom14_to_atom37, atom37_to_torsions
from mdgen.residue_constants import restype_order
from mdgen.tensor_utils import tensor_tree_map
from mdgen.utils_ti import atom14_to_pdb
import pandas as pd

# Import the two wrappers using different names
from mdgen.wrapper import NewMDGenWrapper as NewMDGenWrapperMDGen
from mdgen.wrapper_st import NewMDGenWrapper as NewMDGenWrapperMSM


os.makedirs(args.out_dir, exist_ok=True)

from mdgen.utils import set_seed

set_seed(args.seed)


def get_batch_mdgen(name, seqres, num_frames):
    if args.mdcath:
        if args.from_boltz:
            arr = np.load(f"{args.data_dir}/{name}_boltz.npy")
        else:
            arr = np.load(
                f"{args.data_dir}/{name}{args.suffix}_{args.temp}_{args.replica}.npy"
            )
    else:
        arr = np.lib.format.open_memmap(f"{args.data_dir}/{name}{args.suffix}.npy", "r")
    arr = np.copy(arr[0:1]).astype(np.float32)

    frames = atom14_to_frames(torch.from_numpy(arr))
    seqres = torch.tensor([restype_order[c] for c in seqres])
    atom37 = torch.from_numpy(atom14_to_atom37(arr, seqres[None])).float()
    L = len(seqres)
    mask = torch.ones(L)

    torsions, torsion_mask = atom37_to_torsions(atom37, seqres[None])
    return {
        "torsions": torsions,
        "torsion_mask": torsion_mask[0],
        "trans": frames._trans,
        "rots": frames._rots._rot_mats,
        "seqres": seqres,
        "mask": mask,  # (L,)
    }


def rollout_msm(model, batch):
    atom14, _ = model.inference(batch)
    new_batch = {**batch}

    frames = atom14_to_frames(atom14[:, -1])
    new_batch["trans"] = frames._trans.unsqueeze(1)
    new_batch["rots"] = frames._rots._rot_mats.unsqueeze(1)
    atom37 = atom14_to_atom37(atom14[:, -1].cpu(), batch["seqres"].cpu())
    torsions, _ = atom37_to_torsions(atom37, batch["seqres"].cpu())
    new_batch["torsions"] = torsions.unsqueeze(1).cuda()

    return atom14, new_batch


def rollout_mdgen(model, batch, truncate_mdgen=None):

    expanded_batch = {
        "torsions": batch["torsions"].expand(-1, model.args.num_frames, -1, -1, -1),
        "torsion_mask": batch["torsion_mask"],
        "trans": batch["trans"].expand(-1, model.args.num_frames, -1, -1),
        "rots": batch["rots"].expand(-1, model.args.num_frames, -1, -1, -1),
        "seqres": batch["seqres"],
        "mask": batch["mask"],
    }
    atom14, _ = model.inference(expanded_batch)
    new_batch = {**batch}

    if truncate_mdgen is not None:
        atom14 = atom14[:, :truncate_mdgen]

    frames = atom14_to_frames(atom14[:, -1])
    new_batch["trans"] = frames._trans[None]
    new_batch["rots"] = frames._rots._rot_mats[None]
    atom37 = atom14_to_atom37(atom14[0, -1].cpu(), batch["seqres"][0].cpu())
    torsions, _ = atom37_to_torsions(atom37, batch["seqres"][0].cpu())
    new_batch["torsions"] = torsions[None, None].cuda()

    return atom14, new_batch


def tree_rollout_msm(
    model,
    root_batch,
    calls_msm: int,
    depth: int,
    max_samples: int | None = None,
    chunk: int = 256,
):
    frontier = [root_batch]
    collected = []
    total = 0

    for level in range(depth):
        print("Level", level)
        if not frontier:
            break

        # Build one big batch that contains *every* parent repeated calls_msm times but with interleaved samples
        mega_batch = {}
        M = len(frontier)
        N = calls_msm
        for key in frontier[0].keys():
            stacked = torch.cat([parent[key] for parent in frontier], dim=0)
            expanded = stacked.unsqueeze(1).repeat(*([1, N] + [1] * (stacked.ndim - 1)))
            permuted = expanded.permute(1, 0, *range(2, expanded.ndim))
            mega_batch[key] = permuted.reshape(N * M, *stacked.shape[1:])
        # ────────────────────────────────────────────────────────────────────

        # Split the big batch into manageable GPU chunks
        n = mega_batch["torsions"].shape[0]
        next_frontier = []
        for start in range(0, n, chunk):
            print("Start", start)
            end = min(start + chunk, n)
            sub_batch = {k: v[start:end] for k, v in mega_batch.items()}
            atom14, stacked_batch = rollout_msm(model, sub_batch)
            collected.append(atom14)
            total += atom14.shape[0]

            # Children for the next level
            for i in range(atom14.shape[0]):
                child = {k: v[i : i + 1] for k, v in stacked_batch.items()}
                next_frontier.append(child)

            if max_samples is not None and total >= max_samples:
                return collected  # hard cut-off

        frontier = next_frontier

    return collected


@torch.no_grad()
def main():
    if args.sim_ckpt_msm:
        model_msm = NewMDGenWrapperMSM.load_from_checkpoint(args.sim_ckpt_msm)
        model_msm.eval().to("cuda")
    if args.sim_ckpt_mdgen:
        model_mdgen = NewMDGenWrapperMDGen.load_from_checkpoint(args.sim_ckpt_mdgen)
        model_mdgen.eval().to("cuda")

    df = pd.read_csv(args.split, index_col="name")

    if args.do_not_overwrite and not args.overfit_peptide:
        all_names = [
            n
            for n in df.index
            if not os.path.exists(os.path.join(args.out_dir, f"{n}.pdb"))
        ]
    else:
        all_names = df.index

    print("Number of peptides to process:", len(all_names))
    for name in all_names:
        if args.pdb_id and name not in args.pdb_id:
            continue
        if args.overfit_peptide:
            name = args.overfit_peptide
        print("Processing peptide:", name)
        seqres = df.seqres[name]

        # Start with MDGen
        item = get_batch_mdgen(name, seqres, num_frames=model_mdgen.args.num_frames)
        batch = next(iter(torch.utils.data.DataLoader([item])))
        batch = tensor_tree_map(lambda x: x.cuda(), batch)
        all_atom14 = []

        for n_roll in range(args.num_rollouts):
            if args.calls_msm == 0 and args.calls_mdgen > 0:
                repeated_batch = {
                    k: v.repeat(args.calls_mdgen, *([1] * (v.ndim - 1)))
                    for k, v in batch.items()
                }
                atom14, stacked_batch = rollout_mdgen(
                    model_mdgen, repeated_batch, args.truncate_mdgen
                )
                all_atom14.append(atom14)

                # batch = {k: v[0:1] for k, v in stacked_batch.items()}

                continue

            #### STAGE 0
            if args.initial_calls_mdgen > 0:
                for _ in range(args.initial_calls_mdgen):
                    atom14, _ = rollout_mdgen(model_mdgen, batch, args.truncate_mdgen)
                    all_atom14.append(atom14)

            #### STAGE 1
            if args.num_tree_rollouts > 0 and args.calls_mdgen == 0:
                # --- NEW tree sampler -----------------------------------------
                tree_samples = tree_rollout_msm(
                    model_msm,
                    batch,
                    args.calls_msm,
                    args.num_tree_rollouts,
                    max_samples=args.max_msm_samples,
                    chunk=args.tree_parallel_chunk,
                )
                all_atom14.extend(tree_samples)
            else:
                msm_batches = []
                if args.sequential_msm:
                    for i in tqdm.trange(args.calls_msm):
                        atom14, batch_a = rollout_msm(model_msm, batch)
                        all_atom14.append(atom14)
                        msm_batches.append(batch_a)
                else:
                    repeated_batch = {
                        key: value.repeat(args.calls_msm, *([1] * (value.ndim - 1)))
                        for key, value in batch.items()
                    }
                    atom14, stacked_batch = rollout_msm(model_msm, repeated_batch)
                    all_atom14.append(atom14)
                    msm_batches = [
                        {k: v[i : i + 1] for k, v in stacked_batch.items()}
                        for i in range(args.calls_msm)
                    ]

                #### STAGE 2
                new_batches = []
                for m_batch in msm_batches:
                    current_batch = m_batch
                    for i in tqdm.trange(args.calls_mdgen):
                        atom14, current_batch = rollout_mdgen(
                            model_mdgen, current_batch, args.truncate_mdgen
                        )
                        all_atom14.append(atom14)
                    new_batches.append(current_batch)
                # batch = new_batches[0]

        all_atom14_cat = []
        for t in all_atom14:
            B, T = t.shape[0], t.shape[1]
            new_shape = (1, B * T) + t.shape[2:]
            all_atom14_cat.append(t.reshape(new_shape))
        all_atom14_cat = torch.cat(all_atom14_cat, 1)

        # create the out_dir if it doesn't exist
        path = os.path.join(args.out_dir, f"{name}.pdb")
        atom14_to_pdb(
            all_atom14_cat[0].cpu().numpy(), batch["seqres"][0].cpu().numpy(), path
        )
        traj = mdtraj.load(path)
        traj.superpose(traj)
        traj.save(os.path.join(args.out_dir, f"{name}.xtc"))
        traj[0].save(os.path.join(args.out_dir, f"{name}.pdb"))

        if args.overfit_peptide:
            break


main()
