import os
import h5py
import torch
from torch.utils.data import Dataset


class FinetuneDataset(Dataset):
    """
    Minimal Dataset class designed for fine-tuning.
    Reads data from an .h5 file containing:
        - "input": shape (N, C, T)
        - "label": shape (N,)

    Returns a dictionary {"input": x, "label": y}.

    Optional: apply z-score normalization (transform=True).
    Other features such as channel expansion or zero-padding are not included.
    """

    def __init__(
        self,
        file_path: str,           # Path to the .h5 file
        transform: bool = False,  # Whether to apply z-score normalization
        verbose: bool = False     # Whether to print debug information
    ):
        """
        Args:
            file_path (str): Path to .h5 file (must contain 'input': (N, C, T), 'label': (N,)).
            transform (bool): If True, applies z-score normalization to each input sample.
            verbose (bool): If True, prints loading information for debugging purposes (only first few samples).
        """
        super().__init__()
        self.file_path = file_path
        self.transform = transform
        self.verbose = verbose

        # Open .h5 file and verify contents
        if not os.path.isfile(self.file_path):
            raise FileNotFoundError(f"{self.file_path} not found.")
        self.h5f = h5py.File(self.file_path, 'r')

        if "input" not in self.h5f:
            raise KeyError(f"{self.file_path} does not contain 'input' dataset.")
        if "label" not in self.h5f:
            raise KeyError(f"{self.file_path} does not contain 'label' dataset.")

        self.data_ds = self.h5f["input"]   # => shape (N, C, T)
        self.label_ds = self.h5f["label"]  # => shape (N, )
        self._num_samples = self.data_ds.shape[0]

        if self.label_ds.shape[0] != self._num_samples:
            raise ValueError("Mismatch between number of samples in 'input' and 'label' datasets.")

        if self.verbose:
            print(f"[FinetuneDataset] Loaded file: {self.file_path}")
            print(f"[FinetuneDataset] Data shape: {self.data_ds.shape} (N, C, T)")
            print(f"[FinetuneDataset] Label shape: {self.label_ds.shape} (N, )")
            print(f"[FinetuneDataset] Total samples: {self._num_samples}")

    def __len__(self):
        return self._num_samples

    def __getitem__(self, idx: int):
        """
        Returns:
            dict:
                "input": Tensor of shape (C, T), dtype float32
                "label": Tensor scalar, dtype int64
        """
        if self.verbose and idx < 5:
            print(f"[FinetuneDataset] Loading sample idx={idx}")

        # Retrieve data from h5
        x_np = self.data_ds[idx]    # => shape (C, T)
        y_np = self.label_ds[idx]   # => scalar

        x = torch.tensor(x_np, dtype=torch.float32)
        y = torch.tensor(y_np, dtype=torch.long)

        # Replace NaNs or Infs with zero
        x[torch.isnan(x)] = 0.0
        x[torch.isinf(x)] = 0.0

        # Optional z-score normalization
        if self.transform:
            mean_val = x.mean()
            std_val = x.std()
            x = (x - mean_val) / (std_val + 1e-10)

        if self.verbose and idx < 5:
            print(f"[FinetuneDataset] idx={idx}, input shape={tuple(x.shape)}, label={y.item()}")

        return {"input": x, "label": y}

    def __del__(self):
        # Close h5 file upon object deletion
        if hasattr(self, 'h5f') and self.h5f is not None:
            if self.verbose:
                print(f"[FinetuneDataset] Closing file: {self.file_path}")
            self.h5f.close()
