import torch
import h5py
import numpy as np
from collections import deque

class MobiLoader(torch.utils.data.Dataset):
    def __init__(self, hdf5_file, transform=None, squeeze=False, finetune=True, minmax=True, cache_size=0, pad_up_to_max_chans=None, use_cache=False, zero_pad_chans=None, zero_pad_toks=None):
        self.hdf5_file = hdf5_file
        self.squeeze = squeeze
        self.cache_size = cache_size
        self.finetune = finetune
        self.use_cache = use_cache
        self.transform = transform
        self.zero_pad_chans = zero_pad_chans
        self.zero_pad_toks = zero_pad_toks
        self.pad_up_to_max_chans = pad_up_to_max_chans
        self.data = h5py.File(self.hdf5_file, 'r')
        self.keys = list(self.data.keys())
        self.minmax = minmax

        self.index_map = []
        for key in self.keys:
            group_size = len(self.data[key]['X'])  # Always assume 'X' is present
            self.index_map.extend([(key, i) for i in range(group_size)])
        
        # Cache to store recently accessed samples
        if self.use_cache:
            self.cache = {}
            self.cache_queue = deque(maxlen=self.cache_size)

    def min_max(self, X):                
        ''' min-max transform across given `X` to range `[-1, 1]` '''
        max_X = X.max(axis=0)[0]
        min_X = X.min(axis=0)[0]
        X = (X - min_X) / (max_X - min_X + 1e-10) # [0, 1]
        X = (X - 0.5) * 2 # [-0.5, 0.5] -> [-1, 1]
        return X
    
    def preprocessing(self, X):
        '''
        Preprocess X according to given settings. Considering the following:
        `self.squeeze`: unsqueeze 0-th dimension (`bool`)
        `self.minmax`: apply min-max transform (`bool`)
        `self.zero_pad_chans`: zero pad channels by `self.zero_pad_chans` amount (`int` or `None`, the latter to avoid this step)
        `self.zero_pad_toks`: zero pad patch size per each channel by `self.zero_pad_toks` amount (`int` or `None`, the latter to avoid this step)
        '''
        if self.squeeze:
            X = X.unsqueeze(0)
            
        if self.minmax:
            X = self.min_max(X)        
        
        if self.zero_pad_chans is not None:
            X = X.T
            X = torch.nn.functional.pad(X, (0, self.zero_pad_chans), "constant", 0.0).T
        
        if self.zero_pad_toks is not None:
            X = torch.nn.functional.pad(X, (0, self.zero_pad_toks), "constant", 0.0)
    
        return X
    
    def __len__(self):
        return len(self.index_map)

    def __getitem__(self, index):
        if self.use_cache and index in self.cache:
            cached_data = self.cache[index]
            X = cached_data[0]
            if self.finetune:
                Y = torch.tensor(np.array(cached_data[1])).to(torch.float32)
        else:
            group_key, sample_idx = self.index_map[index]
            grp = self.data[group_key]
            
            X = torch.tensor(np.array(grp["X"][sample_idx])).to(torch.float32)
            
            cache_tuple = [X]
            if self.finetune:
                Y = torch.tensor(np.array(grp["y"][sample_idx])).to(torch.float32) / 90
                cache_tuple.append(Y)

            if self.use_cache:
                self.cache[index] = tuple(cache_tuple)
                self.cache_queue.append(index)
        
        # preprocessing data (either data from cache or new data)
        X = self.preprocessing(X)
        to_pad = 0
        if self.pad_up_to_max_chans is not None:
            num_real_chans = X.shape[0]
            to_pad = self.pad_up_to_max_chans - num_real_chans
            # assert to_pad > 0                
            X = X.T
            X = torch.nn.functional.pad(X, (0, to_pad), "constant", 0.0).T

        return_dict = {"input":X}
        if self.finetune:
            return_dict["label"] = Y
        if self.pad_up_to_max_chans:
            return_dict["nr_padded_channels"] = to_pad
        return return_dict  
          
    def __del__(self):
        self.data.close()
        

        
