"""
Utility file for computing dataset statistics.
"""
import argparse
import re
import os

import h5py
from tqdm import tqdm
import torch

from neural_mpm.data.parser import load_multimaterial, load_monomaterial, load_WBC
from neural_mpm.data.data_manager import get_voxel_centers, find_size
from neural_mpm.util.metaparams import LOW, HIGH
import neural_mpm.util.interpolation as interp

NUM_CHANNELS = {
    'mono': 4,  # [d_part, vx_part, vy_part, d_wall]
    'multi': 10,  # [vwater(2),vsand(2), vgoop(2), dwater, dsand, dgoop, dwall]
    'wbc': 6,  # [d_part, vx_part, vy_part, d_wall, grav_x, grav_y]
}

LOAD_FUNCS = {
    'mono': load_monomaterial,
    'multi': load_multimaterial,
    'wbc': load_WBC,
}

def compute_dataset_mstd(
        dataset_path,
        dataset_type='mono',
        save_path=None,
        split='train',
        grid_size=64,
        dim=2
):
    """
    Compute the mean and standard deviation of the grids of the entire split
    of a dataset.

    Args:
        dataset_path (str): Path to the dataset.
        dataset_type (str): Type of dataset
                            - 'mono': Monomaterial dataset
                            - 'multi': Multimaterial dataset
                            - 'wbc': WBC-SPH dataset
        save_path (str): Path to save the computed mean and standard deviation.
        split (str): Split to compute the mean and standard deviation for.

    Returns:
        mstd (tuple): Tuple containing the mean and standard deviation.
    """

    # 1000x1000x64x64xC
    #   B    T   H  W

    if dataset_type == 'mono':
        lpath = dataset_path.lower()
        if 'goop' in lpath:
            material = 'goop'
        elif 'sand' in lpath:
            material = 'sand'
        elif 'water' in lpath:
            material = 'water'
        else:
            raise ValueError(f"Invalid mono dataset: {dataset_path}")
    elif dataset_type == 'multi':
        material = None
    elif dataset_type == 'wbc':
        material = 'water'
    else:
        raise ValueError(f"Invalid dataset type: {dataset_type}")

    # if torch.cuda.is_available():
    #     torch.set_default_device('cuda')

    num_channels = NUM_CHANNELS[dataset_type]
    load_func = LOAD_FUNCS[dataset_type]

    size = find_size(LOW, HIGH, grid_size)
    size_tensor = torch.tensor([size, size], dtype=torch.float32)
    grid_coords = get_voxel_centers(
        size_tensor,
        torch.tensor([LOW, LOW]),
        torch.tensor([HIGH, HIGH]),
    )

    mean = 0
    squares = 0

    for root, _, files in os.walk(os.path.join(dataset_path, split)):
        # files = [f for f in files if int(re.search(r'sim_(\d+)\.h5', f).group(1)) not in THROW_large]

        num_samples = len(files)
        # num_samples = 10

        i = 0
        for file in tqdm(files, desc=f"Computing mean and std for {split}"):
            if file.endswith('.h5'):
                if i >= num_samples:
                    break

                sim, types, grav = load_func(os.path.join(root, file),
                                                  material)

                grid = interp.create_grid_cluster_batch(
                    grid_coords,
                    torch.clamp(sim[..., :dim], min=LOW, max=HIGH),
                    sim[..., dim:],
                    torch.tile(types[None, :], (sim.shape[0], 1)),
                    interp.linear,
                    size=size,
                    interaction_radius=0.015,
                )  # Shape: [T, H, W, C]

                if grav is not None:
                    grid = torch.cat((grid, torch.tile(grav[None, None, None], (*grid.shape[:-1], 1))), axis=-1)

                mean += grid
                squares += grid ** 2

                i += 1

        mean = torch.mean(mean, axis=(0, 1, 2), keepdim=True) / num_samples
        std = torch.mean(squares, axis=(0, 1, 2), keepdim=True) / num_samples - mean ** 2
        std = torch.sqrt(std)

        print(mean, std)

        # if dataset_type == 'wbc':
        #     mean = torch.concatenate((mean, torch.zeros_like(mean)[..., :2]), axis=-1)
        #     std = torch.concatenate((std, torch.ones_like(std)[..., :2]), axis=-1)

        if save_path is not None:
            # Save as h5
            with h5py.File(save_path, 'w') as f:
                f.create_dataset('mean', data=mean.cpu().numpy())
                f.create_dataset('std', data=std.cpu().numpy())

        return mean, std

if __name__ == '__main__':
    if torch.cuda.is_available():
        torch.set_default_device("cuda")

    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset_path', '-d', type=str, required=True)
    parser.add_argument('--dataset_type', '-t', type=str, default='mono')
    parser.add_argument('--save_path', '-s', type=str, default=None)
    parser.add_argument('--split', '-S', type=str, default='train')
    parser.add_argument('--grid_size', type=int, default=64)
    parser.add_argument('--dim', type=int, default=2)
    args = parser.parse_args()

    compute_dataset_mstd(
        args.dataset_path,
        args.dataset_type,
        args.save_path,
        args.split,
        args.grid_size,
        args.dim
    )