import argparse
import json
import os

from glob import glob


import torch

import neural_mpm.util.interpolation as interp
from neural_mpm.util.metaparams import LOW, HIGH, DIM
import neural_mpm.util.simulate as simulate

from neural_mpm.data.data_manager import find_size, get_voxel_centers, list_to_padded


import h5py

from neural_mpm.valid import get_stats_and_load_fn, load_config_and_model


def get_rollout_sims(files, material, load_simulation):
    sims = []
    types = []
    gravity = []

    for f in files:
        sim, type_, grav = load_simulation(f, material)

        sims.append(sim)
        types.append(type_)
        gravity.append(grav)

    return sims, types, gravity

def rollout(model_path, save_path, data_path=None, batch=10, ckpt=None):

    config, model = load_config_and_model(model_path, ckpt)

    if data_path is None:
        data_path = config['data']

    dataset_type = 'mpm' if "WBC" not in data_path else 'sph'

    gmean, gstd, load_simulation, material = get_stats_and_load_fn(config, dataset_type)

    files = os.path.join(data_path, "test", "*.h5")

    files = glob(files)
    files = sorted(files)

    sims, types, grav = get_rollout_sims(files, material, load_simulation)
    start = torch.tensor([LOW, LOW])
    end = torch.tensor([HIGH, HIGH])
    SIZE = find_size(LOW, HIGH, config["grid_size"])

    size_tensor = torch.tensor([SIZE, SIZE])
    grid_coords = get_voxel_centers(size_tensor, start, end)


    padded_s0 = list_to_padded([s[0] for s in sims])
    padded_types = list_to_padded(types)

    grids = []

    for sim, typ, grav in zip(sims, types, grav):
        grid = interp.create_grid_cluster_batch(
            grid_coords,
            sim[:1, ..., :DIM],
            sim[:1, ..., :, DIM:],
            torch.tile(typ[None, :], (1, 1)),
            interp.linear,
            size=SIZE,
            interaction_radius=0.015,
        )

        if grav is not None:
            grav = grav.to('cuda')
            grid = torch.cat((grid, torch.tile(grav[None, None, None], (*grid.shape[:-1], 1))), axis=-1)

        grids += [grid]

    grids = torch.stack(grids)

    with torch.no_grad():
        num_calls = sims[0].shape[0] // config["steps_per_call"] + 1

        trajectories = []

        for b in range(0, len(sims), batch):
            nxt = min(b + batch, len(sims))
            init_state = (grids[b:nxt, 0], padded_s0[b:nxt])

            _, traj, new_grids = simulate.unroll(
                model,
                init_state,
                grid_coords,
                num_calls,
                gmean,
                gstd,
                padded_types[b:nxt],
                SIZE,
                interaction_radius=0.015,
                interp_fn=interp.linear,
            )
            trajectories.extend([traj[t].cpu() for t in range(traj.shape[0])])


    os.makedirs(f"{save_path}", exist_ok=True)

    with open(f'{save_path}/config.json', 'w') as f:
        json.dump(config, f)


    for i, (pred, true, typ) in enumerate(zip(trajectories, sims, types)):
        t = true.shape[0] - 1
        with h5py.File(f'{save_path}/rollout_{i}.h5', 'w') as f:
            f.create_dataset('ground_truth_rollout', data=true[1:].cpu())
            f.create_dataset('predicted_rollout', data=pred[:t, :typ.shape[0]].cpu())
            f.create_dataset('types', data=typ.cpu())



if __name__ == '__main__':

    if torch.cuda.is_available():
        torch.set_default_device("cuda")

    parser = argparse.ArgumentParser("Neural MPM")
    parser.add_argument("--model-path", required=True)
    parser.add_argument("--save-path", required=True)
    parser.add_argument("--data-path", type=str, default=None)
    parser.add_argument("--batch-size", type=int, default=8)
    parser.add_argument("--ckpt", type=int, default=None)
    args = parser.parse_args()

    rollout(args.model_path,
            args.save_path,            args.data_path,
            args.batch_size,
            args.ckpt)

