import torch
import h5py
import numpy as np
from collections import deque

class HDF5Loader(torch.utils.data.Dataset):
    def __init__(self, hdf5_file, transform=None, squeeze=False, finetune=True, zero_pad_chans=None, zero_pad_toks=None):
        self.hdf5_file = hdf5_file
        self.squeeze = squeeze
        self.finetune = finetune
        self.transform = transform
        self.zero_pad_chans = zero_pad_chans
        self.zero_pad_toks = zero_pad_toks
        self.data = h5py.File(self.hdf5_file, 'r')
        self.keys = list(self.data.keys())

        self.index_map = []
        for key in self.keys:
            group_size = len(self.data[key]['X'])  # Always assume 'X' is present
            self.index_map.extend([(key, i) for i in range(group_size)])
        

    def __len__(self):
        return len(self.index_map)

    def __getitem__(self, index):
        group_key, sample_idx = self.index_map[index]
        grp = self.data[group_key]
        X = grp["X"][sample_idx]
        X = torch.FloatTensor(X)
        channel_mask = grp["channel_mask"][sample_idx]
        channel_mask = torch.FloatTensor(channel_mask)

        if self.finetune:
            Y = grp["y"][sample_idx]
            Y = torch.LongTensor([Y]).squeeze()
        
        if self.squeeze:
            X = X.unsqueeze(0)
        
        if self.transform:
            # min-max to [-1, 1]
            max_X = X.max()
            min_X = X.min()
            X = (X - min_X) / (max_X - min_X + 1e-10) # [0, 1]
            X = (X - 0.5) * 2 # [-0.5, 0.5] -> [-1, 1]
        
        if self.zero_pad_chans is not None:
            X = X.T
            X = torch.nn.functional.pad(X, (0, self.zero_pad_chans), "constant", 0.0).T
        
        if self.zero_pad_toks is not None:
            X = torch.nn.functional.pad(X, (0, self.zero_pad_toks), "constant", 0.0)
            
        if self.finetune:
            return X, Y, channel_mask
        else:
            return X, channel_mask

    def __del__(self):
        self.data.close()
        

        
