# dataset for counterfactual inference planning
import torch
from torch.utils.data import Dataset
import numpy as np
from torch.utils.data import DataLoader

class CIPDataset(Dataset):
    def __init__(self, data, config, train=False):
        self.data = data
        for key in self.data:
            print(f"{key}: {self.data[key].shape}")
        self.train = train
        self.tau = config.exp.tau
        self.max_history_length = config.dataset.min_seq_length
        self.min_h = config.dataset.min_history_length
        self.repeats = config.exp.repeats
        np.random.seed(config.exp.seed)
        #Generate unique history lengths
        if train:
            # arange from 5 to max_history_length - tau
            # self.history_lengths = np.arange(1, self.max_history_length - self.tau)
            self.history_lengths = np.random.randint(self.min_h, self.max_history_length - self.tau, self.repeats * 4)
            # self.history_lengths = np.arange(5, 6)
        else:
            self.history_lengths = np.random.randint(self.min_h, self.max_history_length - self.tau, self.repeats)
            # self.history_lengths = np.arange(5, 6)
        self.history_lengths = np.unique(self.history_lengths)
        self.repeats = len(self.history_lengths)
        
        #Calculate the number of samples for each history length
        self.samples_per_history = len(self.data['unscaled_outputs'])
        self.model = config.model.name
        self.config = config

    def __len__(self):
        return len(self.data['unscaled_outputs']) * self.repeats

    def __getitem__(self, index):
        history_group = index // self.samples_per_history
        data_index = index % self.samples_per_history
        
        history_length = self.history_lengths[history_group]

        # history_length = 30
        # self.tau = 9

        if not self.train:
            # print(f"self.max_history_length: {self.max_history_length}, self.tau: {self.tau}, history_length:{history_length}")
            start_idx = np.random.randint(0, max(self.max_history_length - self.tau - history_length - 1, 1))
        else:
            start_idx = 0

        start_idx = 0

        # print(f"start_idx:{start_idx}, history_length:{history_length}")

        sample = {k: v[data_index] for k, v in self.data.items() 
                 if hasattr(v, '__len__') and len(v) == len(self.data['unscaled_outputs'])}

        H_t = {k: v[start_idx:history_length+start_idx] for k, v in sample.items() if hasattr(v, '__len__')}
        # append no length to the history
        for k, v in sample.items():
            if not hasattr(v, '__len__'):
                H_t[k] = v
            elif len(v) <= 2:
                H_t[k] = v

        # print(f"sample['static_features'].ndim {sample['static_features'].ndim}, sample['unscaled_outputs'].ndim:{sample['unscaled_outputs'].ndim}")
        if sample['static_features'].ndim != sample['unscaled_outputs'].ndim:
            H_t['static_features'] = sample['static_features']
        
        if 'sample_indices' in self.data:
            H_t['sample_indices'] = self.data['sample_indices'][data_index]
        
        target = {k: v[history_length+start_idx:history_length+self.tau+start_idx] for k, v in sample.items() if hasattr(v, '__len__')}
        target['static_features'] = H_t['static_features']

        # H_t['all_radio_dosage'] = sample['radio_dosage']
        # H_t['all_cancer_volume'] = sample['cancer_volume']
        # H_t['all_recovery_flags'] = sample['recovery_flags']

        # print(f"H_t['sequence_lengths'] {H_t['sequence_lengths'], H_t['sequence_lengths'].shape}")
        H_t['sequence_lengths'] = history_length
        if 'mimic3' in self.config.dataset.name:
            H_t['future_vitals'] = target['vitals']

        # print(f"recovery_flags:{target['recovery_flags'].shape, sample['recovery_flags'].shape}")
        # exit()
        
        return H_t, target

def get_dataloader(dataset, batch_size, shuffle=False, seed=10, drop_last=False):
    def batch_sampler():
        # np.random.seed(seed)
        for h_idx in range(dataset.repeats):
            
            start_idx = h_idx * dataset.samples_per_history
            end_idx = (h_idx + 1) * dataset.samples_per_history
            
            indices = list(range(start_idx, end_idx))
            if shuffle:
                np.random.shuffle(indices)
            
            for i in range(0, len(indices), batch_size):
                batch_indices = indices[i:min(i + batch_size, len(indices))]
                
                #Implement drop_last function
                if drop_last and len(batch_indices) < batch_size:
                    continue  #Skip incomplete batch
                
                yield batch_indices

    return DataLoader(dataset, batch_sampler=list(batch_sampler()))

