import torch
from torch.utils.data import Dataset
import numpy as np
from torch.utils.data import DataLoader

class CIPDataset(Dataset):
    def __init__(self, data, config, train=False):
        self.data = data
        for key in self.data:
            print(f"{key}: {self.data[key].shape}")
        self.train = train
        self.tau = config.exp.tau
        self.max_history_length = config.dataset.min_seq_length
        self.min_h = config.dataset.min_history_length
        self.repeats = config.exp.repeats
        np.random.seed(config.exp.seed)
        if train:
            self.history_lengths = np.random.randint(self.min_h, self.max_history_length - self.tau, self.repeats * 4)
        else:
            self.history_lengths = np.random.randint(self.min_h, self.max_history_length - self.tau, self.repeats)
        self.history_lengths = np.unique(self.history_lengths)
        self.repeats = len(self.history_lengths)
        self.samples_per_history = len(self.data['unscaled_outputs'])
        self.model = config.model.name
        self.config = config

    def __len__(self):
        return len(self.data['unscaled_outputs']) * self.repeats

    def __getitem__(self, index):
        history_group = index // self.samples_per_history
        data_index = index % self.samples_per_history
        
        history_length = self.history_lengths[history_group]

        if not self.train:
            start_idx = np.random.randint(0, max(self.max_history_length - self.tau - history_length - 1, 1))
        else:
            start_idx = 0

        start_idx = 0

        sample = {k: v[data_index] for k, v in self.data.items() 
                 if hasattr(v, '__len__') and len(v) == len(self.data['unscaled_outputs'])}

        H_t = {k: v[start_idx:history_length+start_idx] for k, v in sample.items() if hasattr(v, '__len__')}
        for k, v in sample.items():
            if not hasattr(v, '__len__'):
                H_t[k] = v
            elif len(v) <= 2:
                H_t[k] = v
        if sample['static_features'].ndim != sample['unscaled_outputs'].ndim:
            H_t['static_features'] = sample['static_features']
        
        if 'sample_indices' in self.data:
            H_t['sample_indices'] = self.data['sample_indices'][data_index]
        
        target = {k: v[history_length+start_idx:history_length+self.tau+start_idx] for k, v in sample.items() if hasattr(v, '__len__')}
        target['static_features'] = H_t['static_features']
        H_t['sequence_lengths'] = history_length
        if 'mimic3' in self.config.dataset.name:
            H_t['future_vitals'] = target['vitals']
        
        return H_t, target

def get_dataloader(dataset, batch_size, shuffle=False, seed=10, drop_last=False):
    def batch_sampler():
        for h_idx in range(dataset.repeats):
            
            start_idx = h_idx * dataset.samples_per_history
            end_idx = (h_idx + 1) * dataset.samples_per_history
            
            indices = list(range(start_idx, end_idx))
            if shuffle:
                np.random.shuffle(indices)
            
            for i in range(0, len(indices), batch_size):
                batch_indices = indices[i:min(i + batch_size, len(indices))]
                if drop_last and len(batch_indices) < batch_size:
                    continue  
                
                yield batch_indices

    return DataLoader(dataset, batch_sampler=list(batch_sampler()))

