import argparse
import os
import json
from glob import glob
import math


import torch

import neural_mpm.util.interpolation as interp
from neural_mpm.util.metaparams import LOW, HIGH, DIM
import neural_mpm.util.simulate as simulate

from neural_mpm.data.data_manager import find_size, get_voxel_centers, list_to_padded
from neural_mpm.data.parser import load_monomaterial, load_WBC

import matplotlib.pyplot as plt

from neural_mpm.nn import UNet
from matplotlib.ticker import FuncFormatter

import h5py


def get_valid_sims(files, material, load_simulation):
    sims = []
    types = []

    for f in files:
        sim, type_, _ = load_simulation(f, material)

        sims.append(sim)
        types.append(type_)

    return sims, types


def compute_mse(pred, true, types=None):
    n = true.shape[1]
    if types is not None:
        n = torch.count_nonzero(types)

    t = true.shape[0] - 1
    err = (pred[:t, :n, :DIM] - true[1:, :n, :DIM]) ** 2
    err = err.mean(axis=(1, 2))
    return err


def compute_ae(pred, true, types=None):
    n = true.shape[1]
    if types is not None:
        n = torch.count_nonzero(types)
    t = true.shape[0] - 1
    err = torch.abs(pred[:t, :n, :DIM] - true[1:, :n, :DIM])
    err = err.mean(axis=(1, 2))
    return err


def load_config_and_model(path: str, name: str = None):
    with open(os.path.join(path, "config.json"), "r") as f:
        config_dict = json.load(f)

    model = load_model(path, config_dict, checkpoint_name=name)

    return config_dict, model


def load_model(path: str, config_dict: dict, checkpoint_name: str = None):
    """

    Args:
        checkpoint_name: Time of the checkpoint to load.
        If None, will load the latest checkpoint.

    Returns:

    """

    in_channels = 4
    if "WBC" in config_dict["data"]:
        in_channels = 6

    if checkpoint_name is None:
        """
        Loads the latest checkpoint saved.
        checkpoint_name = sorted(
            os.listdir(self.model_folder),
            key=lambda x: int(x.split('/')[-1].split(".")[0])
        )[-1]
        """

        checkpoint_name = "best"
    checkpoint_name = os.path.join(path, "models", f"{checkpoint_name}.ckpt")
    if config_dict["model"] == "unet":
        architecture = config_dict["architecture"]["hidden"] + [
            config_dict["steps_per_call"]
        ]
        factors = [2] * (len(architecture) - 1)
        model = UNet(architecture, factors, in_channels=in_channels)

        state_dict = torch.load(checkpoint_name)

        if "rb" not in checkpoint_name and "WBC" not in config_dict['data']:
            if torch.cuda.is_available():
                model = torch.compile(model)

    model.load_state_dict(state_dict)

    return model


def get_stats_and_load_fn(config, dataset_type):
    if "Water" in config["data"]:
        if dataset_type == "mpm":
            material = "water"
            with h5py.File("stats/waterramps_stats.h5", "r") as f:
                gmean = f["mean"][()]
                gstd = f["std"][()]

            gmean = torch.tensor(gmean, dtype=torch.float32)
            gstd = torch.tensor(gstd, dtype=torch.float32)

            load_simulation = load_monomaterial
            assert LOW == 0.08

    elif "Goop" in config["data"]:
        if dataset_type == "mpm":
            material = "goop"
            with h5py.File("stats/goop_stats.h5", "r") as f:
                gmean = f["mean"][()]
                gstd = f["std"][()]

            gmean = torch.tensor(gmean, dtype=torch.float32)
            gstd = torch.tensor(gstd, dtype=torch.float32)

            load_simulation = load_monomaterial
            assert LOW == 0.08

    elif "Sand" in config["data"]:
        if dataset_type == "mpm":
            material = "sand"
            with h5py.File("stats/sandramps_stats.h5", "r") as f:
                gmean = f["mean"][()]
                gstd = f["std"][()]

            gmean = torch.tensor(gmean, dtype=torch.float32)
            gstd = torch.tensor(gstd, dtype=torch.float32)

            load_simulation = load_monomaterial
            assert LOW == 0.08

    elif "WBC" in config["data"]:
        if dataset_type == "sph":
            material = "water"
            s = config['grid_size']
            with h5py.File(f"stats/wbc_{s}_stats.h5", "r") as f:
                gmean = f["mean"][()]
                gstd = f["std"][()]

            gmean = torch.tensor(gmean, dtype=torch.float32)
            gstd = torch.tensor(gstd, dtype=torch.float32)

            load_simulation = load_WBC
            assert LOW == 0.0025

    else:
        raise "Unknown material"

    return gmean, gstd, load_simulation, material


def plot_error_over_time(err_list):
    mean = sum(err_list) / len(err_list)

    # Calculate variance or standard deviation for the shadow
    std_dev = torch.sqrt(sum([x**2 for x in err_list]) / len(err_list) - mean**2)

    # Convert to CPU if necessary (comment this out if not using PyTorch)
    mean = mean.cpu()
    std_dev = std_dev.cpu()
    # lower = torch.maximum(torch.zeros_like(mean), mean - std_dev).cpu()
    low = torch.inf
    for x in err_list:
        if x.mean() < low:
            lower = x.cpu()
            low = x.mean()

    def format_yaxis(x, pos):
        return f"{x * 1e3:.0f}"

    # Apply formatter to the y-axis

    plt.figure(figsize=(10, 6))
    plt.plot(mean, color="red", label="Mean")  # Plot the mean
    plt.fill_between(
        range(len(mean)),
        lower,
        mean + std_dev,
        color="orange",
        alpha=0.5,
        label="Standard Deviation",
    )
    plt.gca().yaxis.set_major_formatter(FuncFormatter(format_yaxis))

    plt.title("Plot of Mean with Standard Deviation Shadow")
    plt.xlabel("Time")
    plt.ylabel("Value")
    plt.legend()
    plt.grid(True)
    plt.show()


def plot_all_erros(err_list):
    def format_yaxis(x, pos):
        return f"{x * 1e3:.0f}"

    # Apply formatter to the y-axis

    plt.figure(figsize=(10, 6))

    for x in err_list:
        plt.plot(x.cpu())  # Plot the mean
    plt.gca().yaxis.set_major_formatter(FuncFormatter(format_yaxis))

    plt.xlabel("Time")
    plt.ylabel("Error")
    # plt.legend()
    plt.grid(True)
    plt.show()


def valid_ablation(experiment, ablation_name, dataset_type):
    if experiment[-2:] == "g*":
        experiment = glob(experiment)

        exp = []
        for e in experiment:
            if "grid" in e:
                continue
            exp.append(e)
        experiment = exp
    else:
        experiment = glob(experiment)

    experiment = sorted(experiment)

    results = {}

    for exp in experiment:
        experiment_name = os.path.split(exp)[-1]
        # experiment_name = experiment_name

        path = exp
        config, model = load_config_and_model(path, None)

        start = torch.tensor([LOW, LOW])
        end = torch.tensor([HIGH, HIGH])
        SIZE = find_size(LOW, HIGH, config["grid_size"])

        size_tensor = torch.tensor([SIZE, SIZE])
        grid_coords = get_voxel_centers(size_tensor, start, end)

        gmean, gstd, load_simulation, material = get_stats_and_load_fn(
            config, dataset_type
        )

        files = os.path.join(config["data"], "test", "*.h5")

        files = glob(files)
        files = sorted(files)

        sims, types = get_valid_sims(files, material, load_simulation)

        padded_s0 = list_to_padded([s[0] for s in sims])
        padded_types = list_to_padded(types)

        grids = []

        for sim, typ in zip(sims, types):
            grid = interp.create_grid_cluster_batch(
                grid_coords,
                sim[..., :DIM],
                sim[..., :, DIM:],
                torch.tile(typ[None, :], (sim.shape[0], 1)),
                interp.linear,
                size=SIZE,
                interaction_radius=0.015,
            )

            grids += [grid]

        grids = torch.stack(grids)

        with torch.no_grad():
            batch = 10
            num_calls = sims[0].shape[0] // config["steps_per_call"] + 1

            trajectories = []

            for b in range(0, len(sims), batch):
                nxt = min(b + batch, len(sims))
                init_state = (grids[b:nxt, 0], padded_s0[b:nxt])

                _, traj, new_grids = simulate.unroll(
                    model,
                    init_state,
                    grid_coords,
                    num_calls,
                    gmean,
                    gstd,
                    padded_types[b:nxt],
                    SIZE,
                    interaction_radius=0.015,
                    interp_fn=interp.linear,
                )

                trajectories.append(traj)

            trajectories = torch.cat(trajectories, axis=0)
            mse = []
            ae = []

            for i in range(len(sims)):
                ae += [compute_ae(trajectories[i], sims[i], types[i])]
                mse += [compute_mse(trajectories[i], sims[i], types[i])]

            results[experiment_name] = {"ae": ae, "mse": mse}

    torch.save(results, os.path.join("error_lists", ablation_name + ".pth"))


def error_over_training(model_path, save_name, dataset_type):

    epochs = glob(os.path.join(model_path, "models/*.ckpt"))

    epochs = sorted(epochs)

    config, model = load_config_and_model(model_path, None)

    start = torch.tensor([LOW, LOW])
    end = torch.tensor([HIGH, HIGH])
    SIZE = find_size(LOW, HIGH, config["grid_size"])

    size_tensor = torch.tensor([SIZE, SIZE])
    grid_coords = get_voxel_centers(size_tensor, start, end)

    gmean, gstd, load_simulation, material = get_stats_and_load_fn(config, dataset_type)

    files = os.path.join(config["data"], "test", "*.h5")

    files = glob(files)
    files = sorted(files)

    sims, types = get_valid_sims(files, material, load_simulation)

    padded_s0 = list_to_padded([s[0] for s in sims])
    padded_types = list_to_padded(types)

    grids = []

    for sim, typ in zip(sims, types):
        grid = interp.create_grid_cluster_batch(
            grid_coords,
            sim[..., :DIM],
            sim[..., :, DIM:],
            torch.tile(typ[None, :], (sim.shape[0], 1)),
            interp.linear,
            size=SIZE,
            interaction_radius=0.015,
        )

        grids += [grid]

    grids = torch.stack(grids)

    results = {}

    with torch.no_grad():
        batch = 10
        num_calls = sims[0].shape[0] // config["steps_per_call"] + 1

        trajectories = []

        for b in range(0, len(sims), batch):
            nxt = min(b + batch, len(sims))
            init_state = (grids[b:nxt, 0], padded_s0[b:nxt])

            _, traj, new_grids = simulate.unroll(
                model,
                init_state,
                grid_coords,
                num_calls,
                gmean,
                gstd,
                padded_types[b:nxt],
                SIZE,
                interaction_radius=0.015,
                interp_fn=interp.linear,
            )

            trajectories.append(traj)

        trajectories = torch.cat(trajectories, axis=0)
        mse = []
        ae = []

        for i in range(len(sims)):
            ae += [compute_ae(trajectories[i], sims[i], types[i])]
            mse += [compute_mse(trajectories[i], sims[i], types[i])]

        ae = sum(ae) / len(ae)
        mse = sum(mse) / len(mse)

    best_json_path = os.path.join(model_path, "models", "best.json")
    wandb_json_path = os.path.join(model_path, "wandb.json")

    with open(best_json_path, "r") as file:
        best_json = json.load(file)

    with open(wandb_json_path, "r") as file:
        wandb_json = json.load(file)

    epoch_list = []
    for e in epochs:
        name = e.split("/")[-1]
        name = name.split(".")[0]
        if name == "best":
            continue

        epoch_list.append(name)



    def get_sec(time_str):
        """Get seconds from time."""
        h, m, s = time_str.split(":")
        return int(h) * 3600 + int(m) * 60 + int(s)

    best_time = get_sec(best_json["elapsed_time"])


    results["best"] = {
        "ae": ae,
        "mse": mse,
        "time": round(best_time),
    }

    for name in epoch_list:
        config, model = load_config_and_model(model_path, str(name))

        with torch.no_grad():
            batch = 10
            num_calls = sims[0].shape[0] // config["steps_per_call"] + 1

            trajectories = []

            for b in range(0, len(sims), batch):
                nxt = min(b + batch, len(sims))
                init_state = (grids[b:nxt, 0], padded_s0[b:nxt])

                _, traj, new_grids = simulate.unroll(
                    model,
                    init_state,
                    grid_coords,
                    num_calls,
                    gmean,
                    gstd,
                    padded_types[b:nxt],
                    SIZE,
                    interaction_radius=0.015,
                    interp_fn=interp.linear,
                )

                trajectories.append(traj)

            trajectories = torch.cat(trajectories, axis=0)
            mse = []
            ae = []

            for i in range(len(sims)):
                ae += [compute_ae(trajectories[i], sims[i], types[i])]
                mse += [compute_mse(trajectories[i], sims[i], types[i])]

            ae = sum(ae) / len(ae)
            mse = sum(mse) / len(mse)

            results[name] = {
                "ae": ae,
                "mse": mse,
                "time": int(name),
            }

    folder = model_path.split('/')[-2]
    os.makedirs(os.path.join("error_over_training", folder), exist_ok=True)
    torch.save(results, os.path.join("error_over_training", folder, save_name + ".pth"))


def tensor_list_mean(data):
    mean_dict = {}
    for key, nested_dict in data.items():
        mean_dict[key] = {}
        for subkey, tensor_list in nested_dict.items():
            # Stack the list of tensors and compute the mean
            mean_dict[key][subkey] = torch.stack(tensor_list).mean(dim=0)
    return mean_dict


def reduce_results(data):
    reduced = tensor_list_mean(data)
    reduced = torch.utils._pytree.tree_map(lambda x: x.mean().cpu(), reduced)
    return reduced


def bar_plots(data, subkey, ylabel, titles):
    # Number of plots
    N = len(data)

    # Create a figure with 1 row and N columns
    fig, axs = plt.subplots(1, N, figsize=(10 * N, 5), sharey=True)

    # Ensure axs is iterable
    if N == 1:
        axs = [axs]

    for i, (ax, d, tit) in enumerate(zip(axs, data, titles)):
        # Prepare plot data
        keys = list(d.keys())
        means = [d[key][subkey] for key in keys]  # Directly use scalar values

        # Plotting
        ax.bar(keys, means, color="skyblue")
        ax.set_title(tit)

        # Set y-axis label and formatter only for the first plot
        if i == 0:
            ax.set_ylabel(ylabel)
            ax.yaxis.set_major_formatter(FuncFormatter(lambda x, pos: f"{x * 1e3:.0f}"))

            ax.spines["top"].set_visible(False)
            ax.spines["right"].set_visible(False)
            # ax.spines['bottom'].set_visible(False)
            # ax.spines['left'].set_visible(False)
        else:
            ax.yaxis.set_visible(False)
            ax.spines["top"].set_visible(False)
            ax.spines["right"].set_visible(False)
            # ax.spines['bottom'].set_visible(False)
            ax.spines["left"].set_visible(False)

    plt.show()


if __name__ == "__main__":
    if torch.cuda.is_available():
        torch.set_default_device("cuda")

    parser = argparse.ArgumentParser("Neural MPM")
    parser.add_argument("--experiment")
    parser.add_argument("--save-name")
    parser.add_argument("--dataset-type", type=str)

    parser.add_argument("-v", "--validation-type", type=str, required=True)

    args = parser.parse_args()

    if args.validation_type == "ablation":
        assert args.experiment[-1] == "*"
        valid_ablation(args.experiment, args.save_name, args.dataset_type)

    elif args.validation_type == "epochs":
        if args.experiment.split('/')[-1] == 'single':
            error_over_training(args.experiment[:-len('single')], args.save_name, args.dataset_type)
        else:
            experiment = glob(os.path.join(args.experiment, '*'))

            for exp in experiment:
                save_name = exp.split('/')[-1] + '_no_noise'


                # if len(save_name) == 3:
                #     save_name = save_name[1]
                # if len(save_name) == 5:
                #     save_name = save_name[1] + save_name[2] + save_name[3]
                error_over_training(exp, save_name, args.dataset_type)


