"""Utilities for data handling and processing."""

import random
import os
import os.path as osp
import h5py
import torch
from torch import Tensor
from torch.utils.data import Dataset
from typing import Tuple, Optional, List
import numpy as np
from torch.utils.data import ConcatDataset

EPS = 1e-6
MAX_X_DIM = 4
MAX_Y_DIM = 3


def has_nan_or_inf(tensor: Tensor, name: str, log: callable = print) -> bool:
    """Check if a tensor contains NaN values."""
    if torch.isnan(tensor).any():
        log(f"{name} contains NaNs:\n{tensor}")
        return True

    if torch.isinf(tensor).any():
        log(f"{name} contains Infs:\n{tensor}")
        return True

    return False


def tnp(x) -> np.ndarray:
    """Convert a PyTorch tensor to a NumPy array, detaching it from the computation graph and moving it to CPU if necessary."""
    return x.detach().cpu().numpy() if isinstance(x, Tensor) else x


def set_all_seeds(seed, deterministic: bool = False):
    torch.use_deterministic_algorithms(deterministic)
    torch.backends.cudnn.deterministic = deterministic
    torch.backends.cudnn.benchmark = deterministic
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)


def get_random_seed(seed: Optional[int] = None) -> int:
    """Get a random seed if not provided."""
    seed = random.randint(0, 2**32 - 1) if seed is None else seed
    return seed


def group_ele_count(hdf5_path: str, grp_name: Optional[str] = None):
    """Count number of objects in a group, or the entire HDF5 file if no group is specified."""
    with h5py.File(hdf5_path, "r") as f:
        if grp_name is not None:
            # Count objects in a specific group if `grp_name` is provided
            if grp_name in f:
                grp = f.get(grp_name)
                count = len(list(grp.keys()))
            else:
                count = 0
        else:
            # Otherwise count objects in the entire file
            count = len(list(f.keys()))
    return count


def save_dataset(hdf5_path: str, grp_name: str, inputs: Tensor, targets: Tensor):
    """Add new dataset as a subgroup given by `grp_name` to an HDF5 file."""
    with h5py.File(hdf5_path, "a") as f:  # "a" mode to append datasets
        inputs_np = inputs.float().detach().cpu().numpy()
        targets_np = targets.float().detach().cpu().numpy()

        if grp_name in f:
            # if group already exists, append
            grp = f.get(grp_name)
        else:
            # Otherwise create a new group
            grp = f.create_group(grp_name)

        grp.create_dataset("inputs", data=inputs_np)
        grp.create_dataset("targets", data=targets_np)

        # Store dimension metadata
        grp.attrs["input_dim"] = inputs_np.shape[1] if len(inputs_np.shape) > 1 else 1
        grp.attrs["output_dim"] = (
            targets_np.shape[1] if len(targets_np.shape) > 1 else 1
        )


def load_dataset(
    hdf5_path: str, grp_name: str, device: str
) -> Tuple[Tensor, Tensor, int, int]:
    """Returns dataset with `grp_name` from an HDF5 file, or None if it doesn't exist."""
    with h5py.File(hdf5_path, "r") as f:
        if grp_name not in f:
            return None

        inputs = torch.tensor(f[grp_name]["inputs"][:], device=device)
        targets = torch.tensor(f[grp_name]["targets"][:], device=device)

        input_dim = f[grp_name].attrs["input_dim"]
        output_dim = f[grp_name].attrs["output_dim"]

    return inputs, targets, input_dim, output_dim


class MultiFileHDF5Dataset(Dataset):
    """Create a dataset from multiple HDF5 files.

    Args:
        file_paths: List of paths to HDF5 files.
        max_x_dim: Maximum dimension for input data.
        max_y_dim: Maximum dimension for target data.
        zero_mean: Make targets zero mean.
        standardize: Scale targets to a given range.
        range_scale: Range to scale targets to if standardize is True.
    """

    def __init__(
        self,
        file_paths: List[str],
        max_x_dim: int = MAX_X_DIM,
        max_y_dim: int = MAX_Y_DIM,
        zero_mean: bool = True,
        standardize: bool = True,
        range_scale: List[float] = [-1, 1],
    ):
        self.file_paths = file_paths
        self.max_x_dim = max_x_dim
        self.max_y_dim = max_y_dim
        self.zero_mean = zero_mean
        self.standardize = standardize
        self.range_scale = range_scale

        self.data_per_file = [self.get_data_size(path) for path in file_paths]
        self.cumulative_sizes = [0] + list(
            torch.cumsum(torch.tensor(self.data_per_file), dim=0)
        )
        self.total_size = self.cumulative_sizes[-1]

    def __len__(self):
        return self.total_size

    def __getitem__(self, idx):
        """Get item from the dataset at index `idx`."""
        file_idx = torch.searchsorted(
            torch.tensor(self.cumulative_sizes), idx, right=True
        ).item()
        file_idx -= 1
        local_idx = idx - self.cumulative_sizes[file_idx]
        return self.load_from_file(self.file_paths[file_idx], local_idx)

    def get_data_size(self, file_path):
        """Get size of dataset in an HDF5 file."""
        try:
            with h5py.File(file_path, "r") as f:
                data_size = len(list(f.keys()))
                return data_size
        except FileNotFoundError:
            print(f"File not found: {file_path}")

        

    def _zero_mean(self, yvals):
        means = np.mean(yvals, axis=0, keepdims=True)  # [1, y_dims]
        return yvals - means

    def _standardize(self, yvals):
        mins = yvals.min(axis=0, keepdims=True)
        maxs = yvals.max(axis=0, keepdims=True)
        norm_term = np.where(maxs - mins < EPS, 1.0, maxs - mins)
        scale_factor = self.range_scale[1] - self.range_scale[0]
        yvals = (yvals - mins) / norm_term
        yvals = self.range_scale[0] + yvals * scale_factor
        return yvals

    def load_from_file(self, file_path, local_idx):
        """Load data from a specific file at the local index."""
        try:
            with h5py.File(file_path, "r", swmr=True) as f:
                # NOTE retrieve with group name is much faster than retrieving with index!
                grp_name = f"dataset_{local_idx}"
                group = f[grp_name]

                xvals = group["inputs"][:]
                yvals = group["targets"][:]

                valid_x_counts = group.attrs["input_dim"]
                valid_y_counts = group.attrs["output_dim"]

            if self.standardize:
                yvals = self._standardize(yvals)
            if self.zero_mean:
                yvals = self._zero_mean(yvals)

            # Pad xvals and yvals to max to enable batching
            pad_x = self.max_x_dim - valid_x_counts
            pad_y = self.max_y_dim - valid_y_counts

            xvals = np.pad(xvals, ((0, 0), (0, pad_x)), "constant")
            yvals = np.pad(yvals, ((0, 0), (0, pad_y)), "constant")

        except FileNotFoundError:
            print(f"File not found: {file_path}")

        return xvals, yvals, valid_x_counts, valid_y_counts


class HDF5Dataset(Dataset):
    """Create dataset from HDF5 file.

    Args:
        hdf5_path (str): Path to the HDF5 file containing datasets
        max_x_dim (int): Maximum dimension for input data
        max_y_dim (int): Maximum dimension for target data
        zero_mean (bool): If True, subtract mean from targets
    """

    def __init__(
        self,
        hdf5_path: str,
        max_x_dim: int = MAX_X_DIM,
        max_y_dim: int = MAX_Y_DIM,
        zero_mean: bool = True,
        standardize: bool = True,
        range_scale: List[float] = [-1, 1],
        device: str = "cpu",
    ):
        self.max_x_dim = max_x_dim
        self.max_y_dim = max_y_dim
        self.hdf5_path = hdf5_path
        self.zero_mean = zero_mean
        self.standardize = standardize
        self.range_scale = range_scale
        self.device = device

        if not osp.exists(hdf5_path):
            raise FileNotFoundError(f"HDF5 file {hdf5_path} does not exist.")

        with h5py.File(hdf5_path, "r") as f:
            self.length = len(list(f.keys()))

    def _zero_mean(self, yvals):
        means = np.mean(yvals, axis=0, keepdims=True)  # [1, y_dims]
        return yvals - means

    def _standardize(self, yvals):
        mins = yvals.min(axis=0, keepdims=True)
        maxs = yvals.max(axis=0, keepdims=True)
        norm_term = np.where(maxs - mins < EPS, 1.0, maxs - mins)
        scale_factor = self.range_scale[1] - self.range_scale[0]
        yvals = (yvals - mins) / norm_term
        yvals = self.range_scale[0] + yvals * scale_factor
        return yvals

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        # Lazy loading
        with h5py.File(self.hdf5_path, "r", swmr=True) as f:
            # NOTE retrieve with group name is much faster than retrieving with index!
            grp_name = f"dataset_{idx}"
            group = f[grp_name]

            xvals = group["inputs"][:]
            yvals = group["targets"][:]

            valid_x_counts = group.attrs["input_dim"]
            valid_y_counts = group.attrs["output_dim"]

        if self.standardize:
            yvals = self._standardize(yvals)
        if self.zero_mean:
            yvals = self._zero_mean(yvals)

        assert not np.isnan(xvals).any(), f"xvals contains nan: {xvals.shape}\n{xvals}"
        assert not np.isnan(yvals).any(), f"yvals contains nan: {yvals.shape}\n{yvals}"

        # Pad xvals and yvals to max to enable batching
        pad_x = self.max_x_dim - valid_x_counts
        pad_y = self.max_y_dim - valid_y_counts

        assert pad_x >= 0, f"{self.max_x_dim} < {valid_x_counts}"
        assert pad_y >= 0, f"{self.max_y_dim} < {valid_y_counts}"

        xvals = np.pad(xvals, ((0, 0), (0, pad_x)), "constant")
        yvals = np.pad(yvals, ((0, 0), (0, pad_y)), "constant")

        return xvals, yvals, valid_x_counts, valid_y_counts


class RepeatDataset(Dataset):
    """Repeat a dataset multiple times."""

    def __init__(self, dataset, times):
        self.dataset = dataset
        self.times = times
        self.length = len(dataset)  # Avoid computing length multiple times

    def __len__(self):
        return len(self.dataset) * self.times

    def __getitem__(self, index):
        return self.dataset[index % self.length]


def get_dx_dy_datapaths(
    path: str, x_dim_list: int | list, y_dim_list: int | list
) -> List[str]:
    """Get datapaths for datasets organized by x_dim and y_dim.

    Returns:
        datapaths: `path/x_dim_dx/y_dim_dy/filename.ext`
        relative_paths_no_ext: `x_dim_dx/y_dim_dy/filename` without file extension.
    """
    datapaths = []
    relative_paths_no_ext = []

    for dx in x_dim_list:
        for dy in y_dim_list:
            relative_dir = f"x_dim_{dx}/y_dim_{dy}"
            data_dir = osp.join(path, relative_dir)

            if not osp.exists(data_dir):
                raise ValueError(f"Data directory {data_dir} does not exist.")

            for filename in os.listdir(data_dir):
                relative_path = osp.join(relative_dir, filename)
                rel_path_no_ext = osp.splitext(relative_path)[0]

                datapath = osp.join(data_dir, filename)

                if osp.isfile(datapath):
                    datapaths.append(datapath)
                    relative_paths_no_ext.append(rel_path_no_ext)

    return datapaths, relative_paths_no_ext


def get_datasets(
    datapaths: List[str],
    max_x_dim: int,
    max_y_dim: int,
    zero_mean: bool = True,
    standardize: bool = True,
    range_scale: List[float] = [-1, 1],
) -> List[HDF5Dataset]:
    """Create a list of hdf5 dataset instances from a list of datapaths."""
    dataset_list = [
        HDF5Dataset(
            hdf5_path=datapath,
            max_x_dim=max_x_dim,
            max_y_dim=max_y_dim,
            zero_mean=zero_mean,
            standardize=standardize,
            range_scale=range_scale,
        )
        for datapath in datapaths
    ]

    return dataset_list


def concat_repeat_datasets(
    dataset_list: List[HDF5Dataset], n_repeat: int = 1
) -> Dataset:
    return ConcatDataset([RepeatDataset(ds, times=n_repeat) for ds in dataset_list])
