import numpy as np
import pandas as pd
import sys

from util_sim import TUMOUR_DEATH_THRESHOLD, TUMOUR_CELL_DENSITY
from util_sim import tumour_size_distributions, cancer_stage_observations
from util_sim import calc_volume, calc_diameter
from scipy.stats import truncnorm  # we need to sample from truncated normal distributions

from simulation_cancer_volume import simulation_cancer_volume

class SyntheticCancerDatasetCollection():
    """
    Dataset collection (train_f, val_f, test_cf_one_step, test_cf_treatment_seq)
    """

    def __init__(self,
                 set_coeffs:             dict,                 
                 num_patients:           dict,
                 dict_wd:              dict,
                 seed                   = 100,
                 window_size            = 15,
                 max_seq_length         = 60,
                 projection_horizon     = 5,
                 cf_seq_mode            = 'sliding_treatment',
                 **kwargs):
        

        self.seed = seed
        np.random.seed(seed)
        
        # -------------------------------------------------------------------
        # 訓練データ
        # -------------------------------------------------------------------
        self.train_params = self.generate_params(num_patients['train'],
                                                 set_coeffs,
                                                 dict_wd,
                                                 window_size    = window_size,
                                                 max_seq_length = max_seq_length
                                                )
        
        self.train_f = simulation_cancer_volume(self.train_params)
        
        # -------------------------------------------------------------------
        # 評価データ
        # -------------------------------------------------------------------
        self.val_params = self.generate_params(num_patients['valid'],
                                               set_coeffs,
                                               dict_wd,
                                               window_size    = window_size,
                                               max_seq_length = max_seq_length
                                              )

        self.val_f = simulation_cancer_volume(self.val_params)
        
        # -------------------------------------------------------------------
        # テストデータ
        # -------------------------------------------------------------------
        # Seed 固定
        #np.random.seed(1010101)         
        
        # 反事実データ (複数ステップ先)
        self.test_params = self.generate_params(num_patients['test'],
                                                set_coeffs,
                                                dict_wd,
                                                window_size    = window_size,
                                                max_seq_length = max_seq_length                                               
                                               )
        
        #
        cf_seq_dataset = simulation_cancer_volume(self.test_params, 
                                                  isCounterFactual   = True, 
                                                  projection_horizon = projection_horizon,
                                                  cf_seq_mode        = cf_seq_mode
                                                 )
        # one, multi_step 
        self.test_cf, self.test_cf_multi = cf_seq_dataset 
        
        # -------------------------------------------------------------------
        # その他
        # -------------------------------------------------------------------
        self.projection_horizon   = projection_horizon
        self.autoregressive       = True
        
    # ------------------------------------------------------------
    # generate params
    # ------------------------------------------------------------        
    def generate_params(self, 
                        num_patients, 
                        set_coeffs,
                        dict_wd,
                        window_size,
                        max_seq_length,
                       ):

        basic_params = self.get_standard_params(num_patients)

        # Parameters controlling sigmoid application probabilities
        D_MAX = calc_diameter(TUMOUR_DEATH_THRESHOLD)
        
        # treatment sigmoid intercepts/betas
        basic_params['treat_chemo_sigmoid_intercepts'] = np.array([D_MAX / 2.0 for _ in range(num_patients)])
        basic_params['treat_radio_sigmoid_intercepts'] = np.array([D_MAX / 2.0 for _ in range(num_patients)])
        basic_params['treat_chemo_sigmoid_betas']      = np.array([set_coeffs['treat_chemo'] / D_MAX for _ in range(num_patients)])
        basic_params['treat_radio_sigmoid_betas']      = np.array([set_coeffs['treat_radio'] / D_MAX for _ in range(num_patients)])
        
        # dosage sigmoid intercepts/betas
        basic_params['dose_chemo_sigmoid_intercepts']  = np.array([D_MAX / 2.0 for _ in range(num_patients)])
        basic_params['dose_radio_sigmoid_intercepts']  = np.array([D_MAX / 2.0 for _ in range(num_patients)])
        basic_params['dose_chemo_sigmoid_betas']       = np.array([set_coeffs['dose_chemo'] / D_MAX for _ in range(num_patients)])
        basic_params['dose_radio_sigmoid_betas']       = np.array([set_coeffs['dose_radio'] / D_MAX for _ in range(num_patients)])
        
        basic_params['dict_wd']      = dict_wd
        basic_params['window_size']    = window_size
        basic_params['max_seq_length'] = max_seq_length
        
        #
        return basic_params

    # ------------------------------------------------------------
    # get standard params
    # ------------------------------------------------------------    
    def get_standard_params(self, num_patients):  # additional params
        """
        Simulation parameters from the Nature article + adjustments for static variables

        :param num_patients: Number of patients to simulate
        :return: simulation_parameters: Initial volumes + Static variables (e.g. response to treatment); randomly shuffled
        """

        # INITIAL VOLUMES SAMPLING
        TOTAL_OBS = sum(cancer_stage_observations.values())
        cancer_stage_proportions = {k: cancer_stage_observations[k] / TOTAL_OBS for k in cancer_stage_observations}

        # remove possible entries
        possible_stages = list(tumour_size_distributions.keys())
        possible_stages.sort()

        initial_stages = np.random.choice(possible_stages, num_patients,
                                          p=[cancer_stage_proportions[k] for k in possible_stages])

        # Get info on patient stages and initial volumes
        output_initial_diam = []
        patient_sim_stages = []
        

        for stg in possible_stages:
            count = np.sum((initial_stages == stg) * 1)

            mu, sigma, lower_bound, upper_bound = tumour_size_distributions[stg]

            # Convert lognorm bounds in to standard normal bounds
            lower_bound = (np.log(lower_bound) - mu) / sigma
            upper_bound = (np.log(upper_bound) - mu) / sigma

            norm_rvs = truncnorm.rvs(lower_bound, upper_bound,
                                     size=count)  # truncated normal for realistic clinical outcome

            initial_volume_by_stage = np.exp((norm_rvs * sigma) + mu)
            output_initial_diam += list(initial_volume_by_stage)
            patient_sim_stages += [stg for i in range(count)]

   
        initial_volumes = calc_volume(np.array(output_initial_diam))

                        
        # STATIC VARIABLES SAMPLING
        # Fixed params
        K = calc_volume(30)  # carrying capacity given in cm, so convert to volume
        ALPHA_BETA_RATIO = 10
        ALPHA_RHO_CORR = 0.87

        # Distributional parameters for dynamics
        parameter_lower_bound = 0.0
        parameter_upper_bound = np.inf
        rho_params = (7 * 10 ** -5, 7.23 * 10 ** -3)
        alpha_params = (0.0398, 0.168)
        beta_c_params = (0.028, 0.0007)

        # Get correlated simulation paramters (alpha, beta, rho) which respects bounds
        alpha_rho_cov = np.array([[alpha_params[1] ** 2, ALPHA_RHO_CORR * alpha_params[1] * rho_params[1]],
                                  [ALPHA_RHO_CORR * alpha_params[1] * rho_params[1], rho_params[1] ** 2]])

        alpha_rho_mean = np.array([alpha_params[0], rho_params[0]])

        simulated_params = []

        while len(simulated_params) < num_patients:  # Keep on simulating till we get the right number of params

            param_holder = np.random.multivariate_normal(alpha_rho_mean, alpha_rho_cov, size=num_patients)

            for i in range(param_holder.shape[0]):

                # Ensure that all params fulfill conditions
                if param_holder[i, 0] > parameter_lower_bound and param_holder[i, 1] > parameter_lower_bound:
                    simulated_params.append(param_holder[i, :])

            #logging.info("Got correlated params for {} patients".format(len(simulated_params)))

        # Adjustments for static variables
        possible_patient_types = [1, 2, 3]
        patient_types = np.random.choice(possible_patient_types, num_patients)
        chemo_mean_adjustments = np.array([0.0 if i < 3 else 0.1 for i in patient_types])
        radio_mean_adjustments = np.array([0.0 if i > 1 else 0.1 for i in patient_types])

        simulated_params = np.array(simulated_params)[:num_patients, :]  # shorten this back to normal
        alpha_adjustments = alpha_params[0] * radio_mean_adjustments
        alpha = simulated_params[:, 0] + alpha_adjustments
        rho = simulated_params[:, 1]
        beta = alpha / ALPHA_BETA_RATIO

        # Get the remaining indep params
        #logging.info("Simulating beta_c parameters")
        beta_c_adjustments = beta_c_params[0] * chemo_mean_adjustments
        beta_c = beta_c_params[0] + beta_c_params[1] * truncnorm.rvs(
            (parameter_lower_bound - beta_c_params[0]) / beta_c_params[1],
            (parameter_upper_bound - beta_c_params[0]) / beta_c_params[1],
            size=num_patients) + beta_c_adjustments

        output_holder = {'patient_types': patient_types,
                         'initial_stages': np.array(patient_sim_stages),
                         'initial_volumes': initial_volumes,  # assumed spherical with diam
                         'alpha_r': alpha,
                         'rho': rho,
                         'beta_r': beta,
                         'beta_c': beta_c,
                         'K': np.array([K for _ in range(num_patients)]),
                         }
        # np.random.exponential(expected_treatment_delay, num_patients),

        # Randomise output params
        #logging.info("Randomising outputs")
        idx = [i for i in range(num_patients)]
        np.random.shuffle(idx)

        output_params = {}
        for k in output_holder:
            output_params[k] = output_holder[k][idx]

        return output_params
