# Supplementary Experiment

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torch.nn.utils import clip_grad_norm_
import random
import gym
import math
from torch.utils.tensorboard import SummaryWriter
from collections import deque, namedtuple
import time

import warnings
warnings.filterwarnings("ignore")

import os
os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'

from torch.autograd import Variable
import pandas as pd
from lifelines import CoxPHFitter
from lifelines import AalenAdditiveFitter
from lifelines import KaplanMeierFitter
import lifelines
from scipy.optimize import minimize
from sklearn import linear_model
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Selected device:", device)

from lifelines import CoxPHFitter
import numpy as np
import pandas as pd

###############################################################################################################
# Survival Regression / Analysis
###############################################################################################################

def Survival_T_Aalen(time, Y, X, A, Delta):
    A = np.reshape(A, (A.shape[0], 1))
    A = 2 * A - 1
    X_adjusted = np.multiply(X, A.reshape(-1, 1)) 
    data = pd.DataFrame(X_adjusted, columns=['X1*A', 'X2*A'])
    pre_data = pd.DataFrame(X_adjusted, columns=['X1*A', 'X2*A'])

    data['Y'] = Y.flatten()
    data['Delta'] = Delta.flatten()

    pre_data['Y'] = Y.flatten()
    pre_data['Delta'] = Delta.flatten()

    cox = AalenAdditiveFitter()
    cox.fit(data, duration_col='Y', event_col='Delta')
    st_estimate = cox.predict_survival_function(pre_data, times=time)
    st_estimate_values = st_estimate.values.reshape(1, -1)
    return st_estimate_values


def Survival_T_Cox(time, Y, X, A, Delta): 
    A = np.reshape(A, (A.shape[0], 1))
    A = 2 * A - 1
    X_adjusted = np.multiply(X, A.reshape(-1, 1))
    data = pd.DataFrame(X_adjusted, columns=['X1*A', 'X2*A'])
    pre_data = pd.DataFrame(X_adjusted, columns=['X1*A', 'X2*A'])

    data['Y'] = Y.flatten()
    data['Delta'] = Delta.flatten()

    pre_data['Y'] = Y.flatten()
    pre_data['Delta'] = Delta.flatten()

    cox = CoxPHFitter()
    cox.fit(data, duration_col='Y', event_col='Delta')
    st_estimate = cox.predict_survival_function(pre_data, times=time)
    return st_estimate.values  

def Survival_C(time, Y, X, A, Delta):  
    A = np.reshape(A, (A.shape[0], 1))
    data = pd.DataFrame(np.hstack([X, A, np.multiply(X, A)]), columns=['X1', 'X2', 'A', 'X1*A', 'X2*A'])
    data['Y'] = Y.flatten()
    data['1-Delta'] = 1 - Delta.flatten()
    kmf = KaplanMeierFitter()
    kmf.fit(durations=data['Y'], event_observed=data['1-Delta'])
    sc_estimate = kmf.survival_function_at_times(time)
    return sc_estimate.values

###############################################################################################################
# DQN Network
###############################################################################################################

class QR_DQN(nn.Module):
    def __init__(self, state_size, action_size, layer_size, seed, neg_softplus=True):
        super(QR_DQN, self).__init__()
        self.seed = torch.manual_seed(seed)
        self.input_shape = state_size[1]
        self.action_size = action_size
        self.head_1 = nn.Linear(self.input_shape, layer_size)
        self.ff_2 = nn.Linear(layer_size, action_size)
        self.neg_softplus = neg_softplus

    def forward(self, x):
        x = torch.relu(self.head_1(x))
        x = nn.functional.softplus(self.ff_2(x))
        if self.neg_softplus:
            x = -x
        return x
def normalization(data):
    _range = np.max(data) - np.min(data)
    return (data - np.min(data)) / _range


###############################################################################################################
# Agent Environment
###############################################################################################################

class CustomEnv(gym.Env):
    def __init__(self, max_steps, num_agents, mode):
        super(CustomEnv, self).__init__()

        self.action_space = gym.spaces.Discrete(2)
        self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(num_agents, 2), dtype=np.float32)
        self.num_agents = num_agents

        self.max_steps = max_steps
        self.current_step = 0
        self.mode = mode  # 'aalen', 'cox', 'km'

        self.seed_value = None

    def seed(self, seed=None):
        self.seed_value = seed

    def reset(self):
        if self.seed_value is not None:
            np.random.seed(self.seed_value)
            self.state = np.random.randn(self.num_agents, 2)
            self.censor = np.random.uniform(low=0, high=10*7, size=(self.num_agents))
        else:
            self.censor = np.random.uniform(low=0, high=10*7, size=(self.num_agents))
            self.state = np.zeros((self.num_agents, 2))
        self.current_step = 0
        self.observe = np.ones(self.num_agents)

        return self.state, self.observe

    def step(self, action):
        reward, Delta, survival_time1, survival_time2 = self.reward_function(action, self.state, self.observe)
        state_next = self.state_function(self.state, action)
        next_observe = self.observe_function(self.observe, Delta)
        self.state = state_next
        self.observe = next_observe
        done = self.current_step >= self.max_steps
        self.current_step += 1

        return state_next, reward, next_observe, Delta, survival_time1, survival_time2, done

    def observe_function(self, observe, Delta):
        next_observe = observe.copy()
        indices_matrix = np.nonzero(next_observe)[0]
        next_observe[indices_matrix[Delta == 0]] = 0
        return next_observe

    def state_function(self, state, action):
        A = np.array(action)
        next_state = np.zeros(state.shape)
        for i in range(state.shape[0]):
            beta = np.array([[3 / 4 * (2 * A[i][0] - 1), 0], [0, 3 / 4 * (1 - 2 * A[i][0])]])
            next_state[i] = np.dot(state[i], beta) + np.random.normal(0, 0.25, size=state.shape)
        next_state = next_state 
        return next_state

    def reward_function(self, action, state, observe):
        action = np.array(action)
        action = action[:, 0]
        beta_1 = np.array([0, -2, 1])
        state_1 = np.insert(state, 0, 1, axis=1)

        linear_predictor = (2 * action - 1) * np.dot(state_1, beta_1)
        baseline_hazard = 1
        param = 1 / (np.exp(linear_predictor) * baseline_hazard)
        T = np.random.gamma(shape=1 / param, scale=param)

        survival_time1 = np.mean(param)
        T = np.clip(T, None, 7)
        survival_time2 = np.mean(T)

        Y = np.minimum(T, self.censor)
        Delta = (T <= self.censor).astype(int)
        self.censor = self.censor - Y
        indices = np.where(observe == 1)[0]
        Y = Y[indices]
        Delta = Delta[indices]
        state = state[indices]
        action = action[indices]
        state_1 = state_1[indices]

        if self.mode == 'aalen':
            time = 0.5
            reward = Survival_T_Aalen(time, Y, np.array(state), np.array(action), Delta)
            reward = np.log(reward + 1e-5)
            return reward, Delta, survival_time1, survival_time2

        if self.mode == 'cox':
            time = 0.5
            reward = Survival_T_Cox(time, Y, np.array(state), np.array(action), Delta)
            reward = np.log(reward + 1e-5)
            return reward, Delta, survival_time1, survival_time2

        S = Survival_C(Y, Y, np.array(state), np.array(action), Delta)
        reward = np.zeros_like(Y)
        non_zero_indices = Delta != 0
        reward[non_zero_indices] = (Y[non_zero_indices] * Delta[non_zero_indices]) / S[non_zero_indices]
        return reward, Delta, survival_time1, survival_time2

###############################################################################################################
# ReplayBuffer
###############################################################################################################

class ReplayBuffer:
    """Fixed-size buffer to store experience tuples."""

    def __init__(self, buffer_size, batch_size, device, seed, gamma, n_step=1):
        """Initialize a ReplayBuffer object.
        Params
        ======
            buffer_size (int): maximum size of buffer
            batch_size (int): size of each training batch
            seed (int): random seed
        """
        self.device = device
        self.memory = deque(maxlen=buffer_size)
        self.batch_size = batch_size
        self.experience = namedtuple("Experience", field_names=["state", "action", "observe", "reward", "next_state", "done"])
        self.seed = random.seed(seed)
        self.gamma = gamma
        self.n_step = n_step
        self.n_step_buffer = deque(maxlen=self.n_step)

    def add(self, state, action, observe, reward, next_state, done):
        """Add a new experience to memory."""
        self.n_step_buffer.append((state, action, observe, reward, next_state, done))
        if len(self.n_step_buffer) == self.n_step:
            state, action, observe, reward, next_state, done = self.calc_multistep_return()
            e = self.experience(state, action, observe, reward, next_state, done)
            self.memory.append(e)

    def calc_multistep_return(self):
        Return = 0
        for idx in range(self.n_step):
            Return += self.gamma ** idx * self.n_step_buffer[idx][3]
        return self.n_step_buffer[0][0], self.n_step_buffer[0][1], self.n_step_buffer[0][2], Return, self.n_step_buffer[-1][4], self.n_step_buffer[-1][5]

    def sample(self):
        """Randomly sample a batch of experiences from memory."""
        experiences = random.sample(self.memory, k=self.batch_size)

        states = torch.from_numpy(np.stack([e.state for e in experiences if e is not None])).float().to(self.device)
        observes = torch.from_numpy(np.stack([e.observe for e in experiences if e is not None])).long().to(self.device)
        actions = torch.from_numpy(np.stack([e.action for e in experiences if e is not None])).long().to(self.device)

        rewards = []
        for e in experiences:
            if e is not None:
                observe_indices = np.where(e.observe == 1)[0]
                reward = np.zeros_like(e.observe, dtype=np.float32)
                t = 0
                for idx in observe_indices:
                    try:
                        reward[idx] = e.reward[0, t]  # Aalen / Cox
                    except Exception:
                        reward[idx] = e.reward[t]  # KM
                    t += 1
                rewards.append(reward)
        rewards = torch.from_numpy(np.stack(rewards)).float().to(self.device)
        next_states = torch.from_numpy(np.stack([e.next_state for e in experiences if e is not None])).float().to(self.device)
        dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(self.device)

        return (states, actions, observes, rewards, next_states, dones)

    def __len__(self):
        """Return the current size of internal memory."""
        return len(self.memory)

###############################################################################################################
# DQN_Agent
###############################################################################################################

class DQN_Agent():
    """Interacts with and learns from the environment."""

    def __init__(self,
                 state_size,
                 action_size,
                 Network,
                 layer_size,
                 n_step,
                 BATCH_SIZE,
                 BUFFER_SIZE,
                 LR,
                 TAU,
                 GAMMA,
                 UPDATE_EVERY,
                 device,
                 seed,
                 neg_softplus=True):
        """Initialize an Agent object.

        Params
        ======
            state_size (int): dimension of each state
            action_size (int): dimension of each action
            Network (str): dqn network type
            layer_size (int): size of the hidden layer
            BATCH_SIZE (int): size of the training batch
            BUFFER_SIZE (int): size of the replay memory
            LR (float): learning rate
            TAU (float): tau for soft updating the network weights
            GAMMA (float): discount factor
            UPDATE_EVERY (int): update frequency
            device (str): device that is used for the compute
            seed (int): random seed
        """
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(seed)
        self.device = device
        self.TAU = TAU
        self.GAMMA = GAMMA
        self.UPDATE_EVERY = UPDATE_EVERY
        self.BATCH_SIZE = BATCH_SIZE
        self.Q_updates = 0
        self.n_step = n_step

        self.qnetwork_local = QR_DQN(state_size, action_size, layer_size, seed, neg_softplus=neg_softplus).to(self.device)
        self.qnetwork_target = QR_DQN(state_size, action_size, layer_size, seed, neg_softplus=neg_softplus).to(self.device)

        self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)

        self.memory = ReplayBuffer(BUFFER_SIZE, BATCH_SIZE, self.device, seed, self.GAMMA, n_step)

        self.t_step = 0

    def step(self, state, action, observe, reward, next_state, done, writer):

        self.memory.add(state, action, observe, reward, next_state, done)
        self.t_step = (self.t_step + 1) % self.UPDATE_EVERY
        if self.t_step == 0:
            if len(self.memory) > self.BATCH_SIZE:
                experiences = self.memory.sample()
                loss = self.learn(experiences)
                self.Q_updates += 1
                writer.add_scalar("Q_loss", loss, self.Q_updates)

    def act(self, state, eps=0.):
        """Returns actions for given state as per current policy. Acting only every 4 frames!

        Params
        ======
            frame: to adjust epsilon
            state (array_like): current state

        """

        state = np.array(state)
        state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)

        self.qnetwork_local.eval()
        with torch.no_grad():
            action_values = self.qnetwork_local.forward(state)
        self.qnetwork_local.train()
        action_values = action_values.squeeze(0)
        if random.random() > eps:
            max_value, action = torch.max(action_values, dim=1)
            action = action.unsqueeze(1)
            return action, max_value
        else:
            action = torch.randint(low=0, high=action_values.shape[1], size=(action_values.shape[0], 1))
            max_value = action_values.gather(1, action.to(self.device))
            return action, max_value

    def learn(self, experiences):
        """Update value parameters using given batch of experience tuples.
        Params
        ======
            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples
            gamma (float): discount factor
        """
        self.optimizer.zero_grad()
        states, actions, observes, rewards, next_states, dones = experiences
        states = states[observes.bool()]
        next_states = next_states[observes.bool()]
        actions = actions[observes.bool()]
        rewards = rewards[observes.bool()]

        Q_targets_next = self.qnetwork_target(next_states).detach().cpu()
        action_indx = torch.argmax(Q_targets_next, dim=1)
        Q_targets_next = Q_targets_next.gather(1, action_indx.unsqueeze(1))

        Q_targets = rewards.unsqueeze(1) + self.GAMMA * Q_targets_next.to(self.device)

        Q_expected = self.qnetwork_local(states).gather(1, actions)
        loss = F.mse_loss(Q_targets, Q_expected)
        loss.backward()
        self.optimizer.step()

        # ------------------- update target network ------------------- #
        self.soft_update(self.qnetwork_local, self.qnetwork_target)
        return loss.detach().cpu().numpy()

    def soft_update(self, local_model, target_model):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target
        Params
        ======
            local_model (PyTorch model): weights will be copied from
            target_model (PyTorch model): weights will be copied to
            tau (float): interpolation parameter
        """
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(self.TAU * local_param.data + (1.0 - self.TAU) * target_param.data)

###############################################################################################################
# Training Loop
###############################################################################################################

def run(episodes=1000, eps_fixed=False, eps_frames=1e6, min_eps=0.01, num_agents=1000):
    """Deep Q-Learning.

    Params
    ======
        episodes (int): maximum number of training episodes
        eps_fixed (bool): flag for using fixed epsilon or not
        eps_frames (float): number of frames for epsilon annealing
        min_eps (float): minimum value of epsilon
        alpha (float): exponent for score calculation
    """
    scores = []  # list containing scores from each episode
    rewards = [[] for _ in range(num_agents)]
    rate = []
    rates = []
    survival_times1 = []
    survival_times2 = []
    scores_window = deque(maxlen=100)  # last 100 scores
    survival_times1_window = deque(maxlen=20)
    survival_times2_window = deque(maxlen=20)
    output_history = []
    frame = 0
    if eps_fixed:
        eps = 0
    else:
        eps = 1
    eps_start = 1
    i_episode = 1
    state, observe = env.reset()
    score = np.zeros(num_agents)
    for i_episode in range(1, episodes + 1):

        while True:
            action, _ = agent.act(state, eps)
            next_state, reward, next_observe, Delta, survival_time1, survival_time2, done = env.step(action)
            agent.step(state, action, observe, reward, next_state, done, writer)
            indices = np.where(observe == 1)[0]
            our_actions = np.array(action).squeeze(1)
            our_actions = our_actions[indices]
            t = 0
            for idx in indices:
                try:
                    rewards[idx].append(reward[0, t])  # Aalen / Cox
                except Exception:
                    rewards[idx].append(reward[t])      # KM
                t += 1
            survival_times1.append(survival_time1)
            survival_times2.append(survival_time2)
            survival_times1_window.append(survival_time1)
            survival_times2_window.append(survival_time2)
            censor_rate = 1 - np.sum(observe) / len(observe)
            rate.append(censor_rate)
            state = next_state
            observe = next_observe
            t = 0
            for idx in indices:
                try:
                    score[idx] = reward[0, t] + score[idx]
                except Exception:
                    score[idx] = reward[t] + score[idx]
                t += 1
            frame += 1

            if eps_fixed == False:
                if frame < eps_frames:
                    eps = max(eps_start - (frame * (1 / eps_frames)), min_eps)
                else:
                    eps = -1

            if done:
                scores_window.append(np.mean(score))  # save most recent score
                scores.append(np.mean(score))  # save most recent score
                rates.append(censor_rate)
                writer.add_scalar("Average100", np.mean(scores_window), frame)
                output_history.append(np.mean(scores_window))
                print('\rEpisode {} \tAEV: {:.2f} \tARED: {:.2f} '.format(i_episode, float(np.sum(survival_times1_window)), float(np.sum(survival_times2_window))), end="")
                state, observe = env.reset()
                score = np.zeros(num_agents)
                break

    return scores, rewards, rates, rate, survival_times1, survival_times2

###############################################################################################################
# Train All for Mode
###############################################################################################################

def train_all_for_mode(mode, seeds, save_prefix, stages, num_people=1000):
    episodes = 50
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Selected device:", device)

    neg_softplus_flag = (mode != 'km')

    for seed in seeds:
        scores_results = {}
        rate_results = {}
        rates_results = {}
        survival_times1_results = {}
        survival_times2_results = {}

        for si, stage_steps in enumerate(stages):
            print(f"Running stage (max_steps) = {stage_steps} with seed {seed}")
            np.random.seed(seed)
            torch.manual_seed(seed)
            env_max_steps = stage_steps - 1

            global env
            env = CustomEnv(max_steps=env_max_steps, num_agents=num_people, mode=mode)
            env.seed(seed)
            action_size = env.action_space.n
            state_size = env.observation_space.shape

            global writer
            writer = SummaryWriter(f"DQNt_LL_new_1_stage_{stage_steps}_seed_{seed}/")

            global agent
            agent = DQN_Agent(state_size=state_size,
                              action_size=action_size,
                              Network="DDQN",
                              layer_size=64,
                              n_step=1,
                              BATCH_SIZE=32,
                              BUFFER_SIZE=6400,
                              LR=1e-3,
                              TAU=1e-2,
                              GAMMA=0.99,
                              UPDATE_EVERY=1,
                              device=device,
                              seed=seed,
                              neg_softplus=neg_softplus_flag)

            eps_fixed = False
            t0 = time.time()
            scores, rewards, rates, rate, survival_times1, survival_times2 = run(
                episodes=episodes, eps_fixed=eps_fixed, eps_frames=100, min_eps=0.025, num_agents=num_people
            )
            t1 = time.time()
            print(f"Training time for stage={stage_steps} with seed {seed}: {round((t1 - t0) / 60, 2)}min")

            scores_results[si] = scores
            rate_results[si] = rate
            rates_results[si] = rates
            survival_times1_results[si] = survival_times1
            survival_times2_results[si] = survival_times2
            writer.close()

            subfolder_path = f"{save_prefix}_seed_{seed}_stage_{stage_steps}"
            os.makedirs(subfolder_path, exist_ok=True)

            with open(os.path.join(subfolder_path, f'scores_results_stage_{stage_steps}.txt'), 'w') as file:
                file.write(f'{[float(x) for x in scores]}\n')
            
            with open(os.path.join(subfolder_path, f'rate_results_stage_{stage_steps}.txt'), 'w') as file:
                file.write(f'{[float(x) for x in rate]}\n')
            
            with open(os.path.join(subfolder_path, f'rates_results_stage_{stage_steps}.txt'), 'w') as file:
                file.write(f'{[float(x) for x in rates]}\n')
            
            with open(os.path.join(subfolder_path, f'survival_times1_results_stage_{stage_steps}.txt'), 'w') as file:
                file.write(f'{[float(x) for x in survival_times1]}\n')
            
            with open(os.path.join(subfolder_path, f'survival_times2_results_stage_{stage_steps}.txt'), 'w') as file:
                file.write(f'{[float(x) for x in survival_times2]}\n')


