import os

import numpy as np
import pandas as pd
import warnings
import torch
from copy import deepcopy
import copy
import matplotlib.pyplot as plt

import sys
warnings.simplefilter('ignore')

sys.path.append('../simulation')
from syntheticCancerDatasetCollection import SyntheticCancerDatasetCollection

# -------------------------------------------------------------------------------------------------------
#  make simulation data
# -------------------------------------------------------------------------------------------------------
class MakeSimData():
    def __init__(self, config):
        params = {}
        params["seed"]                   = config.dataset.seed        
        params["window_size"]            = config.dataset.window_size
        params["num_time_steps"]         = config.dataset.max_seq_length
        params["projection_horizon"]     = config.dataset.projection_horizon
        params["cf_seq_mode"]            = config.dataset.cf_seq_mode
        params["dict_wd"]                = config.dataset.dict_wd
        params["num_samples"]            = config.dataset.num_samples
                
        self.params = params
        
        # numer of patient
        num_patients = {
            'train' : config.dataset.num_patients.train,
            'valid' : config.dataset.num_patients.val,
            'test'  : config.dataset.num_patients.test
        }
        self.num_patients = num_patients
        
        # treat/dose coeff
        set_coeffs = {
            'treat_chemo': config.dataset.treat_chemo_coeff,
            'treat_radio': config.dataset.treat_radio_coeff,
            'dose_chemo' : config.dataset.dose_chemo_coeff,
            'dose_radio' : config.dataset.dose_radio_coeff
        }
        self.set_coeffs = set_coeffs
        
        # error check
        self.error_check()
        
        # display parameters
        print("simulation parametes")
        print("Common setting")
        print("   num_patients       : ")
        print("                 train: ", num_patients['train'])
        print("                 valid: ", num_patients['valid'])
        print("                 test : ", num_patients['test'])
        print("   seed               : ", params["seed"])
        print("   window size        : ", params["window_size"])
        print("   max time steps     : ", params["num_time_steps"])
        print("Coeff:")
        print("   treat_chemo_coeff  : ", set_coeffs["treat_chemo"])
        print("   treat_radio_coeff  : ", set_coeffs["treat_radio"])
        print("    dose_chemo_coeff  : ", set_coeffs["dose_chemo"])
        print("    dose_radio_coeff  : ", set_coeffs["dose_radio"])          
        print("For Decoder")
        print("   projection_horizon : ", params["projection_horizon"])
        print("   cf_seq_mode        : ", params["cf_seq_mode"])
              
        self.num_treatments         = 4
        self.num_category_v         = 3
        self.num_dosage_samples     = 2
        
    #
    # 
    #   
    def make(self):
        # simulation
        sim_data = SyntheticCancerDatasetCollection(set_coeffs             = self.set_coeffs,
                                                    # dict
                                                    num_patients           = self.num_patients,
                                                    dict_wd                = self.params["dict_wd"],
                                                    # parametes
                                                    seed                   = self.params["seed"],
                                                    window_size            = self.params["window_size"],
                                                    max_seq_length         = self.params["num_time_steps"],
                                                    projection_horizon     = self.params["projection_horizon"],
                                                    cf_seq_mode            = self.params["cf_seq_mode"],
                                                   )
        # pickle_map
        pickle_map = {'train_f'       : sim_data.train_f,
                      'valid_f'       : sim_data.val_f,
                      'test_cf'       : sim_data.test_cf,
                      'test_cf_multi' : sim_data.test_cf_multi
                     }
        self.pickle_map = pickle_map
        
        # scaling paramters
        self.train_scaling_params = self.get_scaling_params(sim_data.train_f)
        
        # preprocessing (train_f, valid_f, test_cf) 
        processed_dataset, row_dataset = self.preprocessing(pickle_map)
        
        # test_cf_multi 
        processed_dataset['test_cf_multi'] = self.preprocessing_sequential(sim_data.test_cf_multi)
        
        # scaling parameter 
        processed_dataset['train_scaling_params'] = self.train_scaling_params
        
        return processed_dataset, row_dataset

    # ------------------------------------------------------------
    # get scaling params
    # ------------------------------------------------------------
    def get_scaling_params(self, sim):
        real_idx = ['cancer_volume', "treat_dosage"]

        means = {}
        stds = {}
        seq_lengths = sim['sequence_lengths']
        for k in real_idx:
            active_values = []
            for i in range(seq_lengths.shape[0]):
                end = int(seq_lengths[i])
                active_values += list(sim[k][i, :end])

            means[k] = np.mean(active_values)
            stds[k] = np.std(active_values)

        # Add means for static variables`
        means['patient_types'] = np.mean(sim['patient_types'])
        stds['patient_types']  = np.std(sim['patient_types'])

        return pd.Series(means), pd.Series(stds)
    
    # -------------------------------------------------------------------------------------------------------
    # preprocessing
    # -------------------------------------------------------------------------------------------------------
    def preprocessing(self, pickle_map): 

        train_means, train_stds = self.train_scaling_params
        
        data_names        = ['train_f', 'valid_f', 'test_cf']  
        processed_dataset = {}
        row_dataset       = {}
        
        for data_name in data_names:
            # dynamic feature
            x = pickle_map[data_name]['cancer_volume']
            x = x[:, :, np.newaxis]
            num_patients, max_seq_length, _ = x.shape
            
            # static feature
            v_cat = pickle_map[data_name]['patient_types']

            # treatment
            w_cat = pickle_map[data_name]['treat_app_point']
            
            # dosage
            d_cat = pickle_map[data_name]['dose_app_point']
            d     = pickle_map[data_name]['treat_dosage']
            
            # dosage
            treat_chemo_dosage  = pickle_map[data_name]['treat_chemo_dosage']
            treat_radio_dosage  = pickle_map[data_name]['treat_radio_dosage']
            effect_chemo_dosage = pickle_map[data_name]['effect_chemo_dosage']
            
            #
            seq_lengths = pickle_map[data_name]['sequence_lengths'].astype('int64')
            
            
            # ------------------------------------------------------------
            # normalize X
            # ------------------------------------------------------------            
            # cancer volume
            x_scaled      = (x - train_means["cancer_volume"]) / train_stds["cancer_volume"]

            # x_unscaled_next 
            x_unscaled_next = x.copy()
            x_unscaled_next = np.hstack([x_unscaled_next[:, 1:, :], np.zeros([num_patients, 1, 1])])
            
            # x_next
            x_scaled_next = x_unscaled_next.copy() 
            x_scaled_next = (x_scaled_next - train_means["cancer_volume"]) / train_stds["cancer_volume"]
            
            # ------------------------------------------------------------
            # one-hot vector (v, w)
            # ------------------------------------------------------------
            # patient type [1,2,3]
            v = np.zeros([num_patients, max_seq_length, self.num_category_v])
            w = np.zeros([num_patients, max_seq_length, self.num_treatments])
            od = np.zeros([num_patients, max_seq_length, 2])
                        
            # factual mask
            counterfactual_mask = np.ones([num_patients, max_seq_length, self.num_treatments, self.num_dosage_samples])
            masked_factual_dr_curve = np.zeros([num_patients, max_seq_length, self.num_treatments, self.num_dosage_samples])
                 
            for i in range(num_patients):
                for t in range(max_seq_length):
                    v[i, t, (v_cat[i] - 1)] = 1
                    w[i, t, w_cat[i,t]]   = 1
                    od[i, t, d_cat[i,t]]  = 1
                    
                    # 
                    counterfactual_mask[i, t, w_cat[i,t], d_cat[i,t]]     = 0
                    masked_factual_dr_curve[i, t, w_cat[i,t], d_cat[i,t]] = x_scaled_next[i, t]
                    
            # ------------------------------------------------------------
            # normalize
            # ------------------------------------------------------------
            # dosage            
            d_scaled      = (d - train_means["treat_dosage"]) / train_stds["treat_dosage"]
            d_scaled      = d_scaled[:, :, np.newaxis]
            #
            d_prev        = d.copy()
            stack_d0      = np.zeros([num_patients, 1])
            d_prev        = np.hstack([stack_d0, d_prev[:, :-1]])
            d_scaled_prev = (d_prev - train_means["treat_dosage"]) / train_stds["treat_dosage"]
            d_scaled_prev   = d_scaled_prev[:, :, np.newaxis] 
            
            # prev_w
            w_prev          = w.copy()
            stack_w0        = np.zeros([num_patients, 1, self.num_treatments])
            stack_w0[:,:,0] = 1
            w_prev          = np.hstack([stack_w0, w_prev[:,:-1,:]])
                                                   
            # active entries
            active_entries = np.zeros(x_scaled.shape)
            
            for i in range(num_patients):
                sequence_length = int(seq_lengths[i])
                active_entries[i, :sequence_length, :] = 1
                
            # processed data
            processed_data = {
                # input data
                "inp_x"      : x_scaled,
                "inp_v"      : v,
                "inp_w_prev" : w_prev,
                "inp_d_prev" : d_scaled_prev,
                
                # output data
                "out_x_next" : x_scaled_next,
                
                # for calculating weight
                "unscaled_x"      : x,
                "unscaled_x_next" : x_unscaled_next,

                # Treatment
                "current_iw" : w_cat[:, :, np.newaxis],
                "current_ow" : w,
                
                # dosage
                "current_id" : d_cat[:, :, np.newaxis],
                "current_od" : od,
                "current_d"  : d_scaled,
            
                #
                "seq_lengths": seq_lengths,
                "active_entries": active_entries,
                
                #
                "counterfactual_mask": counterfactual_mask,
                "masked_factual_dr_curve" : masked_factual_dr_curve,
            }

            # row data
            row_data = {
                "row_x"             : x, 
                "row_v"             : v_cat, 
                "row_w"             : w_cat,
                "row_d"             : d_cat,
                "row_d_treat_radio" : treat_radio_dosage,
                "row_d_effect_chemo": effect_chemo_dosage,
                "row_d_treat_chemo" : treat_chemo_dosage,
                'chemo_TEs'         : pickle_map[data_name]['chemo_treatment_effects'],
                'radio_TEs'         : pickle_map[data_name]['radio_treatment_effects'],
                'seq_lengths'       : seq_lengths,
            }
            
            processed_dataset[data_name] = processed_data
            row_dataset[data_name]       = row_data
            
        return processed_dataset, row_dataset
    
    # ----------------------------------------------------------------------------------------------
    # sequential test
    # ----------------------------------------------------------------------------------------------
    def preprocessing_sequential(self, test_data_multi):       
                
        x               = test_data_multi["cf_cancer_volume_nahead"]
        cf_active_mask  = test_data_multi["cf_active_mask"]
        
        # treatment
        prev_iw    = test_data_multi["cfs_prev_treat_app_point"]
        current_iw = test_data_multi["cf_treat_app"]
        
        # dosage
        current_id = test_data_multi["cf_dose_app"]
        current_d  = test_data_multi["cf_treat_dosage"]
        prev_d     = test_data_multi["cfs_prev_dosage"]
 
        # shape
        num_patients, max_seq_length, nsamples, ntau = x.shape

        # -------------------------------------------------------------
        # one hot vector
        # -------------------------------------------------------------
        current_ow = np.zeros([num_patients, max_seq_length, nsamples,  ntau, self.num_treatments])
        prev_ow    = np.zeros([num_patients, max_seq_length, nsamples,  ntau, self.num_treatments])
        current_od = np.zeros([num_patients, max_seq_length, nsamples,  ntau, self.num_dosage_samples])
        for i in range(num_patients):
            for t in range(max_seq_length):
                for nsample in range(nsamples):
                    for tau in range(ntau):
                        # current
                        current_ow[i, t, nsample, tau, current_iw[i, t, nsample, tau]] = 1 
                        prev_ow[i, t, nsample, tau, prev_iw[i, t, nsample, tau]]       = 1
                        current_od[i, t, nsample, tau, current_id[i, t, nsample, tau]] = 1
     
        # -------------------------------------------------------------
        # dosage
        # -------------------------------------------------------------  
        # prev_d
        train_means, train_stds = self.train_scaling_params 
        current_d_scaled = (current_d - train_means["treat_dosage"]) / train_stds["treat_dosage"]
        prev_d_scaled    = (prev_d - train_means["treat_dosage"]) / train_stds["treat_dosage"]

        # -------------------------------------------------------------
        # active_entries
        # -------------------------------------------------------------  
        # active_entries
        active_entries = cf_active_mask.copy()
        
        # active_entries_next 
        active_entries_next = active_entries.copy()
        active_entries_next = np.concatenate([active_entries_next[:, :, :, 1:], 
                                              np.zeros([num_patients, max_seq_length, nsamples, 1])], axis = -1)
        
        processed_data = {
            # cancer volume
            "unscaled_x_nahead"  : x[:, :, :, :, np.newaxis],
            
            # treatment
            "current_iw"         : current_iw[:, :, :, :, np.newaxis],
            "current_ow"         : current_ow,            
            "prev_ow"            : prev_ow,
            
            # dosage
            "current_id"         : current_id[:, :, :, :, np.newaxis],
            "current_od"         : current_od,
            "current_d_scaled"   : current_d_scaled[:, :, :, :, np.newaxis],
            "prev_d_scaled"      : prev_d_scaled[:, :, :, :, np.newaxis],
            
            # active 
            "active_entries"     : active_entries[:, :, :, :, np.newaxis],
            "active_entries_next": active_entries_next[:, :, :, :, np.newaxis],
        }
        
        # shrink
        for key, value in processed_data.items():
            processed_data[key] = value[:, :, :, 1:, :]
        
        return processed_data
                    
    def error_check(self): 
        coeff_range = [0, 10]
        message = []
        for key, value in self.set_coeffs.items(): 
            if (value < coeff_range[0]) or (value > coeff_range[1]):
                message.append(key)
        
        for key, value in self.num_patients.items():
            if value < 1:
                message.append(f"num_patients[{key}]")                
                    
        if self.params["seed"] < 0:
            message.append("seed")
            
        if self.params["window_size"] < 1:
            message.append("window_size")
            
        if self.params["projection_horizon"] < 0:
            message.append("projection_horizon")   
            
        if self.params["cf_seq_mode"] not in ['random_trajectories', 'sliding_treatment']:
            message.append("cf_seq_mode")     
        
        if len(message) != 0:
            raise(Exception(message))
