import re
from glob import glob
import os
from random import shuffle
from typing import Union, List, Tuple, Sequence, Callable
import h5py
import torch

from neural_mpm.data.parser import load_monomaterial, load_WBC
import neural_mpm.util.interpolation as interp

from neural_mpm.util.metaparams import LOW, HIGH
from neural_mpm.util.wbc_filters import MEDIUM_FILTER


def find_size(a, b, N):
    if a == b:
        raise ValueError(
            "a and b cannot be equal, as it would make size infinite or undefined."
        )

    delta = a - b
    if delta < 0:
        delta = -delta  # Ensure delta is positive

    lower_bound = delta / N
    upper_bound = (
        delta / (N - 1) if N != 1 else float("inf")
    )  # Avoid division by zero if N is 1

    # We can choose any size in the interval (lower_bound, upper_bound]
    size = (lower_bound + upper_bound) / 2  # Choosing the midpoint for simplicity

    return size


def get_voxel_centers(size, start, end):
    num_voxels = ((end - start) / size).long() + 1

    x_center = torch.arange(start[0] + size[0] / 2, end[0] + size[0] / 2, size[0])
    y_center = torch.arange(start[1] + size[1] / 2, end[1] + size[1] / 2, size[1])

    x_center = x_center.view(1, -1).repeat(num_voxels[1], 1)
    y_center = y_center.view(-1, 1).repeat(1, num_voxels[0])

    centers = torch.stack([x_center, y_center], dim=-1)
    return centers


def list_to_padded(
    x: Union[List[torch.Tensor], Tuple[torch.Tensor]],
    pad_size: Union[Sequence[int], None] = None,
    pad_value: float = 0.0,
    equisized: bool = False,
) -> torch.Tensor:
    r"""
    Transforms a list of N tensors each of shape (Si_0, Si_1, ... Si_D)
    into:
    - a single tensor of shape (N, pad_size(0), pad_size(1), ..., pad_size(D))
      if pad_size is provided
    - or a tensor of shape (N, max(Si_0), max(Si_1), ..., max(Si_D)) if pad_size is None.

    Args:
      x: list of Tensors
      pad_size: list(int) specifying the size of the padded tensor.
        If `None` (default), the largest size of each dimension
        is set as the `pad_size`.
      pad_value: float value to be used to fill the padded tensor
      equisized: bool indicating whether the items in x are of equal size
        (sometimes this is known and if provided saves computation)

    Returns:
      x_padded: tensor consisting of padded input tensors stored
        over the newly allocated memory.
    """
    if equisized:
        return torch.stack(x, 0)

    if not all(torch.is_tensor(y) for y in x):
        raise ValueError("All items have to be instances of a torch.Tensor.")

    # we set the common number of dimensions to the maximum
    # of the dimensionalities of the tensors in the list
    element_ndim = max(y.ndim for y in x)

    # replace empty 1D tensors with empty tensors with a correct number of dimensions
    x = [
        (y.new_zeros([0] * element_ndim) if (y.ndim == 1 and y.nelement() == 0) else y)
        for y in x
    ]

    if any(y.ndim != x[0].ndim for y in x):
        raise ValueError("All items have to have the same number of dimensions!")

    if pad_size is None:
        pad_dims = [
            max(y.shape[dim] for y in x if len(y) > 0) for dim in range(x[0].ndim)
        ]
    else:
        if any(len(pad_size) != y.ndim for y in x):
            raise ValueError("Pad size must contain target size for all dimensions.")
        pad_dims = pad_size

    N = len(x)
    x_padded = x[0].new_full((N, *pad_dims), pad_value)
    for i, y in enumerate(x):
        if len(y) > 0:
            slices = (i, *(slice(0, y.shape[dim]) for dim in range(y.ndim)))
            x_padded[slices] = y
    return x_padded


def padded_to_list(
    x: torch.Tensor,
    split_size: Union[Sequence[int], Sequence[Sequence[int]], None] = None,
):
    r"""
    Transforms a padded tensor of shape (N, S_1, S_2, ..., S_D) into a list
    of N tensors of shape:
    - (Si_1, Si_2, ..., Si_D) where (Si_1, Si_2, ..., Si_D) is specified in split_size(i)
    - or (S_1, S_2, ..., S_D) if split_size is None
    - or (Si_1, S_2, ..., S_D) if split_size(i) is an integer.

    Args:
      x: tensor
      split_size: optional 1D or 2D list/tuple of ints defining the number of
        items for each tensor.

    Returns:
      x_list: a list of tensors sharing the memory with the input.
    """
    x_list = list(x.unbind(0))

    if split_size is None:
        return x_list

    N = len(split_size)
    if x.shape[0] != N:
        raise ValueError("Split size must be of same length as inputs first dimension")

    for i in range(N):
        if isinstance(split_size[i], int):
            x_list[i] = x_list[i][: split_size[i]]
        else:
            slices = tuple(slice(0, s) for s in split_size[i])  # pyre-ignore
            x_list[i] = x_list[i][slices]
    return x_list


class SmallDataset(torch.utils.data.Dataset):
    def __init__(self, grids, states, targets, types, grid_targets=None):
        self.init_state = (grids[:, 0], states[:, 0])
        self.init_types = types
        self.timesteps = grids.shape[1]
        self.num = grids.shape[0]
        self.types = (
            types[:, None, :]
            .expand(-1, self.timesteps, -1)
            .reshape(-1, types.shape[-1])
        )

        self.grids = grids.reshape(-1, *grids.shape[2:])
        self.states = states.reshape(-1, *states.shape[2:])
        self.targets = targets.reshape(-1, *targets.shape[2:])

        if grid_targets is not None:
            self.grid_targets = grid_targets.reshape(-1, *grid_targets.shape[2:])

    def __len__(self):
        return self.grids.shape[0]

    def __getitem__(self, idx):
        return (
            self.grids[idx],
            self.states[idx],
            self.targets[idx],
            self.types[idx],
            self.grid_targets[idx],
        )

    def get_sims_for_unroll(self):
        return self.init_state, self.init_types


class DataManager:
    def __init__(
        self,
        path: str,
        batch_size: int = 64,
        dim: int = 2,
        grid_size: int = 64,
        steps_per_call: int = 1,
        autoregressive_steps: int = 1,
        sims_in_memory: int = 2,
        interaction_radius: float = 0.015,
        interp_fn: Callable = interp.linear,
        num_valid_sims: int = 8,
    ):
        self.dim = dim
        self.batch_size = batch_size

        files = os.path.join(path, "train", "*.h5")
        # files = os.path.join("temp_data/train", "*.h5")
        files = glob(files)
        self.files = sorted(files)

        if "Water" in path:
            material = "water"
            sim_length = 600
            self.load_simulation = load_monomaterial
            assert LOW == 0.08

        elif "Goop" in path:
            material = "goop"
            sim_length = 400
            self.load_simulation = load_monomaterial
            assert LOW == 0.08

        elif "Sand" in path:
            material = "sand"
            sim_length = 400
            self.load_simulation = load_monomaterial
            assert LOW == 0.08


        elif "WBC" in path:
            material = "water"
            sim_length = 400
            self.load_simulation = load_WBC
            shuffle(self.files)
            num_valid_sims = 4
            self.files = [f for f in self.files if int(re.search(r'sim_(\d+)\.h5', f).group(1)) not in MEDIUM_FILTER]

            assert LOW == 0.0025

        else:
            raise "Unknown material"


        valid_files = os.path.join(path, "valid", "*.h5")
        valid_files = glob(valid_files)
        self.valid_files = sorted(valid_files)[:num_valid_sims]


        self.material = material
        self.sim_length = sim_length
        if "WBC" not in path:
            self.true_sim_length = sim_length
        else:
            self.true_sim_length = 3000
        self.current_sim_idx = 0
        self.sims_in_memory = sims_in_memory
        self.steps_per_call = steps_per_call
        self.autoregressive_steps = autoregressive_steps
        self.max_t = (
            self.sim_length - self.steps_per_call * self.autoregressive_steps - 1
        )

        self.indices = torch.arange(0, self.max_t, self.steps_per_call)
        # self.indices = torch.arange(0, self.max_t)

        self.target_indices = self.indices[:, None] + torch.arange(
            1, self.steps_per_call * self.autoregressive_steps + 1
        )

        self.sims = None
        self.types = None
        self.gravities = None
        self.grids = None
        self.gmean = None
        self.gstd = None

        start = torch.tensor([LOW, LOW])
        end = torch.tensor([HIGH, HIGH])
        self.size = find_size(LOW, HIGH, grid_size)
        self.grid_size = grid_size

        self.interp_fn = interp_fn
        self.interaction_radius = interaction_radius

        size_tensor = torch.tensor([self.size, self.size])
        grid_coords = get_voxel_centers(size_tensor, start, end)

        self.grid_coords = grid_coords
        if torch.cuda.is_available():
            self.grid_coords = self.grid_coords.to("cuda")
    def get_valid_sims(self):
        sims = []
        types = []
        grids = []

        for f in self.valid_files:
            # f = self.files[0]

            sim, type_, grav = self.load_simulation(f, self.material)
            if torch.cuda.is_available():
                sim, type_ = sim.to("cuda"), type_.to("cuda")

            num_particles = torch.count_nonzero(type_)

            sims.append(sim)
            types.append(type_)
            grid = interp.create_grid_cluster_batch(
                self.grid_coords,
                sim[:1, ..., : self.dim],
                sim[:1, ..., :num_particles, self.dim :],
                torch.tile(type_[None, :], (sim.shape[0], 1)),
                self.interp_fn,
                size=self.size,
                interaction_radius=self.interaction_radius,
            )

            if grav is not None:
                if torch.cuda.is_available():
                    grav = grav.to("cuda")
                grid = torch.cat(
                    (grid, torch.tile(grav[None, None, None], (*grid.shape[:-1], 1))),
                    axis=-1,
                )

            grids.append(grid[0])

        grids = torch.stack(grids)
        states = list_to_padded([s[0] for s in sims])
        types = list_to_padded(types)

        return (grids, states), types, sims

    def load_sim_buffer(self):
        if self.current_sim_idx + self.sims_in_memory >= len(self.files):
            self.current_sim_idx = 0

        sims = []
        types = []
        gravities = []

        for idx in range(
            self.current_sim_idx, self.current_sim_idx + self.sims_in_memory
        ):
            # idx = 0
            sim = self.files[idx]
            sim, type_, grav = self.load_simulation(sim, self.material)

            sims.append(sim)
            types.append(type_)
            gravities.append(grav)

        self.sims = sims
        self.types = types
        self.gravities = gravities
        self.current_sim_idx += self.sims_in_memory

    def build_grids(self):
        sims = self.sims
        types = self.types
        gravities = self.gravities

        grids = []

        for (
            sim,
            typ,
            grav,
        ) in zip(sims, types, gravities):
            if grav is not None and torch.cuda.is_available():
                grav = grav.to("cuda")
            if torch.cuda.is_available():
                sim, typ = sim.to("cuda"), typ.to("cuda")

            num_particles = torch.count_nonzero(typ)
            # TODO: clamp only for WBC
            grid = interp.create_grid_cluster_batch(
                self.grid_coords,
                torch.clamp(sim[..., : self.dim], min=LOW, max=HIGH),
                sim[..., :num_particles, self.dim :],
                torch.tile(typ[None, :], (sim.shape[0], 1)),
                self.interp_fn,
                size=self.size,
                interaction_radius=self.interaction_radius,
            )
            if grav is not None:
                grid = torch.cat(
                    (grid, torch.tile(grav[None, None, None], (*grid.shape[:-1], 1))),
                    axis=-1,
                )

            grids.append(grid.to("cpu"))

        self.grids = grids

        self.gmean = torch.mean(
            torch.cat(grids, axis=0), dim=(0, 1, 2), keepdim=True
        )
        self.gstd = torch.std(torch.cat(grids, axis=0), dim=(0, 1, 2), keepdim=True)
        if torch.cuda.is_available():
            self.gmean = self.gmean.to("cuda")
            self.gstd = self.gstd.to("cuda")
        if self.gmean.shape[-1] == 6:
            with h5py.File(f"stats/wbc_{self.grid_size}_stats_mega_redux.h5", "r") as f:
                self.gmean = torch.tensor(f["mean"][()], device="cuda")
                self.gstd = torch.tensor(f["std"][()], device="cuda")

                # print(self.gmean, self.gstd)

        self.gstd[self.gstd == 0] = 1.0
        # if self.gstd == 0:
        # if self.gmean.shape[-1] == 6:
        #     self.gmean[..., -2:] = 0
        #     self.gstd[..., -2:] = 1

    def get_dataloader(self):
        self.load_sim_buffer()
        self.build_grids()

        state_list = []
        target_list = []
        grid_list = []
        grid_target_list = []

        for grid, sim in zip(self.grids, self.sims):
            states = sim[self.indices]
            targets = sim[self.target_indices]
            grids = grid[self.indices]
            grid_targets = grid[self.target_indices]

            state_list.append(states)
            target_list.append(targets)
            grid_list.append(grids)
            grid_target_list.append(grid_targets)

        states = list_to_padded(state_list)
        targets = list_to_padded(target_list)
        # states = torch.stack(state_list)
        # targets = torch.stack(target_list)

        grids = torch.stack(grid_list)
        grid_targets = torch.stack(grid_target_list)

        types = list_to_padded(self.types)
        # types = torch.stack(self.types)

        dataset = SmallDataset(
            grids=grids,
            states=states,
            targets=targets,
            grid_targets=grid_targets,
            types=types,
        )

        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=True,
            # generator=torch.Generator(device="cuda"),
        )
        return dataloader
