import numpy as np
from tqdm import tqdm
from scipy.special import softmax

from util_sim import TUMOUR_DEATH_THRESHOLD, TUMOUR_CELL_DENSITY, OPTIMAL_CHEMO_DOSAGE, OPTIMAL_RADIO_DOSAGE
from util_sim import calc_volume, calc_diameter

def simulation_cancer_volume(simulation_params,
                             isCounterFactual   = False,
                             projection_horizon = 0,
                             cf_seq_mode        = 'random_trajectories'
                            ):
    
    # --------------------------------------------------------------------------------------------
    # simulation Parameters
    # --------------------------------------------------------------------------------------------
    # one day half life for drugs
    drug_half_life = 1  
    
    # Unpack simulation parameters
    initial_volumes = simulation_params['initial_volumes']
    rhos            = simulation_params['rho']
    
    # alpha, beta
    alpha_rs        = simulation_params['alpha_r']
    beta_rs         = simulation_params['beta_r']
    beta_cs         = simulation_params['beta_c']
    
    Ks              = simulation_params['K']
    patient_types   = simulation_params['patient_types']

    # Coefficients for treatment assignment proba
    # intercepts
    treat_chemo_sigmoid_intercepts  = simulation_params['treat_chemo_sigmoid_intercepts']
    treat_radio_sigmoid_intercepts  = simulation_params['treat_radio_sigmoid_intercepts']
    # betas
    treat_chemo_sigmoid_betas       = simulation_params['treat_chemo_sigmoid_betas']
    treat_radio_sigmoid_betas       = simulation_params['treat_radio_sigmoid_betas']
    
    # Coefficients for dose assignment proba
    dose_chemo_sigmoid_intercepts  = simulation_params['dose_chemo_sigmoid_intercepts']
    dose_radio_sigmoid_intercepts  = simulation_params['dose_radio_sigmoid_intercepts']
    
    # betas
    dose_chemo_sigmoid_betas       = simulation_params['dose_chemo_sigmoid_betas']
    dose_radio_sigmoid_betas       = simulation_params['dose_radio_sigmoid_betas']    
    
    # basic parameters
    num_patients   = initial_volumes.shape[0]
    window_size    = simulation_params['window_size']
    max_seq_length = simulation_params['max_seq_length']
    dict_wd      = simulation_params['dict_wd']

    # --------------------------------------------------------------------------------------------
    # Allocate Variables
    # --------------------------------------------------------------------------------------------
    # Commence Simulation
    cancer_volume        = np.zeros((num_patients, max_seq_length))
    
    # length, flags
    sequence_lengths     = np.zeros(num_patients, dtype = int)
    
    # app point
    treat_app_point      = np.zeros((num_patients, max_seq_length), dtype = int)
    dose_app_point       = np.zeros((num_patients, max_seq_length), dtype = int)

    # dosage
    treat_dosages        = np.zeros((num_patients, max_seq_length))
    #
    treat_chemo_dosages  = np.zeros((num_patients, max_seq_length))
    effect_chemo_dosages = np.zeros((num_patients, max_seq_length))
    treat_radio_dosages  = np.zeros((num_patients, max_seq_length))
    
    # proba
    treat_chemo_proba    = np.zeros((num_patients, max_seq_length))
    treat_radio_proba    = np.zeros((num_patients, max_seq_length))
    dose_chemo_proba     = np.zeros((num_patients, max_seq_length))
    dose_radio_proba     = np.zeros((num_patients, max_seq_length))
    
    # treatment effect
    chemo_treatment_effects = np.zeros((num_patients, max_seq_length))
    radio_treatment_effects = np.zeros((num_patients, max_seq_length))
    
    # --------------------------------------------------------------------------------------------
    # Random variables
    # --------------------------------------------------------------------------------------------
    noise_terms  = 0.01 * np.random.randn(num_patients, max_seq_length + projection_horizon)  # 5% cell variability  
    recovery_rvs = np.random.rand(num_patients, max_seq_length + projection_horizon)
    
    # rvs
    treat_chemo_app_rvs = np.random.rand(num_patients, max_seq_length)
    treat_radio_app_rvs = np.random.rand(num_patients, max_seq_length)
    dose_chemo_app_rvs  = np.random.rand(num_patients, max_seq_length)
    dose_radio_app_rvs  = np.random.rand(num_patients, max_seq_length)
    
    # --------------------------------------------------------------------------------------------
    # treatment options for multi-step predictions
    # --------------------------------------------------------------------------------------------
    if isCounterFactual:
        # allocate
        nsamples                = projection_horizon * 2
        cfs_cancer_volume       = np.zeros([num_patients, max_seq_length, nsamples, projection_horizon + 2])
        # application point
        cfs_treat_app_point     = np.zeros([num_patients, max_seq_length, nsamples, projection_horizon + 2], dtype = int)
        cfs_dose_app_point      = np.zeros([num_patients, max_seq_length, nsamples, projection_horizon + 2], dtype = int)
        
        # treat/effect dosage
        cfs_treat_dosage        = np.zeros([num_patients, max_seq_length, nsamples, projection_horizon + 2])
        cfs_treat_radio_dosage  = np.zeros([num_patients, max_seq_length, nsamples, projection_horizon + 2])
        cfs_treat_chemo_dosage  = np.zeros([num_patients, max_seq_length, nsamples, projection_horizon + 2])
        cfs_effect_chemo_dosage = np.zeros([num_patients, max_seq_length, nsamples, projection_horizon + 2])
        
        # prev
        cfs_prev_treat_app_point= np.zeros([num_patients, max_seq_length, nsamples, projection_horizon + 2], dtype = int)
        cfs_prev_dosage         = np.zeros([num_patients, max_seq_length, nsamples, projection_horizon + 2], dtype = int)
        
        # active mask
        cfs_active_mask         = np.zeros([num_patients, max_seq_length, nsamples, projection_horizon + 2], dtype = int)
            
    # -----------------------------------------------------------------
    # Run cancer simulation
    # -----------------------------------------------------------------
    for i in tqdm(range(num_patients), total = num_patients):
        # initial values
        cancer_volume[i, 0] = initial_volumes[i]
        
        # parameters
        alpha_r = alpha_rs[i]
        beta_r  = beta_rs[i]
        beta_c  = beta_cs[i]
        rho     = rhos[i]
        K       = Ks[i]
        
        #
        # start simulation 
        #
        for t in range(0, max_seq_length - 1): 
            current_chemo_dose  = 0.0
            previous_chemo_dose = 0.0 if t == 0 else effect_chemo_dosages[i, t - 1]
            
            cancer_volume_used = cancer_volume[i, max(t - window_size, 0):max(t + 1, 0)]
            
            # mean diameter over 15 days
            cancer_metric_used = np.array([calc_diameter(vol) for vol in cancer_volume_used]).mean()  
            
            # --------------------------------------------------------------------------------------------
            # determine treatment type
            # --------------------------------------------------------------------------------------------
            # radio
            treat_radio_prob = (1.0 / (1.0 + np.exp(-treat_radio_sigmoid_betas[i] *
                                                    (cancer_metric_used - treat_radio_sigmoid_intercepts[i]))))
            dose_radio_prob  = (1.0 / (1.0 + np.exp(-dose_radio_sigmoid_betas[i] *
                                                    (cancer_metric_used - dose_radio_sigmoid_intercepts[i]))))            
            
            # chemo
            treat_chemo_prob = (1.0 / (1.0 + np.exp(-treat_chemo_sigmoid_betas[i] * 
                                              (cancer_metric_used - treat_chemo_sigmoid_intercepts[i]))))
            dose_chemo_prob  = (1.0 / (1.0 + np.exp(-dose_chemo_sigmoid_betas[i] * 
                                              (cancer_metric_used - dose_chemo_sigmoid_intercepts[i]))))
            
            # probability
            treat_radio_proba[i, t] = treat_radio_prob
            dose_radio_proba[i, t]  = dose_radio_prob
            
            treat_chemo_proba[i, t] = treat_chemo_prob
            dose_chemo_proba[i, t]  = dose_chemo_prob

            # treatment
            treat_app_point[i, t], dose_app_point[i, t] = get_treat_dose_app_point(treat_radio_app_rvs[i, t], treat_chemo_app_rvs[i, t],
                                                                                   dose_radio_app_rvs[i, t], dose_chemo_app_rvs[i, t],
                                                                                   treat_radio_prob, treat_chemo_prob,
                                                                                   dose_radio_prob, dose_chemo_prob,
                                                                                   isCounterFactual = isCounterFactual)
            
            # dosage
            treat_radio_dosages[i, t], treat_chemo_dosages[i, t], treat_dosages[i,t] = get_radio_chemo_dosages(dict_wd, 
                                                                                                               treat_app_point[i, t], 
                                                                                                               dose_app_point[i, t])
            
            # update effect_chemo_dosage
            effect_chemo_dosages[i, t] = previous_chemo_dose * np.exp(-np.log(2) / drug_half_life) + treat_chemo_dosages[i, t]
            
            # -----------------------------------------------------------------------------
            # simulation (factual)
            # -----------------------------------------------------------------------------
            chemo_treatment_effects[i, t] = beta_c  * effect_chemo_dosages[i, t] 
            radio_treatment_effects[i, t] = alpha_r * treat_radio_dosages[i, t] + beta_r * treat_radio_dosages[i, t] ** 2                                   
            growth_rate = 1 + rho * np.log(K / cancer_volume[i, t]) - chemo_treatment_effects[i, t] - radio_treatment_effects[i, t] + noise_terms[i, t + 1]   
            cancer_volume[i, t + 1] = cancer_volume[i, t] * growth_rate

            isDeathOrRecover_f = False
            # Package outputs
            # patient dead
            if cancer_volume[i, t + 1 ] > TUMOUR_DEATH_THRESHOLD:
                cancer_volume[i, t + 1] = TUMOUR_DEATH_THRESHOLD
                isDeathOrRecover_f = True
                break

            # patient recover
            if recovery_rvs[i, t + 1] < np.exp(-cancer_volume[i, t + 1] * TUMOUR_CELL_DENSITY):
                cancer_volume[i, t + 1] = 0
                isDeathOrRecover_f = True
                break # patient recover
            
            # -----------------------------------------------------------------------------
            # simulation (Counterfactual)
            # -----------------------------------------------------------------------------
            if isCounterFactual:
                # random_trajectories
                if cf_seq_mode == 'random_trajectories':
                    # treatment [0,1,2,3]
                    treatment_options = np.random.randint(0, 4, (nsamples, projection_horizon))
                    # dosage [0, 1]
                    dosage_options    = np.random.randint(0, 2, (nsamples, projection_horizon))

                # sliding_treatment
                elif cf_seq_mode == 'sliding_treatment':
                    treat_sequence    = np.eye(projection_horizon, dtype = int)
                    treatment_options = np.concatenate([treat_sequence * 1, treat_sequence * 2])
                    dosage_options    = np.concatenate([treat_sequence, treat_sequence])                    
                else:
                    raise NotImplementedError()
                    
                # if treatment=0 then Dosage_options = 0
                for nsample in range(nsamples):
                    for tau in range(projection_horizon):
                        if treatment_options[nsample, tau] == 0:
                            dosage_options[nsample, tau] = 0
                    
                for nsample in range(nsamples):
                    # cancer volume
                    cf_cancer_volume          = np.zeros(projection_horizon + 2)
                    
                    # app_point
                    cf_treat_app_point      = np.zeros(projection_horizon + 2)
                    cf_dose_app_point       = np.zeros(projection_horizon + 2)
                    # prev_w
                    cf_prev_treat_app_point = np.zeros(projection_horizon + 2)
                    # dosage
                    cf_treat_dosage           = np.zeros(projection_horizon + 2)
                    cf_prev_dosage            = np.zeros(projection_horizon + 2)
                    cf_treat_radio_dosage     = np.zeros(projection_horizon + 2)
                    cf_treat_chemo_dosage     = np.zeros(projection_horizon + 2)
                    cf_effect_chemo_dosage    = np.zeros(projection_horizon + 2)
                    
                    # cancer volume
                    cf_cancer_volume[0:2]     = cancer_volume[i, t:t + 2]
                    
                    # app point
                    cf_treat_app_point[0]     = treat_app_point[i, t]
                    cf_dose_app_point[0]      = dose_app_point[i, t]
                    cf_treat_app_point[1:-1]  = treatment_options[nsample, :]
                    cf_dose_app_point[1:-1]   = dosage_options[nsample, :]
                    
                    # dosage
                    cf_treat_dosage[0]           = treat_dosages[i, t]
                    cf_effect_chemo_dosage[0]    = effect_chemo_dosages[i, t] 
                    cf_treat_radio_dosage[0]     = treat_radio_dosages[i, t]
                    cf_treat_chemo_dosage[0]     = treat_chemo_dosages[i, t]
                    
                    # prev treatment/dosage
                    cf_prev_treat_app_point[0]  = 0.0 if t == 0 else treat_app_point[i,t - 1]
                    cf_prev_dosage[0]           = 0.0 if t == 0 else treat_dosages[i, t - 1]

                    # ---------------------------------------------------------------
                    # projection horizon
                    # ---------------------------------------------------------------
                    isDeathOrRecover_cf = False
                    for ntau in range(1, projection_horizon + 1):
                        current_t = t + ntau
                        
                        # dosage
                        cf_treat_radio_dosage[ntau], cf_treat_chemo_dosage[ntau], cf_treat_dosage[ntau] = \
                        get_radio_chemo_dosages(dict_wd,
                                                cf_treat_app_point[ntau],
                                                cf_dose_app_point[ntau])
                        # effect chemo dosage
                        cf_effect_chemo_dosage[ntau] = cf_effect_chemo_dosage[ntau - 1] * np.exp(-np.log(2) / drug_half_life) + cf_treat_chemo_dosage[ntau]
                        
                        # treatment effect
                        cf_chemo_te = beta_c * cf_effect_chemo_dosage[ntau]
                        cf_radio_te = alpha_r * cf_treat_radio_dosage[ntau] + beta_r * cf_treat_radio_dosage[ntau] ** 2
                        
                        # tumor simulation
                        growth_rate = 1 + rho * np.log(K / cf_cancer_volume[ntau]) - cf_chemo_te - cf_radio_te + noise_terms[i, current_t + 1] 
                        cf_cancer_volume[ntau + 1] = cf_cancer_volume[ntau] * growth_rate
                        
                        # patient dead
                        if cf_cancer_volume[ntau + 1] > TUMOUR_DEATH_THRESHOLD:
                            cf_cancer_volume[ntau + 1] = TUMOUR_DEATH_THRESHOLD
                            isDeathOrRecover_cf = True
                            break
                            
                        # patient recover
                        if recovery_rvs[i, current_t + 1] < np.exp(- cf_cancer_volume[ntau + 1] * TUMOUR_CELL_DENSITY):
                            cf_cancer_volume[ntau + 1] = 0
                            isDeathOrRecover_cf = True
                            break 
                                                       
                    # multi-step
                    cf_seqlen = ntau + 1
                    
                    # cancer volume
                    cfs_cancer_volume[i, t, nsample, :cf_seqlen + 1]     = cf_cancer_volume[:cf_seqlen + 1]
                    # mask
                    #cfs_active_mask[i, t, nsample, :cf_seqlen + 1]     = 1 
                    
                    if isDeathOrRecover_cf:
                        cfs_active_mask[i, t, nsample, :cf_seqlen]     = 1  
                    else:
                        cfs_active_mask[i, t, nsample, :cf_seqlen + 1]  = 1
                    
                    #
                    # previous treatment/dosage
                    #
                    # app point
                    cfs_treat_app_point[i, t, nsample, :cf_seqlen]     = cf_treat_app_point[ :cf_seqlen]
                    cfs_dose_app_point[i,  t, nsample, :cf_seqlen]     = cf_dose_app_point[ :cf_seqlen]
                    
                    # dosage
                    cfs_treat_dosage[i, t, nsample,  :cf_seqlen]       = cf_treat_dosage[:cf_seqlen]
                    cfs_treat_radio_dosage[i, t, nsample,  :cf_seqlen] = cf_treat_radio_dosage[:cf_seqlen]
                    cfs_treat_chemo_dosage[i, t, nsample,  :cf_seqlen] = cf_treat_chemo_dosage[:cf_seqlen] 
                    cfs_effect_chemo_dosage[i, t, nsample, :cf_seqlen] = cf_effect_chemo_dosage[:cf_seqlen]

                    # prev_w
                    cf_prev_treat_app_point[1:]                         = cf_treat_app_point[:-1]
                    cfs_prev_treat_app_point[i, t, nsample, :cf_seqlen] = cf_prev_treat_app_point[:cf_seqlen]
                    
                    # prev_d
                    cf_prev_dosage[1:]                                 = cf_treat_dosage[:-1]
                    cfs_prev_dosage[i, t, nsample, :cf_seqlen]         = cf_prev_dosage[:cf_seqlen]
 
        #            
        if isDeathOrRecover_f:
            sequence_lengths[i] = int(t)
        else:
            sequence_lengths[i] = int(t + 1)
        
    factual_outputs = {
        'cancer_volume'      : cancer_volume,    

        # application point
        'treat_app_point'  : treat_app_point,
        'dose_app_point'   : dose_app_point,
        
        # probability
        'treat_chemo_proba': treat_chemo_proba,
        'treat_radio_proba': treat_radio_proba,
        'dose_chemo_proba': dose_chemo_proba,
        'dose_radio_proba': dose_radio_proba,               

        # dosage
        'treat_dosage'       : treat_dosages,
        'effect_chemo_dosage': effect_chemo_dosages,
        'treat_chemo_dosage' : treat_chemo_dosages,
        'treat_radio_dosage' : treat_radio_dosages,

        # treatment effect
        'chemo_treatment_effects': chemo_treatment_effects,
        'radio_treatment_effects': radio_treatment_effects,

        # flags
        'sequence_lengths'   : sequence_lengths,
        'patient_types'      : patient_types,
        }
    
    if isCounterFactual:
        couterfactual_outputs = {
            "cf_cancer_volume_nahead": cfs_cancer_volume,
            "cf_active_mask"         : cfs_active_mask,
            #
            "cfs_prev_treat_app_point": cfs_prev_treat_app_point,
            "cfs_prev_dosage"         : cfs_prev_dosage,
            #
            "cf_treat_app"           : cfs_treat_app_point,
            "cf_dose_app"            : cfs_dose_app_point,
            "cf_treat_dosage"        : cfs_treat_dosage,
            #
            "cf_treat_radio_dosage"  : cfs_treat_radio_dosage,
            "cf_treat_chemo_dosage"  : cfs_treat_chemo_dosage,
            "cf_effect_chemo_dosage" : cfs_effect_chemo_dosage,

        }
        return factual_outputs, couterfactual_outputs
    
    else:
        return factual_outputs
    
# -------------------------------------------------------------------------------------------
# get treat/dose application point
# -------------------------------------------------------------------------------------------
def get_treat_dose_app_point(treat_radio_app_rvs, treat_chemo_app_rvs, 
                             dose_radio_app_rvs, dose_chemo_app_rvs,
                             treat_radio_prob, treat_chemo_prob, 
                             dose_radio_prob, dose_chemo_prob,                              
                             isCounterFactual = False):
    
    treat_app_point = 0
    dose_app_point  = 0
    
    if isCounterFactual:
        # both
        if (treat_radio_app_rvs > treat_radio_prob) and (treat_chemo_app_rvs > treat_chemo_prob):
            treat_app_point = 3
            # dosage
            if (dose_radio_app_rvs + dose_chemo_app_rvs) > (dose_radio_prob + dose_chemo_prob):
                dose_app_point = 1

        # radio
        elif treat_radio_app_rvs > treat_radio_prob:
            treat_app_point = 1
            # dosage
            if dose_radio_app_rvs > dose_radio_prob:
                dose_app_point = 1

        # chemo
        elif treat_chemo_app_rvs > treat_chemo_prob:
            treat_app_point = 2
            # dosage
            if dose_chemo_app_rvs > dose_chemo_prob:
                dose_app_point = 1
        
    else:
        # both
        if (treat_radio_app_rvs < treat_radio_prob) and (treat_chemo_app_rvs < treat_chemo_prob):
            treat_app_point = 3
            # dosage
            if (dose_radio_app_rvs + dose_chemo_app_rvs) < (dose_radio_prob + dose_chemo_prob):
                dose_app_point = 1

        # radio
        elif treat_radio_app_rvs < treat_radio_prob:
            treat_app_point = 1
            # dosage
            if dose_radio_app_rvs < dose_radio_prob:
                dose_app_point = 1

        # chemo
        elif treat_chemo_app_rvs < treat_chemo_prob:
            treat_app_point = 2
            # dosage
            if dose_chemo_app_rvs < dose_chemo_prob:
                dose_app_point = 1

            
    return treat_app_point, dose_app_point
                
# -------------------------------------------------------------------------------------------
# get radio/chemo dosage
# -------------------------------------------------------------------------------------------
def get_radio_chemo_dosages(dict_wd, treat_app_point, dose_app_point):
    
    treat_app_point = int(treat_app_point)
    dose_app_point  = int(dose_app_point)
    
    if treat_app_point == 0:   # no treatment
        treat_radio_dosages = 0.0 
        treat_chemo_dosages = 0.0
        treat_dosages       = 0.0
    elif treat_app_point == 1: # radio
        treat_radio_dosages = dict_wd[treat_app_point][dose_app_point]
        treat_chemo_dosages = 0.0
        treat_dosages       = treat_radio_dosages
    elif treat_app_point == 2: # chemo
        treat_radio_dosages = 0.0  
        treat_chemo_dosages = dict_wd[treat_app_point][dose_app_point]
        treat_dosages       = treat_chemo_dosages
    elif treat_app_point == 3: # both
        treat_radio_dosages = dict_wd[treat_app_point][dose_app_point] * (2.0 / 7.0) 
        treat_chemo_dosages = dict_wd[treat_app_point][dose_app_point] * (5.0 / 7.0)
        treat_dosages       = dict_wd[treat_app_point][dose_app_point] 
    else:
        raise NotImplementedError()
        
    return treat_radio_dosages, treat_chemo_dosages, treat_dosages
