"""
Script to render a trained NeuralMPM model on a set of simulations.
"""

import os

import json
import torch
import argparse
from glob import glob
from pathlib import Path

import h5py
from tqdm import tqdm

from neural_mpm.util.metaparams import LOW, HIGH, DIM
from neural_mpm.util import ModelLogger
from neural_mpm.nn import create_model, UNet
from neural_mpm.data.data_manager import find_size, get_voxel_centers, list_to_padded
from neural_mpm.data.data_stats import LOAD_FUNCS
import neural_mpm.util.interpolation as interp
import neural_mpm.util.simulate as simulate
import neural_mpm.util.viz as viz


def load_model(run, checkpoint):
    """
    Load a trained NeuralMPM model.

    Args:
        run: Path to the run folder.
        checkpoint: Name of the checkpoint to load.

    Returns:
        The loaded model.
    """
    with open(os.path.join(run, "config.json"), "r") as f:
        config = json.load(f)

    # dirty trick lol
    config["run_id"] = run.split("/")[-1].split("_")[-1]

    model_logger = ModelLogger(
        dataset_name=None,
        run_config=config,
        create_wandb_json=False,
        parent_dir=str(Path(run).parent.parent),
    )

    return config, model_logger.load(checkpoint)


def ids_to_list(ids: str) -> list:
    """
    Convert a string of IDs to a list of integers.

    e.g.:
        '1,2,3'    -> [1, 2, 3]
        '1-3'      -> [1, 2, 3]
        '1,2,8-10' -> [1, 2, 8, 9, 10]

    Args:
        ids: Formatted string of IDs.

    Returns:
        The list of IDs. (list of int)

    """
    if ids == "all":
        return "all"

    ids = ids.split(",")
    id_list = []
    for idx in ids:
        if "-" in idx:
            start, end = idx.split("-")
            id_list += list(range(int(start), int(end) + 1))
        else:
            id_list.append(int(idx))
    return id_list


def load_simulations(data, split, id_list):
    """
    Load a set of simulations to render.

    Returns:
        The loaded simulations.
    """

    start = torch.tensor([LOW, LOW])
    end = torch.tensor([HIGH, HIGH])
    size = find_size(LOW, HIGH, 64)

    size_tensor = torch.tensor([size, size])
    grid_coords = get_voxel_centers(size_tensor, start, end)

    if "multi" in data.lower():
        dataset_type = "multi"
    elif "wbc" in data.lower():
        dataset_type = "wbc"
    else:
        dataset_type = "mono"

    if "WaterR" in data:
        mean_std_path = "stats/waterramps_stats.h5"
    elif "Goop" in data:
        mean_std_path = "stats/goop_stats.h5"
    elif "Sand" in data:
        mean_std_path = "stats/sandramps_stats.h5"
    elif "WBC" in data:
        mean_std_path = "stats/wbc_stats.h5"
    elif "XL" in data:
        mean_std_path = "stats/wdxl_stats.h5"
    elif "Multi" in data:
        mean_std_path = "stats/multi_stats.h5"
    else:
        raise ValueError(f"Invalid dataset: {data}")


    with h5py.File(mean_std_path, "r") as f:
        gmean = torch.tensor(f["mean"]).squeeze()
        gstd = torch.tensor(f["std"]).squeeze()

    if dataset_type == "wbc":
        gmean = torch.cat((gmean, torch.zeros(2, dtype=torch.float32)))
        gstd = torch.cat((gstd, torch.ones(2, dtype=torch.float32)))

    files = os.path.join(data, split, "*.h5")
    if id_list == "all":
        files = glob(files)
        files = sorted(files)
    else:
        files = [os.path.join(data, split, f"sim_{i}.h5") for i in id_list]

    if id_list != "all":
        files = [f for f in files if int(f.split("_")[-1].split(".")[0]) in id_list]

    if "multi" in data.lower():
        dataset_type = "multi"
    elif "wbc" in data.lower():
        dataset_type = "wbc"
    else:
        dataset_type = "mono"

    load_func = LOAD_FUNCS[dataset_type]

    if dataset_type == "mono":
        lpath = data.lower()
        if "goop" in lpath:
            material = "goop"
        elif "sand" in lpath:
            material = "sand"
        elif "water" in lpath:
            material = "water"
        else:
            raise ValueError(f"Invalid mono dataset: {data}")
    elif dataset_type == "multi":
        material = None
    elif dataset_type == "wbc":
        material = "water"
    else:
        raise ValueError(f"Invalid dataset type: {dataset_type}")

    sims = []
    types = []
    gravs = []

    for f in files:
        sim, type_, grav = load_func(f, material)

        sims.append(sim)
        types.append(type_)
        gravs.append(grav)

    grids = []

    print("Loading grids")
    pbar = tqdm(total=len(sims))
    for sim, typ, gravv in zip(sims, types, gravs):
        grid = interp.create_grid_cluster_batch(
            grid_coords,
            sim[..., :DIM],
            sim[..., :, DIM:],
            torch.tile(typ[None, :], (sim.shape[0], 1)),
            interp.linear,
            size=size,
            interaction_radius=0.015,
        )

        if gravv is not None:
            grid = torch.cat(
                (grid, torch.tile(gravv[None, None, None], (*grid.shape[:-1], 1))),
                axis=-1,
            )

        grids += [grid]
        pbar.update(1)

    pbar.close()

    grids = torch.stack(grids)

    print("Computed grids:", grids.shape)

    return sims, types, grids, grid_coords, size, gmean, gstd


@torch.no_grad()
def unroll_model(model, config, grids, sims, types, grid_coords, size, gmean, gstd):
    """
    Unroll a model on a set of simulations.

    Args:
        model: The model to unroll.
        grids: The initial grids of the simulations. [N, T, H, W, C]
        particles: The initial particles of the simulations. [N, T, P, 4]
        types: The types of the simulations. [N, P]
        config: The configuration of the model.

    Returns:
        The unrolled simulations.
    """

    num_calls = sims[0].shape[0] // config["steps_per_call"] + 1

    init_grids = grids[:, 0]
    padded_s0 = list_to_padded([s[0] for s in sims])
    padded_types = list_to_padded(types)

    sim_dataset = torch.utils.data.TensorDataset(
        init_grids,
        padded_s0,
        padded_types,
    )
    sim_loader = torch.utils.data.DataLoader(
        sim_dataset,
        batch_size=5,
        shuffle=False,
        num_workers=0,
        pin_memory=False,
    )

    trajectories = []
    pbar = tqdm(total=len(sim_loader), desc="Unrolling")
    for i, (grid_0, particles_0, types_) in enumerate(sim_loader):
        init_state = (
            grid_0,
            particles_0,
        )

        _, batch_trajs, _ = simulate.unroll(
            model,
            init_state,
            grid_coords,
            num_calls,
            gmean,
            gstd,
            types_,
            size,
            interaction_radius=0.015,
            interp_fn=interp.linear,
        )

        trajectories += batch_trajs.cpu()
        pbar.update(1)

    pbar.close()

    trajectories = torch.stack(trajectories)

    return trajectories


def render_simulations(
    run, ids, split, checkpoint, engine, fps, type, steps, gt, img_type, cloud
):
    """
    Render a set of simulations using a trained NeuralMPM model.

    Args:
        run: Path to the run folder.
        ids: IDs of the simulations to render.
        split: Split to use for rendering.
        engine: Engine to use for rendering.
        checkpoint: Name of the checkpoint to load.
    """

    ids = ids_to_list(ids)
    steps = ids_to_list(steps)

    if type in ["video", "frames"]:
        config, model = load_model(run, checkpoint)
        run_name = os.path.split(run)[-1]
        # sims: [N_sims, N_particles, 4]
        # types: [N_sims, N_particles]
        # grids: [N_sims, T, H, W, C]
        (sims, types, grids, grid_coords, size, gmean, gstd) = load_simulations(
            config["data"],
            split,
            ids,
        )

        trajs = unroll_model(
            model, config, grids, sims, types, grid_coords, size, gmean, gstd
        )
    elif type == "dmcf":
        run_name = cloud.split("/")[-1]
        with h5py.File(cloud, "r") as f:
            gt_ = torch.tensor(f["SymNet"]["gt"])[..., :2]
            pred_ = torch.tensor(f["SymNet"]["pred"])[..., :2]
            bnd = torch.tensor(f["SymNet"]["bnd"])[..., :2]

            box_offset = torch.min(pred_.view(-1, 2), dim=0).values
            print(box_offset)
            print(box_offset.shape)

            num_particles = gt_.shape[1]

            bnd = bnd[None, ...].expand(pred_.shape[0], -1, 2)

            gt_ = torch.cat((gt_, bnd), dim=1)[1:]
            pred_ = torch.cat((pred_, bnd), dim=1)[1:]

            typ = torch.cat(
                (
                    torch.ones(num_particles, dtype=torch.long) * 5,
                    torch.zeros(pred_.shape[1] - num_particles, dtype=torch.long),
                ),
            )

            gt_ -= -0.5
            pred_ -= -0.5

            sims = [gt_]
            trajs = [pred_]
            types = [typ]
    elif type == "gns":
        raise NotImplementedError

    elif type == "saved":
        run_name = cloud.split("/")[-1]
        with h5py.File(cloud, "r") as f:
            trajs = []
            for key in f.keys():
                if key != "types":
                    trajs += [torch.tensor(f[key][()])]
                elif key == "types":
                    types = torch.tensor(f[key][()])

            trajs = torch.stack(trajs)
            sims = trajs
            types = torch.tile(types[None, :], (trajs.shape[0], 1))

            ids = list(f.keys())

    os.makedirs(os.path.join("renders", run_name), exist_ok=True)

    pbar = tqdm(total=len(ids), desc="Rendering")
    for idx, gt_traj, pred_traj, types_ in zip(ids, sims, trajs, types):
        length = gt_traj.shape[0] - 1

        if type in ["saved", "frames", "dmcf", "gns"]:
            os.makedirs(
                os.path.join("renders", run_name, f"{split}_{idx}"), exist_ok=True
            )
            for i in steps:
                if gt:
                    cloud_gt = gt_traj[i, ..., :2].cpu().numpy()
                    viz.render_cloud(
                        cloud_gt,
                        types=types_.cpu().numpy(),
                        save_path=os.path.join(
                            "renders",
                            run_name,
                            f"{split}_{idx}",
                            f"gt_{i}" f".{img_type}",
                        ),
                        engine=engine,
                        bounds=((LOW, HIGH), (LOW, HIGH)),
                    )

                cloud_pred = pred_traj[i, ..., :2].cpu().numpy()
                viz.render_cloud(
                    cloud_pred,
                    types=types_.cpu().numpy(),
                    save_path=os.path.join(
                        "renders", run_name, f"{split}_{idx}", f"pred_{i}.{img_type}"
                    ),
                    engine=engine,
                    bounds=((LOW, HIGH), (LOW, HIGH)),
                )
        else:
            if engine == "v2":
                viz.animate_comparison_v2(
                    pred_traj[:length, ..., :2].cpu().numpy(),
                    truth=gt_traj[1:, ..., :2].cpu().numpy(),
                    types=types_.cpu().numpy(),
                    interval=1,
                    save_path=os.path.join("renders", run_name, f"{split}_{idx}.mp4"),
                    return_ani=False,
                    as_array=True,
                    bounds=((LOW, HIGH), (LOW, HIGH)),
                    fps=fps,
                )
            else:
                viz.animate_comparison(
                    pred_traj[:length].cpu().numpy(),
                    truth=gt_traj[1:, ..., :2].cpu().numpy(),
                    type_=types_.cpu().numpy(),
                    interval=1,
                    save_path=os.path.join("renders", run_name, f"{split}_{idx}.mp4"),
                    return_ani=False,
                    as_array=True,
                    bounds=((LOW, HIGH), (LOW, HIGH)),
                    fps=fps,
                )

        pbar.update(1)

    pbar.close()


def render_XL(
    run, ids, split, checkpoint, engine, fps, type, steps, gt, img_type, data_path
):
    """
    Render a set of simulations using a trained NeuralMPM model.

    Args:
        run: Path to the run folder.
        ids: IDs of the simulations to render.
        split: Split to use for rendering.
        engine: Engine to use for rendering.
        checkpoint: Name of the checkpoint to load.
    """

    def load_model(path: str, config_dict: dict, checkpoint_name: str = None):
        """

        Args:
            checkpoint_name: Time of the checkpoint to load.
            If None, will load the latest checkpoint.

        Returns:

        """

        in_channels = 4
        if "WBC" in config_dict["data"]:
            in_channels = 6

        if checkpoint_name is None:
            """
            Loads the latest checkpoint saved.
            checkpoint_name = sorted(
                os.listdir(self.model_folder),
                key=lambda x: int(x.split('/')[-1].split(".")[0])
            )[-1]
            """

            checkpoint_name = "best"
        checkpoint_name = os.path.join(path, "models", f"{checkpoint_name}.ckpt")

        if config_dict["model"] == "unet":
            architecture = config_dict["architecture"]["hidden"] + [
                config_dict["steps_per_call"]
            ]
            factors = [2] * (len(architecture) - 1)
            model = UNet(architecture, factors, in_channels=in_channels)
            if torch.cuda.is_available():
                model = torch.compile(model)

        model.load_state_dict(torch.load(checkpoint_name))

        return model

    with open(run + "/config.json", "r") as file:
        config = json.load(file)

    model = load_model(run, config, checkpoint)

    with h5py.File("stats/waterramps_stats.h5", "r") as f:
        gmean = f["mean"][()]
        gstd = f["std"][()]

    gmean = torch.tensor(gmean)
    gstd = torch.tensor(gstd)
    gmean[..., -2] = 3 * gmean[..., -2]
    gstd[..., -2] = 3 * gstd[..., -2]

    ids = int(ids)
    steps = ids_to_list(steps)

    # sims: [N_sims, N_particles, 4]
    # types: [N_sims, N_particles]
    # grids: [N_sims, T, H, W, C]
    DIM = 2
    MATERIAL = 5

    path = "data/WaterDrop-XL/test/"

    files = glob(path + "*.h5")
    file = files[ids]
    with h5py.File(file, "r") as f:
        particles = f["particles"][()]
        boundary = f["boundary"][()]

    particles = torch.tensor(particles, dtype=torch.float32)
    boundary = torch.tensor(boundary, dtype=torch.float32)

    # Time step
    dt = 0.0025

    particles[..., 2:].multiply_(dt)
    boundary = torch.cat((boundary, torch.zeros_like(boundary)), axis=-1)

    gt_traj = torch.cat((particles, boundary), axis=1)[..., :2]

    start = torch.tensor([LOW, LOW])
    end = torch.tensor([HIGH, HIGH])
    SIZE = find_size(LOW, HIGH, 64)

    size_tensor = torch.tensor([SIZE, SIZE])
    grid_coords = get_voxel_centers(size_tensor, start, end)

    with torch.no_grad():
        s0 = torch.cat([particles[:1], boundary[:1]], dim=1)

        current_types = (
            torch.cat([torch.ones(particles.shape[1]), torch.zeros(boundary.shape[1])])
            * MATERIAL
        )

        grid = interp.create_grid_cluster_batch(
            grid_coords,
            s0[..., :DIM],
            s0[..., :, DIM:],
            torch.tile(current_types[None, :], (s0.shape[0], 1)),
            interp.linear,
            size=SIZE,
            interaction_radius=0.015,
        )

        num_calls = gt_traj.shape[0] // config["steps_per_call"] + 1

        init_state = (grid, s0)

        _, pred_traj, new_grids = simulate.unroll(
            model,
            init_state,
            grid_coords,
            num_calls,
            gmean,
            gstd,
            current_types[None],
            SIZE,
            interaction_radius=0.015,
            interp_fn=interp.linear,
            requires_grad=True,
        )

    run_name = run.split("/")[-1]

    os.makedirs(os.path.join("renders", run_name), exist_ok=True)
    pred_traj = pred_traj[0]
    length = gt_traj.shape[0] - 1

    if type == "frames":
        os.makedirs(os.path.join("renders", run_name, f"{split}_{ids}"), exist_ok=True)
        for i in steps:
            if gt:
                cloud_gt = gt_traj[i, ..., :2].detach().cpu().numpy()
                viz.render_cloud(
                    cloud_gt,
                    types=current_types.cpu().numpy(),
                    save_path=os.path.join(
                        "renders", run_name, f"{split}_{ids}", f"gt_{i}" f".{img_type}"
                    ),
                    engine=engine,
                    bounds=((LOW, HIGH), (LOW, HIGH)),
                )

            viz.render_cloud(
                pred_traj[i, ..., :2].detach().cpu().numpy(),
                types=current_types.cpu().numpy(),
                save_path=os.path.join(
                    "renders", run_name, f"{split}_{ids}", f"pred_{i}.{img_type}"
                ),
                engine=engine,
                bounds=((LOW, HIGH), (LOW, HIGH)),
            )
    else:
        if engine == "v2":
            viz.animate_comparison_v2(
                pred_traj[:length, ..., :2].detach().cpu().numpy(),
                truth=gt_traj[1:, ..., :2].cpu().numpy(),
                types=current_types.cpu().numpy(),
                interval=1,
                save_path=os.path.join("renders", run_name, f"{split}_{ids}.mp4"),
                return_ani=False,
                as_array=True,
                bounds=((LOW, HIGH), (LOW, HIGH)),
                fps=fps,
            )
        else:
            viz.animate_comparison(
                pred_traj[:length].detach().cpu().numpy(),
                truth=gt_traj[1:, ..., :2].cpu().numpy(),
                type_=current_types.cpu().numpy(),
                interval=1,
                save_path=os.path.join("renders", run_name, f"{split}_{ids}.mp4"),
                return_ani=False,
                as_array=True,
                bounds=((LOW, HIGH), (LOW, HIGH)),
                fps=fps,
            )


"""

pred                gt                 types    
(120, 3397, 4) (120, 3046, 4) (3046,)
1762

pred                pred                 types
(120, 3397, 4) (120, 3397, 4) (3046,)
1762

gt               gt                 types
(120, 3046, 4) (120, 3046, 4) (3046,)
1762


"""


def main():
    parser = argparse.ArgumentParser()
    # video, frames, dmcf (h5), gns (h5)
    parser.add_argument("type", type=str, help="video or frames")

    parser.add_argument("--run", "-r", type=str, help="Path to the run folder.")
    parser.add_argument(
        "--ids", "-i", type=str, help="IDs of the simulations to render."
    )
    parser.add_argument(
        "--split", "-s", type=str, help="Split to use for rendering.", default="test"
    )
    parser.add_argument(
        "--checkpoint",
        "-c",
        type=str,
        help="Name of the checkpoint to load.",
        default="best",
    )
    parser.add_argument(
        "--engine", "-e", type=str, help="Engine to use for rendering", default="v2"
    )
    # For frames
    parser.add_argument(
        "--steps", type=str, help="Time steps to " "render", default="0,100," "200,300"
    )
    parser.add_argument("--gt", type=bool, help="Render ground truth", default=False)
    parser.add_argument("--img-type", type=str, help="Image type", default="pdf")

    # For videos
    parser.add_argument("--fps", type=int, default=75)

    parser.add_argument(
        "--data-path",
        type=str,
        help="Which data to use, if None will use the data from the model",
        default=None,
    )

    # For DMCF/GNS
    parser.add_argument(
        "--cloud", type=str, help="Path to the sim .h5 file.", default=None
    )

    args = parser.parse_args()

    if torch.cuda.is_available():
        torch.set_default_device("cuda")

    if args.data_path is None:
        render_simulations(
            args.run,
            args.ids,
            args.split,
            args.checkpoint,
            args.engine,
            args.fps,
            args.type,
            args.steps,
            args.gt,
            args.img_type,
            args.cloud,
        )

    elif "XL" in args.data_path:
        render_XL(
            args.run,
            args.ids,
            args.split,
            args.checkpoint,
            args.engine,
            args.fps,
            args.type,
            args.steps,
            args.gt,
            args.img_type,
            args.data_path,
        )


if __name__ == "__main__":
    main()
