"""

This script allows to measure the average inference time of the model for a
step.

"""
import numpy as np
import torch
import argparse
import os
import json
import h5py
from pathlib import Path
from tqdm import tqdm

from neural_mpm.util import ModelLogger
from neural_mpm.nn import create_model
from neural_mpm.data.data_manager import find_size, get_voxel_centers, \
    list_to_padded
from neural_mpm.data.parser import load_monomaterial
import neural_mpm.util.interpolation as interp
import neural_mpm.util.simulate as simulate
import neural_mpm.util.viz as viz

import warnings

warnings.filterwarnings("ignore")

# Water Ramps dataset
LOW = 0.08
HIGH = 0.92
DIM = 2

def load_model(run, checkpoint):
    """
    Load a trained NeuralMPM model.

    Args:
        run: Path to the run folder.
        checkpoint: Name of the checkpoint to load.

    Returns:
        The loaded model.
    """

    with open(os.path.join(run, "config.json"), "r") as f:
        config = json.load(f)

    # dirty trick lol
    config['run_id'] = run.split('/')[-1].split('_')[-1]

    model_logger = ModelLogger(
        dataset_name=None,
        run_config=config,
        create_wandb_json=False,
        parent_dir=str(Path(run).parent.parent)
    )

    return config, model_logger.load()


def load_simulations(data, sim_id=123):
    """
    Load a set of simulations to render.

    Returns:
        The loaded simulations.
    """

    start = torch.tensor([LOW, LOW])
    end = torch.tensor([HIGH, HIGH])
    size = find_size(LOW, HIGH, 64)

    size_tensor = torch.tensor([size, size])
    grid_coords = get_voxel_centers(size_tensor, start, end)

    dataset_type = 'mono'

    mean_std_path = 'waterramps.h5'
    with h5py.File(mean_std_path, 'r') as f:
        gmean = torch.tensor(f['mean']).squeeze()
        gstd = torch.tensor(f['std']).squeeze()

    file = os.path.join(data, "test", f"sim_{sim_id}.h5")

    sim, types, _ = load_monomaterial(file, 'water')

    grid = interp.create_grid_cluster_batch(
        grid_coords,
        sim[..., :, :DIM],
        sim[..., :, DIM:],
        torch.tile(types[None, :], (sim.shape[0], 1)),
        interp.linear,
        size=size,
        interaction_radius=0.015,
    ).unsqueeze(0)

    return sim, types, grid, grid_coords, size, gmean, gstd

@torch.no_grad()
def unroll(
    model,
    init_state,
    coords,
    num_calls,
    gmean,
    gstd,
    types,
    size,
    interaction_radius=0.015,
    interp_fn=interp.linear,
    dim=2,
):
    euler_step = simulate.get_step_fn(
        coords, interp_fn, size, interaction_radius=interaction_radius
    )

    def step(state, i):
        (state, old_particles) = state
        state = (state - gmean) / gstd
        grid_preds = model(state)
        grid_velocities = grid_preds * gstd[None, ..., :dim] + gmean[None, ..., :dim]

        full_particles, next_input = euler_step(grid_velocities, old_particles, types)
        new_particles = full_particles[:, -1]

        return (next_input, new_particles), (
            grid_preds,
            full_particles,
            next_input,
            new_particles,
        )

    carry = init_state
    times = []
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    for x in torch.arange(num_calls):
        start.record()
        carry, y = step(carry, x)
        end.record()
        torch.cuda.synchronize()
        times.append(start.elapsed_time(end))

    return times

@torch.no_grad()
def measure(run, checkpoint):
    torch.set_default_device('cuda')

    print('Loading model.')
    config, model = load_model(run, checkpoint)
    model.eval()
    print('Model loaded.')
    print('Loading simulations.')
    (
        sim,
        types,
        grid,
        grid_coords,
        size,
        gmean,
        gstd
    ) = load_simulations(
        config['data'],
        0
    )
    print('Simulations loaded.')

    num_calls = sim.shape[0] // config['steps_per_call'] + 1

    # Warmup
    print('Warming up...')
    simulate.unroll(
        model,
        (grid[:, 0], sim[None, 0]),
        grid_coords,
        num_calls,
        gmean,
        gstd,
        types[None],
        size,
        interaction_radius=0.015,
        interp_fn=interp.linear,
    )

    times_dict = {}
    all_times = []
    unroll_times = []

    for sim_id in range(100):
        print("\nSimulation", sim_id)
        print("---------------")
        (
            sim,
            types,
            grid,
            grid_coords,
            size,
            gmean,
            gstd
        ) = load_simulations(
            config['data'],
            sim_id
        )
        times = unroll(
            model,
            (grid[:, 0], sim[None, 0]),
            grid_coords,
            num_calls,
            gmean,
            gstd,
            types[None],
            size,
            interaction_radius=0.015,
            interp_fn=interp.linear,
        )
        all_times.extend(times)
        times = np.array(times)
        print(f'Single step: {times.mean()}±{times.std()}ms')
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        simulate.unroll(
            model,
            (grid[:, 0], sim[None, 0]),
            grid_coords,
            num_calls,
            gmean,
            gstd,
            types[None],
            size,
            interaction_radius=0.015,
            interp_fn=interp.linear,
        )
        end.record()
        torch.cuda.synchronize()
        unroll_time = start.elapsed_time(end)
        unroll_times.append(unroll_time)
        print('Time for full unroll:', unroll_time, 'ms')

        times_dict[sim_id] = {
            'num_particles': sim.shape[1],
            'single': {
                'mean': times.mean(),
                'std': times.std(),
                'min': times.min(),
                'max': times.max(),
            },
            'full': unroll_time,
        }

    times_dict['all'] = {
        'single': {
            'mean': np.array(all_times).mean(),
            'std': np.array(all_times).std(),
            'min': np.array(all_times).min(),
            'max': np.array(all_times).max(),
        },
        'full': {
            'mean': np.array(unroll_times).mean(),
            'std': np.array(unroll_times).std(),
            'min': np.array(unroll_times).min(),
            'max': np.array(unroll_times).max(),
        }
    }

    with open('times.json', 'w') as f:
        json.dump(times_dict, f, indent=4)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--run", "-r", type=str,
                        help="Path to the run folder.")
    parser.add_argument("--checkpoint", "-c", type=str, help="Name of the "
                                                             "checkpoint to load.",
                        default="best")
    args = parser.parse_args()

    measure(args.run, args.checkpoint)


if __name__ == '__main__':
    main()
