import os
import h5py
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset


class EMGPretrainDataset(Dataset):
    """
    Dataset class to load a single .h5 file (must contain a "data" dataset with shape (N, C, T)).
    Supports optional transformations: z-score normalization and zero-padding of channels/time dimensions.
    """

    def __init__(
        self,
        file_path: str,        # Path to the .h5 file
        transform=True,
        squeeze=False,
        zero_pad_chans=None,
        zero_pad_toks=None,
        verbose=False          # Whether to print debug info
    ):
        """
        Args:
            file_path (str): Path to an .h5 file containing 'data' dataset (shape: (N, C, T)).
            transform (bool): If True, apply z-score normalization to each sample.
            squeeze (bool): If True, output shape will have an extra dimension (1, C, T).
            zero_pad_chans (int): Number of zeros to pad on the channel dimension.
            zero_pad_toks (int): Number of zeros to pad on the time dimension.
            verbose (bool): If True, prints loading info and sample debugging details.
        """
        super().__init__()
        self.file_path = file_path
        self.transform = transform
        self.squeeze = squeeze
        self.zero_pad_chans = zero_pad_chans
        self.zero_pad_toks = zero_pad_toks
        self.verbose = verbose

        # Open .h5 file (read-only)
        if not os.path.isfile(self.file_path):
            raise FileNotFoundError(f"{self.file_path} not found or not a file.")

        self.h5f = h5py.File(self.file_path, 'r')
        if "data" not in self.h5f:
            raise KeyError(f"The .h5 file {self.file_path} doesn't contain 'data' dataset.")

        # Dataset shape: (N, C, T)
        self.dataset = self.h5f["data"]
        self._num_samples = self.dataset.shape[0]

        # Print dataset information once during initialization
        if self.verbose:
            print(f"[EMGPretrainDataset] Loaded dataset from: {self.file_path}")
            print(f"[EMGPretrainDataset] Dataset shape: {self.dataset.shape} (expected: (N, C, T))")
            print(f"[EMGPretrainDataset] Total samples (N): {self._num_samples}")

    def __len__(self):
        return self._num_samples

    def __getitem__(self, idx: int):
        """
        Returns a tensor with shape (C, T) or (1, C, T), optionally applying
        z-score normalization and zero-padding.
        """
        # Print loading info only for the first few samples
        if self.verbose and idx < 5:
            print(f"[EMGPretrainDataset] Loading sample idx={idx}")

        # Retrieve data from "data" at index idx (shape: (C, T))
        x_np = self.dataset[idx]
        x = torch.tensor(x_np, dtype=torch.float32)

        # Replace NaN or Inf values with 0
        nan_inf_mask = x.isnan() | x.isinf()
        x[nan_inf_mask] = 0.0

        # Squeeze => (1, C, T)
        if self.squeeze:
            x = x.unsqueeze(0)

        # Apply z-score normalization
        if self.transform:
            mean_val = x.mean()
            std_val = x.std()
            x = (x - mean_val) / (std_val + 1e-10)

        # Zero-pad channels
        if self.zero_pad_chans is not None:
            if self.squeeze:
                # (1, C, T) => (1, T, C)
                x = x.transpose(1, 2)
                x = F.pad(x, (0, self.zero_pad_chans), value=0.0)
                # (1, T, C) => (1, C, T)
                x = x.transpose(1, 2)
            else:
                # (C, T) => (T, C)
                x = x.transpose(0, 1)
                x = F.pad(x, (0, self.zero_pad_chans), value=0.0)
                # (T, C) => (C, T)
                x = x.transpose(0, 1)

        # Zero-pad tokens (time dimension)
        if self.zero_pad_toks is not None:
            x = F.pad(x, (0, self.zero_pad_toks), value=0.0)

        # Print shape info only for the first few samples
        if self.verbose and idx < 5:
            print(f"[EMGPretrainDataset] Sample idx={idx} shape after transforms/padding: {tuple(x.shape)}")

        return {"input": x}

    def __del__(self):
        # Close .h5 file upon destruction
        if hasattr(self, 'h5f') and self.h5f is not None:
            if self.verbose:
                print(f"[EMGPretrainDataset] Closing file: {self.file_path}")
            self.h5f.close()