import numpy as np
import math

import torch
from torch.utils.data import Dataset
from copy import deepcopy

class CancerDataset(Dataset):
    def __init__(self, processed_data):
        self.data = processed_data
        
    def __getitem__(self, idx) -> dict:
        result = {k: v[idx] for k, v in self.data.items() if hasattr(v, '__len__') and len(v) == len(self)}        
        return result

    def __len__(self):
        return self.data['inp_x'].shape[0]
    
    
def numpy2torch(dataset_collection):
    torch_datasets = {}
    for dataname, dataset in dataset_collection.items():
        torch_dataset = {}

        if isinstance(dataset, tuple):
            torch_datasets[dataname] = dataset
            continue

        for label, value in dataset.items():
            if (value.dtype == "float64") or (value.dtype == "float32"):
                torch_dataset[label] = torch.Tensor(value).float()
            elif (value.dtype == "int64") or (value.dtype == "int32"):
                torch_dataset[label] = torch.LongTensor(value)
            else:
                print("error")
                raise Exception()
                
        torch_datasets[dataname] = torch_dataset
        
    return torch_datasets


# -------------------------------------------------------------------------------
# dataset collection
# -------------------------------------------------------------------------------   
def set_multi_label(dataset_collection):
    for key in ['train_f', 'valid_f', 'test_cf']:
        dataset_collection[key] = set_data(dataset_collection[key])
        
    dataset_collection['test_cf_multi'] = set_data_multi(dataset_collection['test_cf_multi'])
    
    return dataset_collection


def set_data(dataset):

    num_patients, max_seq_length, num_features = dataset["current_iw"].shape

    current_treatments = np.zeros([num_patients, max_seq_length, 2])
    prev_treatments = np.zeros([num_patients, max_seq_length, 2])
    #
    for i in range(num_patients):
        for t in range(max_seq_length):
            current_iw = dataset["current_iw"][i,t]
            if current_iw == 0:
                current_treatments[i, t] = [0, 0]
            elif current_iw == 1:
                current_treatments[i, t] = [0, 1]                    
            elif current_iw == 2:
                current_treatments[i, t] = [1, 0]                 
            elif current_iw == 3:
                current_treatments[i, t] = [1, 1]
            else:
                raise NotImplementedError()

    prev_treatments[:, 1:, :]     = deepcopy(current_treatments[:, :-1, :]) 
    dataset["prev_treatments"]    = prev_treatments
    dataset["current_treatments"] = current_treatments

    return dataset

def set_data_multi(dataset):
    num_patients, max_seq_length, nsamples, ntaus, nfeatures  = dataset["current_iw"].shape

    current_treatments = np.zeros([num_patients, max_seq_length, nsamples, ntaus, 2])
    prev_treatments = np.zeros([num_patients, max_seq_length, nsamples, ntaus, 2])
    #
    for i in range(num_patients):
        for t in range(max_seq_length):
            for nsample in range(nsamples):
                for ntau in range(ntaus):
                    current_iw = dataset["current_iw"][i, t, nsample, ntau]
                    if current_iw == 0:
                        current_treatments[i, t, nsample, ntau] = [0, 0]
                    elif current_iw == 1:
                        current_treatments[i, t, nsample, ntau] = [0, 1]                    
                    elif current_iw == 2:
                        current_treatments[i, t, nsample, ntau] = [1, 0]                 
                    elif current_iw == 3:
                        current_treatments[i, t, nsample, ntau] = [1, 1]
                    else:
                        raise NotImplementedError()

    prev_treatments[:, :, :, 1:]     = deepcopy(current_treatments[:, :, :, :-1]) 
    dataset["prev_treatments"]    = prev_treatments
    dataset["current_treatments"] = current_treatments

    return dataset