"""
[Treatment Effects with RNNs] cancer_simulation
Created on 2/4/2018 8:14 AM

Medically realistic data simulation for small-cell lung cancer based on Geng et al 2017.
URL: https://www.nature.com/articles/s41598-017-13646-z

Notes:
- Simulation time taken to be in days

@author: limsi
"""

import logging
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from scipy.stats import truncnorm  
from scipy.stats import beta as scipy_beta
from scipy.interpolate import CubicSpline
import torch

sns.set()


def calc_volume(diameter):
    return 4 / 3 * np.pi * (diameter / 2) ** 3


def calc_diameter(volume):
    volume = max(volume, 0.)  
    return ((volume / (4 / 3 * np.pi)) ** (1 / 3)) * 2

t_control = np.array([0, 0.4, 0.7, 1])
f_control = np.array([0, 0.2, 0.85, 1])
cs = CubicSpline(t_control, f_control, bc_type='clamped')
TUMOUR_CELL_DENSITY = 5.8 * 10 ** 8  
TUMOUR_DEATH_THRESHOLD = calc_volume(13)  

noise_scale = 0
tumour_size_distributions = {'I': (1.72, 4.70, 0.3, 5.0),
                             'II': (1.96, 1.63, 0.3, 13.0),
                             'IIIA': (1.91, 9.40, 0.3, 13.0),
                             'IIIB': (2.76, 6.87, 0.3, 13.0),
                             'IV': (3.86, 8.82, 0.3, 13.0)}  
cancer_stage_observations = {'I': 1432,
                             "II": 128,
                             "IIIA": 1306,
                             "IIIB": 7248,
                             "IV": 12840}


def generate_params(num_patients, chemo_coeff, radio_coeff, window_size, lag):
    """
    Get original patient-specific simulation parameters, and add extra ones to control confounding

    :param num_patients: Number of patients to simulate
    :param chemo_coeff: Bias on action policy for chemotherapy assignments
    :param radio_activation_group: Bias on action policy for chemotherapy assignments
    :return: dict of parameters
    """

    basic_params = get_standard_params(num_patients)
    patient_types = basic_params['patient_types']

    D_MAX = calc_diameter(TUMOUR_DEATH_THRESHOLD)
    basic_params['chemo_sigmoid_intercepts'] = np.array([D_MAX / 2.0 for _ in patient_types])
    basic_params['radio_sigmoid_intercepts'] = np.array([D_MAX / 2.0 for _ in patient_types])

    basic_params['chemo_sigmoid_betas'] = np.array([chemo_coeff / D_MAX for _ in patient_types])
    basic_params['radio_sigmoid_betas'] = np.array([radio_coeff / D_MAX for _ in patient_types])

    basic_params['window_size'] = window_size
    basic_params['lag'] = lag

    return basic_params


def get_standard_params(num_patients):  
    """
    Simulation parameters from the Nature article + adjustments for static variables

    :param num_patients: Number of patients to simulate
    :return: simulation_parameters: Initial volumes + Static variables (e.g. response to treatment); randomly shuffled
    """
    TOTAL_OBS = sum(cancer_stage_observations.values())
    cancer_stage_proportions = {k: cancer_stage_observations[k] / TOTAL_OBS for k in cancer_stage_observations}
    possible_stages = list(tumour_size_distributions.keys())
    possible_stages.sort()

    initial_stages = np.random.choice(possible_stages, num_patients,
                                      p=[cancer_stage_proportions[k] for k in possible_stages])
    output_initial_diam = []
    patient_sim_stages = []
    for stg in possible_stages:
        count = np.sum((initial_stages == stg) * 1)

        mu, sigma, lower_bound, upper_bound = tumour_size_distributions[stg]
        lower_bound = (np.log(lower_bound) - mu) / sigma
        upper_bound = (np.log(upper_bound) - mu) / sigma

        logging.info(("Simulating initial volumes for stage {} " +
                      " with norm params: mu={}, sigma={}, lb={}, ub={}").format(
            stg,
            mu,
            sigma,
            lower_bound,
            upper_bound))

        norm_rvs = truncnorm.rvs(lower_bound, upper_bound,
                                 size=count)  

        initial_volume_by_stage = np.exp((norm_rvs * sigma) + mu)
        output_initial_diam += list(initial_volume_by_stage)
        patient_sim_stages += [stg for i in range(count)]
    K = calc_volume(30)  
    ALPHA_BETA_RATIO = 10
    ALPHA_RHO_CORR = 0.87
    parameter_lower_bound = 0.0
    parameter_upper_bound = np.inf
    rho_params = (7 * 10 ** -5, 7.23 * 10 ** -3)
    alpha_params = (0.0398, 0.168)
    beta_c_params = (0.028, 0.0007)
    alpha_rho_cov = np.array([[alpha_params[1] ** 2, ALPHA_RHO_CORR * alpha_params[1] * rho_params[1]],
                              [ALPHA_RHO_CORR * alpha_params[1] * rho_params[1], rho_params[1] ** 2]])

    alpha_rho_mean = np.array([alpha_params[0], rho_params[0]])

    simulated_params = []

    while len(simulated_params) < num_patients:  

        param_holder = np.random.multivariate_normal(alpha_rho_mean, alpha_rho_cov, size=num_patients)

        for i in range(param_holder.shape[0]):
            if param_holder[i, 0] > parameter_lower_bound and param_holder[i, 1] > parameter_lower_bound:
                simulated_params.append(param_holder[i, :])

        logging.info("Got correlated params for {} patients".format(len(simulated_params)))
    possible_patient_types = [1, 2, 3]
    patient_types = np.random.choice(possible_patient_types, num_patients)
    chemo_mean_adjustments = np.array([0.0 if i < 3 else 0.1 for i in patient_types])
    radio_mean_adjustments = np.array([0.0 if i > 1 else 0.1 for i in patient_types])

    simulated_params = np.array(simulated_params)[:num_patients, :]  
    alpha_adjustments = alpha_params[0] * radio_mean_adjustments
    alpha = simulated_params[:, 0] + alpha_adjustments
    rho = simulated_params[:, 1]
    beta = alpha / ALPHA_BETA_RATIO
    logging.info("Simulating beta_c parameters")
    beta_c_adjustments = beta_c_params[0] * chemo_mean_adjustments
    beta_c = beta_c_params[0] + beta_c_params[1] * truncnorm.rvs(
        (parameter_lower_bound - beta_c_params[0]) / beta_c_params[1],
        (parameter_upper_bound - beta_c_params[0]) / beta_c_params[1],
        size=num_patients) + beta_c_adjustments

    output_holder = {'patient_types': patient_types,
                     'initial_stages': np.array(patient_sim_stages),
                     'initial_volumes': calc_volume(np.array(output_initial_diam)),  
                     'alpha': alpha,
                     'rho': rho,
                     'beta': beta,
                     'beta_c': beta_c,
                     'K': np.array([K for _ in range(num_patients)]),
                     }
    logging.info("Randomising outputs")
    idx = [i for i in range(num_patients)]
    np.random.shuffle(idx)

    output_params = {}
    for k in output_holder:
        output_params[k] = output_holder[k][idx]

    return output_params


def simulate_factual(simulation_params, seq_length, assigned_actions=None, random=False):
    """
    Simulation of factual patient trajectories (for train and validation subset)

    :param simulation_params: Parameters of the simulation
    :param seq_length: Maximum trajectory length
    :param assigned_actions: Fixed non-random treatment assignment policy, if None - standard biased random assignment is applied
    :return: simulated data dict
    """

    total_num_radio_treatments = 1
    total_num_chemo_treatments = 1

    radio_amt = np.array([2.0 for i in range(total_num_radio_treatments)])  
    chemo_amt = [5.0 for i in range(total_num_chemo_treatments)]
    chemo_days = [(i + 1) * 7 for i in range(total_num_chemo_treatments)]
    chemo_idx = np.argsort(chemo_days)
    chemo_amt = np.array(chemo_amt)[chemo_idx]
    chemo_days = np.array(chemo_days)[chemo_idx]

    drug_half_life = 1  
    initial_stages = simulation_params['initial_stages']
    initial_volumes = simulation_params['initial_volumes']
    alphas = simulation_params['alpha']
    rhos = simulation_params['rho']
    betas = simulation_params['beta']
    beta_cs = simulation_params['beta_c']
    Ks = simulation_params['K']
    patient_types = simulation_params['patient_types']
    window_size = simulation_params['window_size']
    lag = simulation_params['lag']
    chemo_sigmoid_intercepts = simulation_params['chemo_sigmoid_intercepts']
    radio_sigmoid_intercepts = simulation_params['radio_sigmoid_intercepts']
    chemo_sigmoid_betas = simulation_params['chemo_sigmoid_betas']
    radio_sigmoid_betas = simulation_params['radio_sigmoid_betas']

    num_patients = initial_stages.shape[0]
    cancer_volume = np.zeros((num_patients, seq_length))
    chemo_dosage = np.zeros((num_patients, seq_length))
    radio_dosage = np.zeros((num_patients, seq_length))
    chemo_application_point = np.zeros((num_patients, seq_length))
    radio_application_point = np.zeros((num_patients, seq_length))
    sequence_lengths = np.zeros(num_patients)
    death_flags = np.zeros((num_patients, seq_length))
    recovery_flags = np.zeros((num_patients, seq_length))
    chemo_probabilities = np.zeros((num_patients, seq_length))
    radio_probabilities = np.zeros((num_patients, seq_length))

    noise_terms = noise_scale * np.random.randn(num_patients, seq_length)  
    recovery_rvs = np.random.rand(num_patients, seq_length)

    chemo_application_rvs = np.random.rand(num_patients, seq_length)
    radio_application_rvs = np.random.rand(num_patients, seq_length)
    for i in tqdm(range(num_patients), total=num_patients):
        noise = noise_terms[i]
        cancer_volume[i, 0] = initial_volumes[i]
        alpha = alphas[i]
        beta = betas[i]
        beta_c = beta_cs[i]
        rho = rhos[i]
        K = Ks[i]
        b_death = False
        b_recover = False
        for t in range(1, seq_length):

            cancer_volume[i, t] = cancer_volume[i, t - 1] *\
                (1 + rho * np.log(K / (cancer_volume[i, t - 1] + 1e-07) + 1e-07) - beta_c * chemo_dosage[i, t - 1] -
                    (alpha * radio_dosage[i, t - 1] + beta * radio_dosage[i, t - 1] ** 2) + noise[t])

            current_chemo_dose = 0.0
            previous_chemo_dose = 0.0 if t == 0 else chemo_dosage[i, t - 1]
            if t >= lag:
                cancer_volume_used = cancer_volume[i, max(t - window_size - lag, 0):max(t - lag, 0)]
            else:
                cancer_volume_used = np.zeros((1, ))
            cancer_diameter_used = np.array(
                [calc_diameter(vol) for vol in cancer_volume_used]).mean()  
            cancer_metric_used = cancer_diameter_used
            if assigned_actions is not None:
                chemo_prob = assigned_actions[i, t, 0]
                radio_prob = assigned_actions[i, t, 1]
            else:
                radio_prob = (1.0 / (1.0 + np.exp(-radio_sigmoid_betas[i] * (cancer_metric_used - radio_sigmoid_intercepts[i]))))
                chemo_prob = (1.0 / (1.0 + np.exp(- chemo_sigmoid_betas[i] * (cancer_metric_used - chemo_sigmoid_intercepts[i]))))
            chemo_probabilities[i, t] = chemo_prob
            radio_probabilities[i, t] = radio_prob
            radio_alpha = 2 * radio_prob
            chemo_alpha = 2 * chemo_prob
            radio_beta = 2 - radio_alpha
            chemo_beta = 2 - chemo_alpha

            delta = 0.
            radio_application_point[i, t] = scipy_beta.rvs(radio_alpha + delta, radio_beta)
            chemo_application_point[i, t] = scipy_beta.rvs(chemo_alpha + delta, chemo_beta)

            if random:
                if np.random.rand() > 0.7:
                    radio_application_point[i, t] = np.random.beta(1/2, 2)
                    chemo_application_point[i, t] = np.random.beta(1/2, 2)

            radio_dosage[i, t] = cs(radio_application_point[i, t]) * 2

            current_chemo_dose = 5 * cs(chemo_application_point[i, t])
            chemo_dosage[i, t] = previous_chemo_dose * np.exp(-np.log(2) / drug_half_life) + current_chemo_dose

            if cancer_volume[i, t] > TUMOUR_DEATH_THRESHOLD:
                cancer_volume[i, t] = TUMOUR_DEATH_THRESHOLD
                b_death = True
                break  
            if recovery_rvs[i, t] < np.exp(-cancer_volume[i, t] * TUMOUR_CELL_DENSITY):
                cancer_volume[i, t] = 0
                b_recover = True
                break
        sequence_lengths[i] = int(t)
        death_flags[i, t] = 1 if b_death else 0
        recovery_flags[i, t] = 1 if b_recover else 0

    outputs = {'cancer_volume': cancer_volume,
               'chemo_dosage': chemo_dosage,
               'radio_dosage': radio_dosage,
               'chemo_application': chemo_application_point,
               'radio_application': radio_application_point,
               'chemo_probabilities': chemo_probabilities,
               'radio_probabilities': radio_probabilities,
               'sequence_lengths': sequence_lengths,
               'death_flags': death_flags,
               'recovery_flags': recovery_flags,
               'patient_types': patient_types,
               'alpha': alphas,
               'rho': rhos,
               'beta': betas,
               'beta_c': beta_cs,
               'K': Ks
               }

    return outputs


def simulate_counterfactual_1_step(simulation_params, seq_length):
    """
    Simulation of test trajectories to asses all one-step ahead counterfactuals
    :param simulation_params: Parameters of the simulation
    :param seq_length: Maximum trajectory length (number of factual time-steps)
    :return: simulated data dict with number of rows equal to num_patients * seq_length * num_treatments
    """

    total_num_radio_treatments = 1
    total_num_chemo_treatments = 1

    num_treatments = 4  

    radio_amt = np.array([2.0 for i in range(total_num_radio_treatments)])  
    chemo_amt = [5.0 for i in range(total_num_chemo_treatments)]
    chemo_days = [(i + 1) * 7 for i in range(total_num_chemo_treatments)]
    chemo_idx = np.argsort(chemo_days)
    chemo_amt = np.array(chemo_amt)[chemo_idx]
    chemo_days = np.array(chemo_days)[chemo_idx]

    drug_half_life = 1  
    initial_stages = simulation_params['initial_stages']
    initial_volumes = simulation_params['initial_volumes']
    alphas = simulation_params['alpha']
    rhos = simulation_params['rho']
    betas = simulation_params['beta']
    beta_cs = simulation_params['beta_c']
    Ks = simulation_params['K']
    patient_types = simulation_params['patient_types']
    window_size = simulation_params['window_size']  
    lag = simulation_params['lag']
    chemo_sigmoid_intercepts = simulation_params['chemo_sigmoid_intercepts']
    radio_sigmoid_intercepts = simulation_params['radio_sigmoid_intercepts']
    chemo_sigmoid_betas = simulation_params['chemo_sigmoid_betas']
    radio_sigmoid_betas = simulation_params['radio_sigmoid_betas']

    num_patients = initial_stages.shape[0]

    num_test_points = num_patients * seq_length * num_treatments
    cancer_volume = np.zeros((num_test_points, seq_length))
    chemo_application_point = np.zeros((num_test_points, seq_length))
    radio_application_point = np.zeros((num_test_points, seq_length))
    sequence_lengths = np.zeros(num_test_points)
    patient_types_all_trajectories = np.zeros(num_test_points)

    test_idx = 0
    for i in tqdm(range(num_patients), total=num_patients):

        noise = noise_scale * np.random.randn(seq_length)  
        recovery_rvs = np.random.rand(seq_length)
        factual_cancer_volume = np.zeros(seq_length)
        factual_chemo_dosage = np.zeros(seq_length)
        factual_radio_dosage = np.zeros(seq_length)
        factual_chemo_application_point = np.zeros(seq_length)
        factual_radio_application_point = np.zeros(seq_length)
        factual_chemo_probabilities = np.zeros(seq_length)
        factual_radio_probabilities = np.zeros(seq_length)

        chemo_application_rvs = np.random.rand(seq_length)
        radio_application_rvs = np.random.rand(seq_length)

        factual_cancer_volume[0] = initial_volumes[i]

        alpha = alphas[i]
        beta = betas[i]
        beta_c = beta_cs[i]
        rho = rhos[i]
        K = Ks[i]

        for t in range(0, seq_length - 1):
            current_chemo_dose = 0.0
            previous_chemo_dose = 0.0 if t == 0 else factual_chemo_dosage[t - 1]
            if t >= lag:
                cancer_volume_used = cancer_volume[i, max(t - window_size - lag, 0):max(t - lag + 1, 0)]
            else:
                cancer_volume_used = np.zeros((1, ))
            cancer_diameter_used = np.array(
                [calc_diameter(vol) for vol in cancer_volume_used]).mean()  
            cancer_metric_used = cancer_diameter_used
            radio_prob = (1.0 / (1.0 + np.exp(-radio_sigmoid_betas[i] * (cancer_metric_used - radio_sigmoid_intercepts[i]))))
            chemo_prob = (1.0 / (1.0 + np.exp(- chemo_sigmoid_betas[i] * (cancer_metric_used - chemo_sigmoid_intercepts[i]))))

            factual_chemo_probabilities[t] = chemo_prob
            factual_radio_probabilities[t] = radio_prob
            radio_alpha = 2 * radio_prob
            chemo_alpha = 2 * chemo_prob
            radio_beta = 2 - radio_alpha
            chemo_beta = 2 - chemo_alpha
            try:
                factual_radio_application_point[t] = scipy_beta.rvs(radio_alpha, radio_beta)
            except:
                print("Radio alpha: ", radio_alpha)
                print("Radio beta: ", radio_beta)
                print("Radio prob: ", radio_prob)
                print("Radio sigmoid beta: ", radio_sigmoid_betas[i])
                print('cancer_volume_used: ', cancer_volume_used)
                exit()
            factual_radio_dosage[t] = cs(factual_radio_application_point[t]) * 2
            factual_chemo_application_point[t] = scipy_beta.rvs(chemo_alpha, chemo_beta)
            current_chemo_dose = 5 * cs(factual_chemo_application_point[t])
            factual_chemo_dosage[t] = previous_chemo_dose * np.exp(-np.log(2) / drug_half_life) + current_chemo_dose
            factual_cancer_volume[t + 1] = factual_cancer_volume[t] * \
                (1 + rho * np.log(K / factual_cancer_volume[t]) - beta_c * factual_chemo_dosage[t] -
                    (alpha * factual_radio_dosage[t] + beta * factual_radio_dosage[t] ** 2) + noise[t + 1])

            factual_cancer_volume[t + 1] = np.clip(factual_cancer_volume[t + 1], 0, TUMOUR_DEATH_THRESHOLD)
            cancer_volume[test_idx] = factual_cancer_volume
            chemo_application_point[test_idx] = factual_chemo_application_point
            radio_application_point[test_idx] = factual_radio_application_point
            patient_types_all_trajectories[test_idx] = patient_types[i]
            sequence_lengths[test_idx] = int(t) + 1
            test_idx = test_idx + 1
            treatment_options = [(0, 0), (0, 1), (1, 0), (1, 1)]  
            treatment_options = np.random.uniform(0, 1, (3, 2))

            for treatment_option in treatment_options:
                counterfactual_chemo_application_point = treatment_option[0]
                counterfactual_radio_application_point = treatment_option[1]
                current_chemo_dose = 5 * cs(counterfactual_chemo_application_point)
                counterfactual_radio_dosage = cs(counterfactual_radio_application_point) * 2

                counterfactual_chemo_dosage = previous_chemo_dose * np.exp(
                    -np.log(2) / drug_half_life) + current_chemo_dose

                counterfactual_cancer_volume = factual_cancer_volume[t] *\
                    (1 + rho * np.log(K / factual_cancer_volume[t]) - beta_c * counterfactual_chemo_dosage -
                        (alpha * counterfactual_radio_dosage + beta * counterfactual_radio_dosage ** 2) + noise[t + 1])

                cancer_volume[test_idx][:t + 2] = np.append(factual_cancer_volume[:t + 1],
                                                            [counterfactual_cancer_volume])
                chemo_application_point[test_idx][:t + 1] = np.append(factual_chemo_application_point[:t],
                                                                      [counterfactual_chemo_application_point])
                radio_application_point[test_idx][:t + 1] = np.append(factual_radio_application_point[:t],
                                                                      [counterfactual_radio_application_point])
                patient_types_all_trajectories[test_idx] = patient_types[i]
                sequence_lengths[test_idx] = int(t) + 1
                test_idx = test_idx + 1

            if (factual_cancer_volume[t + 1] >= TUMOUR_DEATH_THRESHOLD) or \
                    recovery_rvs[t] <= np.exp(-factual_cancer_volume[t + 1] * TUMOUR_CELL_DENSITY):
                break

    outputs = {'cancer_volume': cancer_volume[:test_idx],
               'chemo_application': chemo_application_point[:test_idx],
               'radio_application': radio_application_point[:test_idx],
               'sequence_lengths': sequence_lengths[:test_idx],
               'patient_types': patient_types_all_trajectories[:test_idx]
               }

    print("Call to simulate counterfactuals data")

    return outputs


def simulate_counterfactuals_treatment_seq(simulation_params, seq_length, projection_horizon, cf_seq_mode='sliding_treatment'):
    """
    Simulation of test trajectories to asses a subset of multiple-step ahead counterfactuals
    :param simulation_params: Parameters of the simulation
    :param seq_length: Maximum trajectory length (number of factual time-steps)
    :param cf_seq_mode: Counterfactual sequence setting: sliding_treatment / random_trajectories
    :return: simulated data dict with number of rows equal to num_patients * seq_length * 2 * projection_horizon
    """

    if cf_seq_mode == 'sliding_treatment':
        chemo_arr = np.stack([np.eye(projection_horizon, dtype=int),
                              np.zeros((projection_horizon, projection_horizon), dtype=int)], axis=-1)
        radio_arr = np.stack([np.zeros((projection_horizon, projection_horizon), dtype=int),
                              np.eye(projection_horizon, dtype=int)], axis=-1)
        treatment_options = np.random.uniform(0, 1, (projection_horizon * 2, projection_horizon, 2))

    elif cf_seq_mode == 'random_trajectories':
        treatment_options = np.random.randint(0, 2, (projection_horizon * 2, projection_horizon, 2))
        treatment_options = np.random.uniform(0, 1, (projection_horizon * 2, projection_horizon, 2))

    else:
        raise NotImplementedError()

    total_num_radio_treatments = 1
    total_num_chemo_treatments = 1

    radio_amt = np.array([2.0 for i in range(total_num_radio_treatments)])  
    chemo_amt = [5.0 for i in range(total_num_chemo_treatments)]
    chemo_days = [(i + 1) * 7 for i in range(total_num_chemo_treatments)]
    chemo_idx = np.argsort(chemo_days)
    chemo_amt = np.array(chemo_amt)[chemo_idx]
    chemo_days = np.array(chemo_days)[chemo_idx]

    drug_half_life = 1  
    initial_stages = simulation_params['initial_stages']
    initial_volumes = simulation_params['initial_volumes']
    alphas = simulation_params['alpha']
    rhos = simulation_params['rho']
    betas = simulation_params['beta']
    beta_cs = simulation_params['beta_c']
    Ks = simulation_params['K']
    patient_types = simulation_params['patient_types']
    window_size = simulation_params['window_size']  
    lag = simulation_params['lag']
    chemo_sigmoid_intercepts = simulation_params['chemo_sigmoid_intercepts']
    radio_sigmoid_intercepts = simulation_params['radio_sigmoid_intercepts']
    chemo_sigmoid_betas = simulation_params['chemo_sigmoid_betas']
    radio_sigmoid_betas = simulation_params['radio_sigmoid_betas']

    num_patients = initial_stages.shape[0]

    num_test_points = len(treatment_options) * num_patients * seq_length
    cancer_volume = np.zeros((num_test_points, seq_length + projection_horizon))
    chemo_application_point = np.zeros((num_test_points, seq_length + projection_horizon))
    radio_application_point = np.zeros((num_test_points, seq_length + projection_horizon))
    sequence_lengths = np.zeros(num_test_points)
    patient_types_all_trajectories = np.zeros(num_test_points)
    patient_ids_all_trajectories = np.zeros(num_test_points)
    patient_current_t = np.zeros(num_test_points)

    test_idx = 0
    for i in tqdm(range(num_patients), total=num_patients):

        noise = noise_scale * np.random.randn(seq_length + projection_horizon)  
        recovery_rvs = np.random.rand(seq_length)
        factual_cancer_volume = np.zeros(seq_length)
        factual_chemo_dosage = np.zeros(seq_length)
        factual_radio_dosage = np.zeros(seq_length)
        factual_chemo_application_point = np.zeros(seq_length)
        factual_radio_application_point = np.zeros(seq_length)
        
        chemo_application_rvs = np.random.rand(seq_length)
        radio_application_rvs = np.random.rand(seq_length)

        factual_cancer_volume[0] = initial_volumes[i]

        alpha = alphas[i]
        beta = betas[i]
        beta_c = beta_cs[i]
        rho = rhos[i]
        K = Ks[i]

        for t in range(0, seq_length - 1):
            current_chemo_dose = 0.0
            previous_chemo_dose = 0.0 if t == 0 else factual_chemo_dosage[t - 1]
            if t >= lag:
                cancer_volume_used = cancer_volume[i, max(t - window_size - lag, 0):max(t - lag + 1, 0)]
            else:
                cancer_volume_used = np.zeros((1,))
            cancer_diameter_used = np.array(
                [calc_diameter(vol) for vol in cancer_volume_used]).mean()  
            cancer_metric_used = cancer_diameter_used
            radio_sigmoid = (1.0 / (1.0 + np.exp(- radio_sigmoid_betas[i] * (cancer_metric_used - radio_sigmoid_intercepts[i]))))
            chemo_sigmoid = (1.0 / (1.0 + np.exp(- chemo_sigmoid_betas[i] * (cancer_metric_used - chemo_sigmoid_intercepts[i]))))
            radio_alpha = 2 * radio_sigmoid
            chemo_alpha = 2 * chemo_sigmoid
            radio_beta = 2 - radio_alpha
            chemo_beta = 2 - chemo_alpha

            factual_chemo_application_point[t] = scipy_beta.rvs(chemo_alpha, chemo_beta)
            factual_radio_application_point[t] = scipy_beta.rvs(radio_alpha, radio_beta)

            factual_radio_dosage[t] = cs(factual_radio_application_point[t]) * 2
            current_chemo_dose = 5 * cs(factual_chemo_application_point[t])
            factual_chemo_dosage[t] = previous_chemo_dose * np.exp(-np.log(2) / drug_half_life) + current_chemo_dose
            factual_cancer_volume[t + 1] = factual_cancer_volume[t] * \
                (1 + rho * np.log(K / factual_cancer_volume[t]) - beta_c * factual_chemo_dosage[t] -
                    (alpha * factual_radio_dosage[t] + beta * factual_radio_dosage[t] ** 2) + noise[t + 1])

            factual_cancer_volume[t + 1] = np.clip(factual_cancer_volume[t + 1], 0, TUMOUR_DEATH_THRESHOLD)

            if cf_seq_mode == 'random_trajectories':
                treatment_options = np.random.randint(0, 2, (projection_horizon * 2, projection_horizon, 2))

            for treatment_option in treatment_options:

                counterfactual_cancer_volume = np.zeros(shape=(t + 1 + projection_horizon + 1))
                counterfactual_chemo_application_point = np.zeros(shape=(t + 1 + projection_horizon))
                counterfactual_radio_application_point = np.zeros(shape=(t + 1 + projection_horizon))
                counterfactual_chemo_dosage = np.zeros(shape=(t + 1 + projection_horizon))
                counterfactual_radio_dosage = np.zeros(shape=(t + 1 + projection_horizon))

                counterfactual_cancer_volume[:t + 2] = factual_cancer_volume[:t + 2]
                counterfactual_chemo_application_point[:t + 1] = factual_chemo_application_point[:t + 1]
                counterfactual_radio_application_point[:t + 1] = factual_radio_application_point[:t + 1]
                counterfactual_chemo_dosage[:t + 1] = factual_chemo_dosage[:t + 1]
                counterfactual_radio_dosage[:t + 1] = factual_radio_dosage[:t + 1]

                for projection_time in range(0, projection_horizon):

                    current_t = t + 1 + projection_time
                    previous_chemo_dose = counterfactual_chemo_dosage[current_t - 1]

                    counterfactual_chemo_application_point[current_t] = treatment_option[projection_time][0]
                    counterfactual_radio_application_point[current_t] = treatment_option[projection_time][1]

                    current_chemo_dose = 5 * cs(treatment_option[projection_time][0])
                    counterfactual_radio_dosage[current_t] = 2 * cs(treatment_option[projection_time][1])

                    counterfactual_chemo_dosage[current_t] = previous_chemo_dose * np.exp(
                        -np.log(2) / drug_half_life) + current_chemo_dose

                    counterfactual_cancer_volume[current_t + 1] = counterfactual_cancer_volume[current_t] *\
                        (1 + rho * np.log(K / (counterfactual_cancer_volume[current_t] + 1e-07) + 1e-07) -
                         beta_c * counterfactual_chemo_dosage[current_t] -
                         (alpha * counterfactual_radio_dosage[current_t] + beta * counterfactual_radio_dosage[current_t] ** 2) +
                         noise[current_t + 1])

                if (np.isnan(counterfactual_cancer_volume).any()):
                    continue

                cancer_volume[test_idx][:t + 1 + projection_horizon + 1] = counterfactual_cancer_volume
                chemo_application_point[test_idx][:t + 1 + projection_horizon] = counterfactual_chemo_application_point
                radio_application_point[test_idx][:t + 1 + projection_horizon] = counterfactual_radio_application_point
                patient_types_all_trajectories[test_idx] = patient_types[i]
                patient_ids_all_trajectories[test_idx] = i
                patient_current_t[test_idx] = t

                sequence_lengths[test_idx] = int(t) + projection_horizon + 1
                test_idx = test_idx + 1

            if (factual_cancer_volume[t + 1] >= TUMOUR_DEATH_THRESHOLD) or \
                    recovery_rvs[t] <= np.exp(-factual_cancer_volume[t + 1] * TUMOUR_CELL_DENSITY):
                break

    outputs = {'cancer_volume': cancer_volume[:test_idx],
               'chemo_application': chemo_application_point[:test_idx],
               'radio_application': radio_application_point[:test_idx],
               'sequence_lengths': sequence_lengths[:test_idx],
               'patient_types': patient_types_all_trajectories[:test_idx],
               'patient_ids_all_trajectories': patient_ids_all_trajectories[:test_idx],
               'patient_current_t': patient_current_t[:test_idx],
               'alpha': alphas,
               'rho': rhos,
               'beta': betas,
               'beta_c': beta_cs,
               'K': Ks
            }

    return outputs


def get_scaling_params(sim):
    real_idx = ['cancer_volume', 'chemo_dosage', 'radio_dosage']
    means = {}
    stds = {}
    seq_lengths = sim['sequence_lengths']
    for k in real_idx:
        active_values = []
        for i in range(seq_lengths.shape[0]):
            end = int(seq_lengths[i])
            active_values += list(sim[k][i, :end])

        means[k] = np.mean(active_values)
        stds[k] = np.std(active_values)
    means['patient_types'] = np.mean(sim['patient_types'])
    stds['patient_types'] = np.std(sim['patient_types'])

    return pd.Series(means), pd.Series(stds)


def simulate_output_after_actions(Ht, actions, scaling_params):
    """
    Simulate counterfactual outcoms given historical data and one intervention sequence
    
    Args:
        Ht: Dictionary containing historical data
        actions: Single intervention sequence of shape (projection_horizon, 2)
        scaling_params
    
    Returns:
        counterfactual outcoms
    """
    projection_horizon = actions.shape[1]
    actions = actions.cpu().numpy() if isinstance(actions, torch.Tensor) else actions
    factual_cancer_volume = Ht['cancer_volume'].cpu().numpy() if isinstance(Ht['cancer_volume'], torch.Tensor) else Ht['cancer_volume']
    factual_unscaled_outputs = Ht['unscaled_outputs'].cpu().numpy() if isinstance(Ht['unscaled_outputs'], torch.Tensor) else Ht['unscaled_outputs']
    factual_chemo_application = Ht['chemo_application'].cpu().numpy() if isinstance(Ht['chemo_application'], torch.Tensor) else Ht['chemo_application']
    factual_radio_application = Ht['radio_application'].cpu().numpy() if isinstance(Ht['radio_application'], torch.Tensor) else Ht['radio_application']
    patient_types = Ht['patient_types'].cpu().numpy() if isinstance(Ht['patient_types'], torch.Tensor) else Ht['patient_types']
    alphas = Ht['alpha'].cpu().numpy() if isinstance(Ht['alpha'], torch.Tensor) else Ht['alpha']
    betas = Ht['beta'].cpu().numpy() if isinstance(Ht['beta'], torch.Tensor) else Ht['beta']
    beta_cs = Ht['beta_c'].cpu().numpy() if isinstance(Ht['beta_c'], torch.Tensor) else Ht['beta_c']
    rhos = Ht['rho'].cpu().numpy() if isinstance(Ht['rho'], torch.Tensor) else Ht['rho']
    Ks = Ht['K'].cpu().numpy() if isinstance(Ht['K'], torch.Tensor) else Ht['K']
        
    num_patients = len(factual_cancer_volume)
    current_t = factual_cancer_volume.shape[1]
    cancer_volume = np.zeros((num_patients, current_t + projection_horizon + 1))
    chemo_application = np.zeros((num_patients, current_t + projection_horizon))
    radio_application = np.zeros((num_patients, current_t + projection_horizon))
    sequence_lengths = np.zeros(num_patients)
    
    drug_half_life = 1  
    noise = noise_scale * np.random.randn(num_patients, current_t + projection_horizon + 1)
    
    for i in range(num_patients):
        alpha = alphas[i]
        beta = betas[i]
        beta_c = beta_cs[i]
        rho = rhos[i]
        K = Ks[i]
        action = actions[i]
        cf_cancer_volume = np.zeros(current_t + projection_horizon + 1)
        cf_chemo_application = np.zeros(current_t + projection_horizon)
        cf_radio_application = np.zeros(current_t + projection_horizon)
        cf_chemo_dosage = np.zeros(current_t + projection_horizon)
        cf_radio_dosage = np.zeros(current_t + projection_horizon)
        cf_cancer_volume[0] = factual_cancer_volume[i, 0]
        cf_cancer_volume[1:current_t + 1] = factual_unscaled_outputs[i, :current_t].reshape(-1)
        cf_chemo_application[:current_t] = factual_chemo_application[i, :current_t]
        cf_radio_application[:current_t] = factual_radio_application[i, :current_t]
        for t in range(current_t):
            if t == 0:
                cf_chemo_dosage[t] = 5 * cs(cf_chemo_application[t])
            else:
                cf_chemo_dosage[t] = cf_chemo_dosage[t-1] * np.exp(-np.log(2) / drug_half_life) + \
                                   5 * cs(cf_chemo_application[t])
            cf_radio_dosage[t] = 2 * cs(cf_radio_application[t])
        for t in range(projection_horizon):
            current_idx = current_t + t
            cf_chemo_application[current_idx] = action[t][0]
            cf_radio_application[current_idx] = action[t][1]
            if current_idx == 0:
                cf_chemo_dosage[current_idx] = 5 * cs(cf_chemo_application[current_idx])
            else:
                cf_chemo_dosage[current_idx] = cf_chemo_dosage[current_idx-1] * \
                                              np.exp(-np.log(2) / drug_half_life) + \
                                              5 * cs(cf_chemo_application[current_idx])
            cf_radio_dosage[current_idx] = 2 * cs(cf_radio_application[current_idx])
            cf_cancer_volume[current_idx + 1] = cf_cancer_volume[current_idx] * \
                (1 + rho * np.log(K / (cf_cancer_volume[current_idx] + 1e-07) + 1e-07) - \
                 beta_c * cf_chemo_dosage[current_idx] - \
                 (alpha * cf_radio_dosage[current_idx] + beta * cf_radio_dosage[current_idx] ** 2) + \
                 noise[i, current_idx + 1])
            
            cf_cancer_volume[current_idx + 1] = np.clip(cf_cancer_volume[current_idx + 1], 
                                                      0, TUMOUR_DEATH_THRESHOLD)
        if not np.isnan(cf_cancer_volume).any():
            cancer_volume[i] = cf_cancer_volume
            chemo_application[i] = cf_chemo_application
            radio_application[i] = cf_radio_application
            sequence_lengths[i] = current_t + projection_horizon + 1
    outputs = cancer_volume[:, -1].reshape(-1, 1) 
    mean, std = scaling_params
    outputs = (outputs - mean['cancer_volume']) / std['cancer_volume']
    
    return outputs


def plot_treatments(data: dict, patient: int):
    df = pd.DataFrame({'N(t)': data['cancer_volume'][patient],
                       'C(t)': data['chemo_dosage'][patient],
                       'd(t)': data['radio_dosage'][patient],
                       })
    df = df[['N(t)', "C(t)", "d(t)"]]
    df.plot(secondary_y=['C(t)', 'd(t)'])
    plt.xlabel("$t$")
    plt.show()


def plot_sigmoid_function(data: dict):
    """
    Simple plots to visualise probabilities of treatment assignments

    :return:
    """
    for coeff in [i for i in range(11)]:
        tumour_death_threshold = calc_volume(13)
        assigned_beta = coeff / tumour_death_threshold
        assigned_interp = tumour_death_threshold / 2
        idx = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])
        volumes = idx * tumour_death_threshold

        def sigmoid_fxn(volume, beta, intercept):
            return (1.0 / (1.0 + np.exp(-beta * (volume - intercept))))

        data[coeff] = pd.Series(sigmoid_fxn(volumes, assigned_beta, assigned_interp), index=idx)

    df = pd.DataFrame(data)
    df.plot()
    plt.show()


if __name__ == "__main__":
    logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.INFO)

    np.random.seed(100)

    seq_length = 60  
    window_size = 15
    lag = 0  
    num_patients = 10000
    chemo_coeff = radio_coeff = 10.0

    params = generate_params(num_patients, chemo_coeff=chemo_coeff, radio_coeff=radio_coeff, lag=lag)
    params['window_size'] = window_size
    training_data = simulate_factual(params, seq_length)

    params = generate_params(int(num_patients / 10), chemo_coeff=chemo_coeff, radio_coeff=radio_coeff, lag=lag)
    params['window_size'] = window_size
    validation_data = simulate_factual(params, seq_length)

    params = generate_params(int(num_patients / 10), chemo_coeff=chemo_coeff, radio_coeff=radio_coeff, lag=lag)
    params['window_size'] = window_size
    test_data_factuals = simulate_factual(params, seq_length)
    test_data_counterfactuals = simulate_counterfactual_1_step(params, seq_length)

    params = generate_params(int(num_patients / 10), chemo_coeff=chemo_coeff, radio_coeff=radio_coeff, lag=lag)
    params['window_size'] = window_size
    test_data_seq = simulate_counterfactuals_treatment_seq(params, seq_length, 5)
    plot_treatments(training_data, 572)



