import numpy as np
from sklearn.model_selection import train_test_split
from copy import deepcopy
import logging


logger = logging.getLogger(__name__)


class SyntheticDatasetCollection:
    """
    Dataset collection (train_f, val_f, test_cf_one_step, test_cf_treatment_seq)
    """

    def __init__(self, **kwargs):
        self.seed = None

        self.processed_data_encoder = False
        self.processed_data_decoder = False
        self.processed_data_multi = False
        self.processed_data_msm = False

        self.train_f = None
        self.val_f = None
        self.test_cf_one_step = None
        self.test_cf_treatment_seq = None
        self.train_scaling_params = None
        self.projection_horizon = None

        self.autoregressive = None
        self.has_vitals = None

    def process_data_encoder(self):
        self.train_f.process_data(self.train_scaling_params)
        self.val_f.process_data(self.train_scaling_params)
        self.test_cf_one_step.process_data(self.train_scaling_params)
        self.processed_data_encoder = True

    def process_propensity_train_f(self, propensity_treatment, propensity_history):
        """
        Generate stabilized weights for RMSN for the train subset
        Args:
            propensity_treatment: Propensity treatment network
            propensity_history: Propensity history network
        """
        prop_treat_train_f = propensity_treatment.get_propensity_scores(self.train_f)
        prop_hist_train_f = propensity_history.get_propensity_scores(self.train_f)
        self.train_f.data['stabilized_weights'] = np.prod(prop_treat_train_f / prop_hist_train_f, axis=2)

    def process_data_decoder(self, encoder, save_encoder_r=False):
        """
        Used by CRN, RMSN, EDCT
        """
        self.train_f.process_data(self.train_scaling_params)
        self.val_f.process_data(self.train_scaling_params)
        self.test_cf_treatment_seq.process_data(self.train_scaling_params)

        # Representation generation / One-step ahead prediction with encoder
        r_train_f = encoder.get_representations(self.train_f)
        r_val_f = encoder.get_representations(self.val_f)
        r_test_cf_treatment_seq = encoder.get_representations(self.test_cf_treatment_seq)
        outputs_test_cf_treatment_seq = encoder.get_predictions(self.test_cf_treatment_seq)

        # Splitting time series wrt specified projection horizon / Preparing test sequences
        self.train_f.process_sequential(r_train_f, self.projection_horizon, save_encoder_r=save_encoder_r)
        self.val_f.process_sequential(r_val_f, self.projection_horizon, save_encoder_r=save_encoder_r)
        # Considers only last timesteps according to projection horizon
        self.test_cf_treatment_seq.process_sequential_test(self.projection_horizon, r_test_cf_treatment_seq,
                                                           save_encoder_r=save_encoder_r)
        self.test_cf_treatment_seq.process_autoregressive_test(r_test_cf_treatment_seq, outputs_test_cf_treatment_seq,
                                                               self.projection_horizon, save_encoder_r=save_encoder_r)
        self.processed_data_decoder = True

    def post_process_data(self, dataset, len_past):
        main_key = 'cancer_volume'
        if main_key not in dataset.data.keys():
            for key in dataset.data.keys():
                if len(dataset.data[key].squeeze().shape) == 3:
                    main_key = key
                    print(main_key)
                    break
        
        # min_no_samples 보다 큰 sample 고르기
        list_idx = []
        for i in range(0, dataset.data['outputs'].shape[0]):
            if dataset.data['future_past_split'][i] < len_past:
                continue
            list_idx.append(i)
        
        for key in dataset.data.keys():
            dataset.data[key] = dataset.data[key][list_idx]

        # key generation
        dataset.data_f = {}    
        for key in dataset.data.keys():
            # print(key)
            if key == 'static_features':
                dataset.data_f[key] = dataset.data[key]
                continue
            if key == 'cancer_volume':
                dataset.data_f[key] = dataset.data[key]
                continue
            if len(dataset.data[key].squeeze().shape) == 1:
                # self.train_f.data_f[key] = np.zeros((self.train_f.data[key].shape[0]))
                dataset.data_f[key] = dataset.data[key]
            if len(dataset.data[key].squeeze().shape) == 2:
                total_shape = dataset.data[main_key].shape[1]
                cur_shape = dataset.data[key].shape[1]
                if len(dataset.data[key].shape) == 2:
                    dataset.data_f[key] = np.zeros((dataset.data[key].shape[0], len_past-(total_shape-cur_shape)))
                if len(dataset.data[key].shape) == 3:
                    dataset.data_f[key] = np.zeros((dataset.data[key].shape[0], len_past-(total_shape-cur_shape), 1))
            if len(dataset.data[key].squeeze().shape) == 3:
                total_shape = dataset.data[main_key].shape[1]
                cur_shape = dataset.data[key].shape[1]
                dataset.data_f[key] = np.zeros((dataset.data[key].shape[0], len_past-(total_shape-cur_shape), dataset.data[key].shape[2]))
        
        # 매 sample 마다 min_no_samples 만큼 잘라서 복사
        for key in dataset.data.keys():
            # print(key)
            # print(dataset.data[key].shape)
            for i in range(dataset.data[key].shape[0]):
                if len(dataset.data[key].squeeze().shape) == 1:
                    continue
                if key == 'cancer_volume' or key == 'static_features':
                    continue
                split_idx = int(dataset.data['future_past_split'][i])
                if len(dataset.data[key].squeeze().shape) == 2:
                    total_shape = dataset.data[main_key].shape[1]
                    cur_shape = dataset.data[key].shape[1]
                    if len(dataset.data[key].shape) == 2:
                        dataset.data_f[key][i, :] = dataset.data[key][i, split_idx+1-len_past+(total_shape-cur_shape):split_idx+1]
                    if len(dataset.data[key].shape) == 3:
                        dataset.data_f[key][i, :, :] = dataset.data[key][i, split_idx+1-len_past+(total_shape-cur_shape):split_idx+1, :]
                if len(dataset.data[key].squeeze().shape) == 3:
                    total_shape = dataset.data[main_key].shape[1]
                    cur_shape = dataset.data[key].shape[1]
                    dataset.data_f[key][i, :, :] = dataset.data[key][i, split_idx+1-len_past+(total_shape-cur_shape):split_idx+1, :]
        return dataset, list_idx

    def denormalize_data(self, data, mean, std):
        return data*std + mean

    def normalize_data(self, data, mean, std):
        return (data-mean)/std

    def transform_data(self, dataset, projection_horizon, len_past, flag_include_init = False):
        data = dataset.data.copy()
        keys = data.copy().keys()
        new_data = {}
        
        if 'vitals' in keys:
            len_max_seq = data['outputs'].shape[1]    # maximum of sequence (for semi-synthetic data)
        else:
            len_max_seq = data['outputs'].shape[1]    # maximum of sequence (for synthetic data)
        batch = data['outputs'].shape[0]              # batch size
        count = 0                                     # count for new batch size
        list_idx_batch = []
        for b in range(0, batch):
            len_seq = int(data['sequence_lengths'][b]-projection_horizon)  # possible length of each sequence
            if len_seq <= 0:
                # print(b, " th seq dropped!!")
                continue
            
            for i in range(1, len_seq):
                for key in keys:
                    # compute difference in time
                    if len(data[key].shape) > 1:
                        diff = len_max_seq - data[key].shape[1]



                    if key in ['cancer_volume', 'chemo_dosage', 'radio_dosage', 'chemo_application', 'radio_application', 'chemo_probabilities', 'radio_probabilities', 'death_flags', 'recovery_flags']:
                        if key == 'cancer_volume':
                            new_data[key] = data[key]
                            continue
                        
                        # create empty array
                        if key not in new_data.keys():
                            new_data[key] = np.zeros((batch*len_max_seq, len_past-diff))
                        
                        if i < len_past:
                            if flag_include_init:    # if initial time sequences are included
                                new_data[key][count] = np.concatenate((np.zeros((1, len_past-i)), data[key][b:b+1, :i-diff]), axis=1)       
                        else:
                            new_data[key][count] = data[key][b:b+1, i-len_past:i-diff]



                    elif key in ['prev_treatments', 'current_covariates', 'unscaled_outputs', 'prev_outputs', 'vitals', 'next_vitals']:
                        # mean, std for normalization and denormalization
                        output_stds, output_means = dataset.scaling_params['output_stds'], dataset.scaling_params['output_means']

                        # create empty array
                        if key not in new_data.keys():
                            if key == 'unscaled_outputs':
                                new_data[key] = np.zeros((batch*len_max_seq, projection_horizon, data[key].shape[-1]))
                                new_data['outputs'] = np.zeros((batch*len_max_seq, projection_horizon, data[key].shape[-1]))
                                new_data['active_entries'] = np.zeros((batch*len_max_seq, projection_horizon, data['active_entries'].shape[-1]))
                                new_data['current_treatments'] = np.zeros((batch*len_max_seq, projection_horizon, data['current_treatments'].shape[-1]))
                            else:
                                new_data[key] = np.zeros((batch*len_max_seq, len_past-diff, data[key].shape[-1]))
                        
                        if i < len_past:
                            if flag_include_init:    # if initial time sequences are included      
                                if key == 'unscaled_outputs':
                                    new_data[key][count] = data[key][b:b+1, i-1:i+projection_horizon-1]                                                          # unscaled_outputs
                                    new_data['outputs'][count] = data['outputs'][b:b+1, i-1:i+projection_horizon-1]             # outputs
                                    new_data['active_entries'][count] = np.ones((1, projection_horizon, data['active_entries'].shape[-1])) # active_entries (all one due to non-autoregressive way) 
                                    cur_treatments = data['current_treatments'][b:b+1, i-1:i+projection_horizon-1, :]                          # [bth, i:i+projection_horizon, 1]
                                    new_data['current_treatments'][count] = cur_treatments                                                 # current treatments
                                else:
                                    # denormalization for prev_outputs
                                    if key == 'prev_outputs':
                                        data[key][b:b+1] = self.denormalize_data(data[key][b:b+1].copy(), output_means, output_stds)
                                    
                                    new_data[key][count] = np.concatenate((np.zeros((1, len_past-i+diff, data[key].shape[-1])), data[key][b:b+1, :i-diff, :].copy()), axis=1)  # [bth, ((len_past-i) zeors, 0:i seq), 1]

                                    # normalization for prev_outputs 
                                    if key == 'prev_outputs':
                                        new_data[key][count:count+1] = self.normalize_data(new_data[key][count:count+1].copy(), output_means, output_stds)
                                        data[key][b:b+1] = self.normalize_data(data[key][b:b+1].copy(), output_means, output_stds)
                        else:
                            if key == 'unscaled_outputs':
                                new_data[key][count] = data[key][b:b+1, i-1:i+projection_horizon-1]                                                          # unscaled_outputs
                                new_data['outputs'][count] = data['outputs'][b:b+1, i-1:i+projection_horizon-1]             # outputs
                                new_data['active_entries'][count] = np.ones((1, projection_horizon, data['active_entries'].shape[-1])) # active_entries (all one due to non-autoregressive way) 
                                cur_treatments = data['current_treatments'][b:b+1, i-1:i+projection_horizon-1, :]                          # [bth, i:i+projection_horizon, 1]
                                new_data['current_treatments'][count] = cur_treatments                                                 # current treatments
                            else:
                                # denormalization for prev_outputs
                                if key == 'prev_outputs':
                                    data[key][b:b+1] = self.denormalize_data(data[key][b:b+1].copy(), output_means, output_stds)

                                new_data[key][count] = data[key][b:b+1, i-len_past:i-diff, :]

                                # normalization for prev_outputs 
                                if key == 'prev_outputs':
                                    new_data[key][count:count+1] = self.normalize_data(new_data[key][count:count+1].copy(), output_means, output_stds)
                                    data[key][b:b+1] = self.normalize_data(data[key][b:b+1].copy(), output_means, output_stds)



                    elif key in ['sequence_lengths', 'patient_types']:
                        # create empty array
                        if key not in new_data.keys():
                            new_data[key] = np.zeros(batch*len_max_seq)

                        if i < len_past:
                            if flag_include_init: # if initial time sequences are included
                                new_data[key][count] = data[key][b:b+1]
                        else:
                            new_data[key][count] = data[key][b:b+1]



                    elif key in ['static_features']:
                        # create empty array
                        if key not in new_data.keys():
                            new_data[key] = np.zeros((batch*len_max_seq, data[key].shape[-1]))

                        if i < len_past:
                            if flag_include_init:  # if initial time sequences are included
                                if data[key].shape[1] > 1:
                                    # new_data[key][count]= data[key][b:b+1, data[key].shape[-1]]
                                    new_data[key][count]= data[key][b]
                                else:   # tumor_generator
                                    new_data[key][count]= data[key][b]
                        else:
                            if data[key].shape[1] > 1:
                                # new_data[key][count]= data[key][b:b+1, data[key].shape[-1]]
                                new_data[key][count]= data[key][b]
                            else:
                                new_data[key][count]= data[key][b]



                    # count
                    if key == list(keys)[-1]:
                        if i < len_past:
                            if flag_include_init:  # if initial time sequences are included
                                count += 1
                                list_idx_batch.append(b)
                        else:
                            count += 1
                            list_idx_batch.append(b)
                        
        for key in new_data.keys():
            new_data[key] = new_data[key][:count]
            
        return new_data, list_idx_batch
    
    def remove_min_samples(self, dataset, len_past):
        # min_no_samples 보다 큰 sample 고르기
        list_idx = []
        for i in range(0, dataset.data['outputs'].shape[0]):
            if dataset.data['future_past_split'][i] < len_past:
                continue
            list_idx.append(i)
        
        for key in dataset.data.keys():
            dataset.data[key] = dataset.data[key][list_idx]
        for key in dataset.data_processed_seq.keys():
            dataset.data_processed_seq[key] = dataset.data_processed_seq[key][list_idx]
        return dataset
    
    def transform_test_data(self, dataset, projection_horizon, len_past):
        data = dataset.data.copy()
        keys = data.copy().keys()
        new_data = {}
        
        len_max_seq = data['outputs'].shape[1]    # maximum of sequence (for synthetic data)
        batch = data['outputs'].shape[0]              # batch size
        count = 0                                     # count for new batch size
        list_idx_batch = []
        for b in range(0, batch):
            len_seq = int(data['sequence_lengths'][b]-projection_horizon)  # possible length of each sequence
            if len_seq <= 0:
                print(b, " th seq dropped!!")
                continue
            
            for i in [len_seq+1]:
                for key in keys:
                    # compute difference in time
                    if len(data[key].shape) > 1:
                        diff = len_max_seq - data[key].shape[1]


                    if key in ['cancer_volume', 'chemo_dosage', 'radio_dosage', 'chemo_application', 'radio_application', 'chemo_probabilities', 'radio_probabilities', 'death_flags', 'recovery_flags']:
                        if key == 'cancer_volume':
                            new_data[key] = data[key]
                            continue
                        
                        # create empty array
                        if key not in new_data.keys():
                            new_data[key] = np.zeros((batch*len_max_seq, len_past-diff))
                        
                        new_data[key][count] = data[key][b:b+1, i-len_past:i-diff]



                    elif key in ['prev_treatments', 'current_covariates', 'unscaled_outputs', 'prev_outputs', 'vitals', 'next_vitals']:
                        # mean, std for normalization and denormalization
                        output_stds, output_means = dataset.scaling_params['output_stds'], dataset.scaling_params['output_means']

                        # create empty array
                        if key not in new_data.keys():
                            if key == 'unscaled_outputs':
                                new_data[key] = np.zeros((batch*len_max_seq, projection_horizon, data[key].shape[-1]))
                                new_data['outputs'] = np.zeros((batch*len_max_seq, projection_horizon, data[key].shape[-1]))
                                new_data['active_entries'] = np.zeros((batch*len_max_seq, projection_horizon, data['active_entries'].shape[-1]))
                                new_data['current_treatments'] = np.zeros((batch*len_max_seq, projection_horizon, data['current_treatments'].shape[-1]))
                            else:
                                new_data[key] = np.zeros((batch*len_max_seq, len_past-diff, data[key].shape[-1]))
                        
                        
                        if key == 'unscaled_outputs':
                            new_data[key][count] = data[key][b:b+1, i-1:i+projection_horizon-1]                         # unscaled_outputs
                            new_data['outputs'][count] = data['outputs'][b:b+1, i-1:i+projection_horizon-1]             # outputs
                            new_data['active_entries'][count] = np.ones((1, projection_horizon, data['active_entries'].shape[-1])) # active_entries (all one due to non-autoregressive way) 
                            cur_treatments = data['current_treatments'][b:b+1, i-1:i+projection_horizon-1, :]                      # [bth, i:i+projection_horizon, 1]
                            new_data['current_treatments'][count] = cur_treatments                                                 # current treatments
                        else:
                            # denormalization for prev_outputs
                            if key == 'prev_outputs':
                                data[key][b:b+1] = self.denormalize_data(data[key][b:b+1].copy(), output_means, output_stds)

                            new_data[key][count] = data[key][b:b+1, i-len_past:i-diff, :]

                            # normalization for prev_outputs 
                            if key == 'prev_outputs':
                                new_data[key][count:count+1] = self.normalize_data(new_data[key][count:count+1].copy(), output_means, output_stds)
                                data[key][b:b+1] = self.normalize_data(data[key][b:b+1].copy(), output_means, output_stds)



                    elif key in ['sequence_lengths', 'patient_types']:
                        # create empty array
                        if key not in new_data.keys():
                            new_data[key] = np.zeros(batch*len_max_seq)

                        new_data[key][count] = data[key][b:b+1]



                    elif key in ['static_features']:
                        # create empty array
                        if key not in new_data.keys():
                            new_data[key] = np.zeros((batch*len_max_seq, data[key].shape[-1]))

                        if data[key].shape[1] > 1:
                            new_data[key][count]= data[key][b]
                        else:
                            new_data[key][count]= data[key][b]



                    # count
                    if key == list(keys)[-1]:
                        count += 1
                        list_idx_batch.append(b)
                        
        for key in new_data.keys():
            new_data[key] = new_data[key][:count]
            
        return new_data, list_idx_batch

    def process_data_multi(self):
        """
        Used by CT
        """
        self.train_f.process_data(self.train_scaling_params)
        self.train_f_non = deepcopy(self.train_f)
        print("start to transform train data")
        self.train_f_non.data, _ = self.transform_data(self.train_f, self.projection_horizon, self.len_past, self.flag_include_init)

        if hasattr(self, 'val_f') and self.val_f is not None:
            self.val_f.process_data(self.train_scaling_params)
            self.val_f_non = deepcopy(self.val_f)
            print("start to transform valid data")
            self.val_f_non.data, _ = self.transform_data(self.val_f, self.projection_horizon, self.len_past, self.flag_include_init)

        self.test_cf_one_step.process_data(self.train_scaling_params)
        self.test_cf_treatment_seq.process_data(self.train_scaling_params)
        self.test_cf_treatment_seq.process_sequential_test(self.projection_horizon)
        self.test_cf_treatment_seq.process_sequential_multi(self.projection_horizon)
        self.test_cf_one_step_non = deepcopy(self.test_cf_one_step)
        self.test_cf_treatment_seq_non = deepcopy(self.test_cf_treatment_seq)
        print("start to transform test data")
        self.test_cf_treatment_seq_non = self.remove_min_samples(self.test_cf_treatment_seq_non, self.len_past)
        print("removed min samples!!")
        self.test_cf_treatment_seq_non.data, _ = self.transform_test_data(self.test_cf_treatment_seq_non, self.projection_horizon, self.len_past)
        print("transformed test data!!")

        self.processed_data_multi = True
        
    def process_data_multi_mimic_iii_semisynthetic(self):
        """
        Used by CT
        """
        self.train_f.process_data(self.train_scaling_params)
        if not self.autoregressive:
            self.train_f.data, list_idx_batch = self.transform_data(self.train_f, self.projection_horizon, self.len_past, self.flag_include_init)

        if hasattr(self, 'val_f') and self.val_f is not None:
            self.val_f.process_data(self.train_scaling_params)
            if not self.autoregressive:
                self.val_f.data, _ = self.transform_data(self.val_f, self.projection_horizon, self.len_past, self.flag_include_init)

        self.test_cf_one_step.process_data(self.train_scaling_params)
        self.test_cf_treatment_seq.process_data(self.train_scaling_params)
        self.test_cf_treatment_seq.process_sequential_test(self.projection_horizon)
        self.test_cf_treatment_seq.process_sequential_multi(self.projection_horizon)
        self.test_cf_one_step_non = deepcopy(self.test_cf_one_step)
        self.test_cf_treatment_seq_non = deepcopy(self.test_cf_treatment_seq)
        if not self.autoregressive:
            print("start to transform test data")
            self.test_cf_treatment_seq_non = self.remove_min_samples(self.test_cf_treatment_seq_non, self.len_past)
            print("removed min samples!!")
            self.test_cf_treatment_seq_non.data, _ = self.transform_test_data(self.test_cf_treatment_seq_non, self.projection_horizon, self.len_past)
            print("transformed test data!!")
        
        self.processed_data_multi = True

    def split_train_f_holdout(self, holdout_ratio=0.1):
        """
        Used by G-Net
        """
        if not hasattr(self, 'train_f_holdout') and holdout_ratio > 0.0:
            self.train_f_holdout = deepcopy(self.train_f)
            for k, v in self.train_f.data.items():
                self.train_f.data[k], self.train_f_holdout.data[k] = train_test_split(v, test_size=holdout_ratio,
                                                                                      random_state=self.seed)
            logger.info(f'Splited train_f on train_f: {len(self.train_f)} and train_f_holdout: {len(self.train_f_holdout)}')

    def explode_cf_treatment_seq(self, mc_samples=1):
        """
        Producing mc_samples copies of test_cf_treatment_seq subset for further MC-Sampling (e.g. for G-Net)
        :param mc_samples: Number of copies
        """
        if not hasattr(self, 'test_cf_treatment_seq_mc'):
            logger.info(f'Exploding test_cf_treatment_seq {mc_samples} times')

            self.test_cf_treatment_seq_mc = []
            for m in range(mc_samples):
                self.test_cf_treatment_seq_mc.append(self.test_cf_treatment_seq)
                self.test_cf_treatment_seq_mc[m].data = deepcopy(self.test_cf_treatment_seq.data)


class RealDatasetCollection:
    """
    Dataset collection (train_f, val_f, test_f)
    """
    def __init__(self, **kwargs):
        self.seed = None

        self.processed_data_encoder = False
        self.processed_data_decoder = False
        self.processed_data_propensity = False
        self.processed_data_msm = False

        self.train_f = None
        self.val_f = None
        self.test_f = None
        self.train_scaling_params = None
        self.projection_horizon = None

        self.autoregressive = None
        self.has_vitals = None

    def process_data_encoder(self):
        pass

    def process_propensity_train_f(self, propensity_treatment, propensity_history):
        """
        Generate stabilized weights for RMSN for the train subset
        Args:
            propensity_treatment: Propensity treatment network
            propensity_history: Propensity history network
        """
        prop_treat_train_f = propensity_treatment.get_propensity_scores(self.train_f)
        prop_hist_train_f = propensity_history.get_propensity_scores(self.train_f)
        self.train_f.data['stabilized_weights'] = np.prod(prop_treat_train_f / prop_hist_train_f, axis=2)

    def process_data_decoder(self, encoder, save_encoder_r=False):
        """
        Used by CRN, RMSN, EDCT
        """
        # Multiplying test trajectories
        self.test_f.explode_trajectories(self.projection_horizon)

        # Representation generation / One-step ahead prediction with encoder
        r_train_f = encoder.get_representations(self.train_f)
        r_val_f = encoder.get_representations(self.val_f)
        r_test_f = encoder.get_representations(self.test_f)
        outputs_test_f = encoder.get_predictions(self.test_f)

        # Splitting time series wrt specified projection horizon / Preparing test sequences
        self.train_f.process_sequential(r_train_f, self.projection_horizon, save_encoder_r=save_encoder_r)
        self.val_f.process_sequential(r_val_f, self.projection_horizon, save_encoder_r=save_encoder_r)

        self.test_f.process_sequential_test(self.projection_horizon, r_test_f, save_encoder_r=save_encoder_r)
        self.test_f.process_autoregressive_test(r_test_f, outputs_test_f, self.projection_horizon, save_encoder_r=save_encoder_r)

        self.processed_data_decoder = True

    def process_data_multi(self):
        """
        Used by CT
        """
        self.test_f_multi = deepcopy(self.test_f)

        # Multiplying test trajectories
        self.test_f_multi.explode_trajectories(self.projection_horizon)
        self.test_f_multi.process_sequential_test(self.projection_horizon)
        self.test_f_multi.process_sequential_multi(self.projection_horizon)
        
        # if not self.autoregressive:
        self.train_f_non = deepcopy(self.train_f)
        self.val_f_non = deepcopy(self.val_f)
        self.test_f_non = deepcopy(self.test_f)
        self.test_f_multi_non = deepcopy(self.test_f_multi)
        self.train_f_non.data, _ = self.transform_data(self.train_f_non, self.projection_horizon, self.len_past, self.flag_include_init)
        self.val_f_non.data, _ = self.transform_data(self.val_f_non, self.projection_horizon, self.len_past, self.flag_include_init)
        self.test_f_non.data, _ = self.transform_data(self.test_f_non, self.projection_horizon, self.len_past, self.flag_include_init)
        self.test_f_multi_non = self.remove_min_samples(self.test_f_multi_non, self.len_past)
        self.test_f_multi_non.data, _ = self.transform_test_data(self.test_f_multi_non, self.projection_horizon, self.len_past)

        self.processed_data_multi = True

    def split_train_f_holdout(self, holdout_ratio=0.1):
        """
        Used by G-Net
        """
        if not hasattr(self, 'train_f_holdout') and holdout_ratio > 0.0:
            self.train_f_holdout = deepcopy(self.train_f)
            for k, v in self.train_f.data.items():
                self.train_f.data[k], self.train_f_holdout.data[k] = train_test_split(v, test_size=holdout_ratio,
                                                                                      random_state=self.seed)
            logger.info(f'Splited train_f on train_f: {len(self.train_f)} and train_f_holdout: {len(self.train_f_holdout)}')

    def explode_cf_treatment_seq(self, mc_samples=1):
        """
        Producing mc_samples copies of test_cf_treatment_seq subset for further MC-Sampling (e.g. for G-Net)
        :param mc_samples: Number of copies
        """
        if not hasattr(self, 'test_f_mc'):
            self.test_f_mc = []
            for m in range(mc_samples):
                logger.info(f'Exploding test_f {mc_samples} times')
                self.test_f_mc.append(self.test_f)
                self.test_f_mc[m].data = deepcopy(self.test_f.data)
    
    def denormalize_data(self, data, mean, std):
        return data*std + mean

    def normalize_data(self, data, mean, std):
        return (data-mean)/std
                
    def transform_data(self, dataset, projection_horizon, len_past, flag_include_init = False):
        data = dataset.data.copy()
        keys = data.copy().keys()
        new_data = {}
        
        if 'vitals' in keys:
            len_max_seq = data['outputs'].shape[1]    # maximum of sequence (for semi-synthetic data)
        else:
            len_max_seq = data['outputs'].shape[1]    # maximum of sequence (for synthetic data)
        batch = data['outputs'].shape[0]              # batch size
        count = 0                                     # count for new batch size
        list_idx_batch = []
        for b in range(0, batch):
            len_seq = int(data['sequence_lengths'][b]-projection_horizon)  # possible length of each sequence
            if len_seq <= 0:
                # print(b, " th seq dropped!!")
                continue
            
            for i in range(1, len_seq):
                for key in keys:
                    # compute difference in time
                    if len(data[key].shape) > 1:
                        diff = len_max_seq - data[key].shape[1]



                    if key in ['cancer_volume', 'chemo_dosage', 'radio_dosage', 'chemo_application', 'radio_application', 'chemo_probabilities', 'radio_probabilities', 'death_flags', 'recovery_flags']:
                        if key == 'cancer_volume':
                            new_data[key] = data[key]
                            continue
                        
                        # create empty array
                        if key not in new_data.keys():
                            new_data[key] = np.zeros((batch*len_max_seq, len_past-diff))
                        
                        if i < len_past:
                            if flag_include_init:    # if initial time sequences are included
                                new_data[key][count] = np.concatenate((np.zeros((1, len_past-i)), data[key][b:b+1, :i-diff]), axis=1)       
                        else:
                            new_data[key][count] = data[key][b:b+1, i-len_past:i-diff]



                    elif key in ['prev_treatments', 'current_covariates', 'unscaled_outputs', 'prev_outputs', 'vitals', 'next_vitals']:
                        # mean, std for normalization and denormalization
                        output_stds, output_means = dataset.scaling_params['output_stds'], dataset.scaling_params['output_means']

                        # create empty array
                        if key not in new_data.keys():
                            if key == 'unscaled_outputs':
                                new_data[key] = np.zeros((batch*len_max_seq, projection_horizon, data[key].shape[-1]))
                                new_data['outputs'] = np.zeros((batch*len_max_seq, projection_horizon, data[key].shape[-1]))
                                new_data['active_entries'] = np.zeros((batch*len_max_seq, projection_horizon, data['active_entries'].shape[-1]))
                                new_data['current_treatments'] = np.zeros((batch*len_max_seq, projection_horizon, data['current_treatments'].shape[-1]))
                            else:
                                new_data[key] = np.zeros((batch*len_max_seq, len_past-diff, data[key].shape[-1]))
                        
                        if i < len_past:
                            if flag_include_init:    # if initial time sequences are included      
                                if key == 'unscaled_outputs':
                                    new_data[key][count] = data[key][b:b+1, i-1:i+projection_horizon-1]                                                          # unscaled_outputs
                                    new_data['outputs'][count] = data['outputs'][b:b+1, i-1:i+projection_horizon-1]             # outputs
                                    new_data['active_entries'][count] = np.ones((1, projection_horizon, data['active_entries'].shape[-1])) # active_entries (all one due to non-autoregressive way) 
                                    cur_treatments = data['current_treatments'][b:b+1, i-1:i+projection_horizon-1, :]                          # [bth, i:i+projection_horizon, 1]
                                    new_data['current_treatments'][count] = cur_treatments                                                 # current treatments
                                else:
                                    # denormalization for prev_outputs
                                    if key == 'prev_outputs':
                                        data[key][b:b+1] = self.denormalize_data(data[key][b:b+1].copy(), output_means, output_stds)
                                    
                                    new_data[key][count] = np.concatenate((np.zeros((1, len_past-i+diff, data[key].shape[-1])), data[key][b:b+1, :i-diff, :].copy()), axis=1)  # [bth, ((len_past-i) zeors, 0:i seq), 1]

                                    # normalization for prev_outputs 
                                    if key == 'prev_outputs':
                                        new_data[key][count:count+1] = self.normalize_data(new_data[key][count:count+1].copy(), output_means, output_stds)
                                        data[key][b:b+1] = self.normalize_data(data[key][b:b+1].copy(), output_means, output_stds)
                        else:
                            if key == 'unscaled_outputs':
                                new_data[key][count] = data[key][b:b+1, i-1:i+projection_horizon-1]                                                          # unscaled_outputs
                                new_data['outputs'][count] = data['outputs'][b:b+1, i-1:i+projection_horizon-1]             # outputs
                                new_data['active_entries'][count] = np.ones((1, projection_horizon, data['active_entries'].shape[-1])) # active_entries (all one due to non-autoregressive way) 
                                cur_treatments = data['current_treatments'][b:b+1, i-1:i+projection_horizon-1, :]                          # [bth, i:i+projection_horizon, 1]
                                new_data['current_treatments'][count] = cur_treatments                                                 # current treatments
                            else:
                                # denormalization for prev_outputs
                                if key == 'prev_outputs':
                                    data[key][b:b+1] = self.denormalize_data(data[key][b:b+1].copy(), output_means, output_stds)

                                new_data[key][count] = data[key][b:b+1, i-len_past:i-diff, :]

                                # normalization for prev_outputs 
                                if key == 'prev_outputs':
                                    new_data[key][count:count+1] = self.normalize_data(new_data[key][count:count+1].copy(), output_means, output_stds)
                                    data[key][b:b+1] = self.normalize_data(data[key][b:b+1].copy(), output_means, output_stds)



                    elif key in ['sequence_lengths', 'patient_types']:
                        # create empty array
                        if key not in new_data.keys():
                            new_data[key] = np.zeros(batch*len_max_seq)

                        if i < len_past:
                            if flag_include_init: # if initial time sequences are included
                                new_data[key][count] = data[key][b:b+1]
                        else:
                            new_data[key][count] = data[key][b:b+1]



                    elif key in ['static_features']:
                        # create empty array
                        if key not in new_data.keys():
                            new_data[key] = np.zeros((batch*len_max_seq, data[key].shape[-1]))

                        if i < len_past:
                            if flag_include_init:  # if initial time sequences are included
                                new_data[key][count]= data[key][b]
                                # if data[key].shape[1] > 1:
                                #     new_data[key][count]= data[key][b:b+1, data[key].shape[-1]]
                                # else:
                                #     new_data[key][count]= data[key][b:b+1, 0]
                        else:
                            # print(data[key][b].shape)
                            # print(new_data[key][count].shape)
                            new_data[key][count] = data[key][b]
                            # if data[key].shape[1] > 1:
                            #     new_data[key][count]= data[key][b:b+1, data[key].shape[-1]]
                            # else:
                            #     new_data[key][count]= data[key][b:b+1, 0]
                            
                        



                    # count
                    if key == list(keys)[-1]:
                        if i < len_past:
                            if flag_include_init:  # if initial time sequences are included
                                count += 1
                                list_idx_batch.append(b)
                        else:
                            count += 1
                            list_idx_batch.append(b)
                        
        for key in new_data.keys():
            new_data[key] = new_data[key][:count]
            
        return new_data, list_idx_batch
    
    def remove_min_samples(self, dataset, len_past):
        # min_no_samples 보다 큰 sample 고르기
        list_idx = []
        for i in range(0, dataset.data['outputs'].shape[0]):
            if dataset.data['future_past_split'][i] < len_past:
                continue
            list_idx.append(i)
        
        for key in dataset.data.keys():
            dataset.data[key] = dataset.data[key][list_idx]
        return dataset
    
    def transform_test_data(self, dataset, projection_horizon, len_past):
        data = dataset.data.copy()
        keys = data.copy().keys()
        new_data = {}
        
        len_max_seq = data['outputs'].shape[1]    # maximum of sequence (for synthetic data)
        batch = data['outputs'].shape[0]              # batch size
        count = 0                                     # count for new batch size
        list_idx_batch = []
        for b in range(0, batch):
            len_seq = int(data['sequence_lengths'][b]-projection_horizon)  # possible length of each sequence
            if len_seq <= 0:
                print(b, " th seq dropped!!")
                continue
            
            for i in [len_seq+1]:
                for key in keys:
                    # compute difference in time
                    if len(data[key].shape) > 1:
                        diff = len_max_seq - data[key].shape[1]


                    if key in ['cancer_volume', 'chemo_dosage', 'radio_dosage', 'chemo_application', 'radio_application', 'chemo_probabilities', 'radio_probabilities', 'death_flags', 'recovery_flags']:
                        if key == 'cancer_volume':
                            new_data[key] = data[key]
                            continue
                        
                        # create empty array
                        if key not in new_data.keys():
                            new_data[key] = np.zeros((batch*len_max_seq, len_past-diff))
                        
                        new_data[key][count] = data[key][b:b+1, i-len_past:i-diff]



                    elif key in ['prev_treatments', 'current_covariates', 'unscaled_outputs', 'prev_outputs', 'vitals', 'next_vitals']:
                        # mean, std for normalization and denormalization
                        output_stds, output_means = dataset.scaling_params['output_stds'], dataset.scaling_params['output_means']

                        # create empty array
                        if key not in new_data.keys():
                            if key == 'unscaled_outputs':
                                new_data[key] = np.zeros((batch*len_max_seq, projection_horizon, data[key].shape[-1]))
                                new_data['outputs'] = np.zeros((batch*len_max_seq, projection_horizon, data[key].shape[-1]))
                                new_data['active_entries'] = np.zeros((batch*len_max_seq, projection_horizon, data['active_entries'].shape[-1]))
                                new_data['current_treatments'] = np.zeros((batch*len_max_seq, projection_horizon, data['current_treatments'].shape[-1]))
                            else:
                                new_data[key] = np.zeros((batch*len_max_seq, len_past-diff, data[key].shape[-1]))
                        
                        
                        if key == 'unscaled_outputs':
                            new_data[key][count] = data[key][b:b+1, i-1:i+projection_horizon-1]                         # unscaled_outputs
                            new_data['outputs'][count] = data['outputs'][b:b+1, i-1:i+projection_horizon-1]             # outputs
                            new_data['active_entries'][count] = np.ones((1, projection_horizon, data['active_entries'].shape[-1])) # active_entries (all one due to non-autoregressive way) 
                            cur_treatments = data['current_treatments'][b:b+1, i-1:i+projection_horizon-1, :]                      # [bth, i:i+projection_horizon, 1]
                            new_data['current_treatments'][count] = cur_treatments                                                 # current treatments
                        else:
                            # denormalization for prev_outputs
                            if key == 'prev_outputs':
                                data[key][b:b+1] = self.denormalize_data(data[key][b:b+1].copy(), output_means, output_stds)

                            new_data[key][count] = data[key][b:b+1, i-len_past:i-diff, :]

                            # normalization for prev_outputs 
                            if key == 'prev_outputs':
                                new_data[key][count:count+1] = self.normalize_data(new_data[key][count:count+1].copy(), output_means, output_stds)
                                data[key][b:b+1] = self.normalize_data(data[key][b:b+1].copy(), output_means, output_stds)



                    elif key in ['sequence_lengths', 'patient_types']:
                        # create empty array
                        if key not in new_data.keys():
                            new_data[key] = np.zeros(batch*len_max_seq)

                        new_data[key][count] = data[key][b:b+1]



                    elif key in ['static_features']:
                        # create empty array
                        if key not in new_data.keys():
                            new_data[key] = np.zeros((batch*len_max_seq, data[key].shape[-1]))

                        if data[key].shape[1] > 1:
                            new_data[key][count]= data[key][b]
                        else:
                            new_data[key][count]= data[key][b]



                    # count
                    if key == list(keys)[-1]:
                        count += 1
                        list_idx_batch.append(b)
                        
        for key in new_data.keys():
            new_data[key] = new_data[key][:count]
            
        return new_data, list_idx_batch
