### Experiment 2
### Nonlinear Environment

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
import random
import gym
import math
from torch.utils.tensorboard import SummaryWriter
from collections import deque, namedtuple
import time

import warnings
warnings.filterwarnings("ignore")

import os
os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'

from torch.autograd import Variable
import pandas as pd
from lifelines import CoxPHFitter
from lifelines import AalenAdditiveFitter
from lifelines import KaplanMeierFitter
import lifelines
from scipy.optimize import minimize
from sklearn import linear_model
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Selected device:", device)

from lifelines import CoxPHFitter
import numpy as np
import pandas as pd

###############################################################################################################
# Survival Regression / Analysis
###############################################################################################################

def Survival_T_Aalen(time, Y, X, A, Delta): 
    A = np.reshape(A, (A.shape[0], 1))
    X_adjusted = np.multiply(X, A.reshape(-1, 1))
    data = pd.DataFrame(X_adjusted, columns=[f'X{i}*A' for i in range(X_adjusted.shape[1])])
    pre_data = pd.DataFrame(X_adjusted, columns=[f'X{i}*A' for i in range(X_adjusted.shape[1])])

    data['Y'] = Y.flatten()
    data['Delta'] = Delta.flatten()

    pre_data['Y'] = Y.flatten()
    pre_data['Delta'] = Delta.flatten()

    cox = AalenAdditiveFitter(coef_penalizer=10000.0)
    cox.fit(data, duration_col='Y', event_col='Delta')
    st_estimate = cox.predict_survival_function(pre_data, times=time)
    st_estimate_values = st_estimate.values.reshape(1,-1)

    return st_estimate_values


def Survival_T_Cox(time, Y, X, A, Delta): 
    A = np.reshape(A, (A.shape[0], 1))
    X_adjusted = np.multiply(X, A.reshape(-1, 1))
    data = pd.DataFrame(X_adjusted, columns=[f'X{i}*A' for i in range(X_adjusted.shape[1])])
    pre_data = pd.DataFrame(X_adjusted, columns=[f'X{i}*A' for i in range(X_adjusted.shape[1])])

    data['Y'] = Y.flatten()
    data['Delta'] = Delta.flatten()

    pre_data['Y'] = Y.flatten()
    pre_data['Delta'] = Delta.flatten()

    cox = CoxPHFitter(penalizer=0.1)
    cox.fit(data, duration_col='Y', event_col='Delta', fit_options={"step_size": 0.1})
    st_estimate = cox.predict_survival_function(pre_data, times=time)

    return st_estimate.values


def Survival_C(time, Y, X, A, Delta):
    A = np.reshape(A, (A.shape[0], 1))
    data = pd.DataFrame(X, columns=[f'X{i}*A' for i in range(X.shape[1])])
    data['Y'] = Y.flatten()
    data['1-Delta'] = 1 - Delta.flatten()

    kmf = KaplanMeierFitter()
    kmf.fit(durations=data['Y'], event_observed=data['1-Delta'])

    sc_estimate = kmf.survival_function_at_times(time)

    return sc_estimate.values

###############################################################################################################
# Actor-Critic
###############################################################################################################

class Actor(nn.Module):
    def __init__(self, state_size, action_size, layer_size, seed):
        super(Actor, self).__init__()
        self.seed = torch.manual_seed(seed)
        self.fc1 = nn.Linear(state_size[1], layer_size)
        self.fc2 = nn.Linear(layer_size, layer_size)
        self.fc3 = nn.Linear(layer_size, action_size)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = torch.tanh(self.fc3(x))
        return x


class Critic(nn.Module):
    def __init__(self, state_size, action_size, layer_size, seed):
        super(Critic, self).__init__()
        self.seed = torch.manual_seed(seed)
        self.fc1 = nn.Linear(state_size[1], layer_size)
        self.fc2 = nn.Linear(layer_size + action_size, layer_size)
        self.fc3 = nn.Linear(layer_size, 1)

    def forward(self, state, action):
        x = F.relu(self.fc1(state))
        x = torch.cat((x, action), dim=1)
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def normalization(data):
    _range = np.max(data) - np.min(data)
    return (data - np.min(data)) / _range


###############################################################################################################
# CustomEnv
###############################################################################################################


class CustomEnv(gym.Env):
    def __init__(self, max_steps, num_agents, mode):
        super(CustomEnv, self).__init__()
        self.action_space = gym.spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float32)
        self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(num_agents, 50), dtype=np.float32)
        self.num_agents = num_agents
        self.max_steps = max_steps
        self.mode = mode  # 'aalen', 'cox', 'km'
        self.current_step = 0

    def seed(self, seed=None):
        self.seed_value = seed

    def reset(self):
        if self.seed_value is not None:
            np.random.seed(self.seed_value)
            self.state = np.random.randn(self.num_agents, 50)
            self.censor = np.random.uniform(low=0, high=10 * 7, size=(self.num_agents))
        else:
            self.censor = np.random.uniform(low=0, high=10 * 7, size=(self.num_agents))
            self.state = np.zeros((self.num_agents, 50))
        self.current_step = 0
        self.observe = np.ones(self.num_agents)
        return self.state, self.observe

    def step(self, action):
        reward, Delta, survival_time1, survival_time2 = self.reward_function(action, self.state, self.observe)
        done = self.current_step >= self.max_steps
        done = np.full_like(self.observe, done, dtype=bool)
        state_next = self.state_function(self.state, action)
        next_observe = self.observe_function(self.observe, Delta)
        self.state = state_next
        self.observe = next_observe
        
        self.current_step += 1
        return state_next, reward, next_observe, Delta, survival_time1, survival_time2, done

    def observe_function(self, observe, Delta):
        next_observe = observe.copy()
        indices_matrix = np.nonzero(next_observe)[0]
        next_observe[indices_matrix[Delta == 0]] = 0
        return next_observe

    def state_function(self, state, action):
        A = np.array(action).flatten()
        next_state = np.zeros(state.shape)
        for i in range(state.shape[0]):
            non_linear_term = np.sin(state[i]) + np.cos(state[i])
            beta = 3 / 4 * A[i]
            next_state[i] = beta * state[i] + non_linear_term + np.random.normal(0, 0.25, size=state[i].shape)
        return next_state
    
    
    def reward_function(self, action, state, observe):
        action = np.array(action).flatten()
        non_zero_count = 3
        beta_1 = np.zeros(51)
        beta_1[:non_zero_count] = np.random.randn(non_zero_count)
        state_1 = np.insert(state, 0, 1, axis=1)

        linear_predictor = action * np.dot(state_1, beta_1)
        baseline_hazard = 1
        param = 1 / (np.exp(linear_predictor) * baseline_hazard)
        T = np.random.gamma(shape = 1 / param, scale = param)

        survival_time1 = np.mean(param)
        T = np.clip(T, None, 7)
        survival_time2 = np.mean(T)

        Y = np.minimum(T, self.censor)
        Delta = (T <= self.censor).astype(int)
        self.censor = self.censor - Y
        indices = np.where(observe == 1)[0]
        Y = Y[indices]
        Delta = Delta[indices]
        state = state[indices]
        action = action[indices]
        state_1 = state_1[indices]

        if self.mode == 'aalen':
            time = 0.5
            reward = Survival_T_Aalen(time, Y, np.array(state), np.array(action), Delta)
            reward = np.log(reward + 1e-5)
            return reward, Delta, survival_time1, survival_time2

        if self.mode == 'cox':
            time = 0.5
            reward = Survival_T_Cox(time, Y, np.array(state), np.array(action), Delta)
            reward = np.log(reward + 1e-5)
            return reward, Delta, survival_time1, survival_time2

        # KM
        S = Survival_C(Y, Y, np.array(state), np.array(action), Delta)
        reward = np.zeros_like(Y)
        non_zero_indices = Delta != 0
        reward[non_zero_indices] = (Y[non_zero_indices] * Delta[non_zero_indices]) / S[non_zero_indices]
        reward = reward.reshape(1, -1)
        return reward, Delta, survival_time1, survival_time2

###############################################################################################################
# ReplayBuffer
###############################################################################################################

class ReplayBuffer:
    """Fixed-size buffer to store experience tuples."""

    def __init__(self, buffer_size, batch_size, device, seed, gamma, n_step=1):
        """Initialize a ReplayBuffer object.
        Params
        ======
            buffer_size (int): maximum size of buffer
            batch_size (int): size of each training batch
            seed (int): random seed
        """
        self.device = device
        self.memory = deque(maxlen=buffer_size)
        self.batch_size = batch_size
        self.experience = namedtuple("Experience", field_names=["state", "action", "observe", "reward", "next_state", "done"])
        self.seed = random.seed(seed)
        self.gamma = gamma
        self.n_step = n_step
        self.n_step_buffer = deque(maxlen=self.n_step)

    def add(self, state, action, observe, reward, next_state, done):
        """Add a new experience to memory."""
        self.n_step_buffer.append((state, action, observe, reward, next_state, done))
        if len(self.n_step_buffer) == self.n_step:
            state, action, observe, reward, next_state, done = self.calc_multistep_return()
            e = self.experience(state, action, observe, reward, next_state, done)
            self.memory.append(e)

    def calc_multistep_return(self):
        Return = 0
        for idx in range(self.n_step):
            Return += self.gamma ** idx * self.n_step_buffer[idx][3]
        return self.n_step_buffer[0][0], self.n_step_buffer[0][1], self.n_step_buffer[0][2], Return, self.n_step_buffer[-1][4], self.n_step_buffer[-1][5]

    def sample(self):
        """Randomly sample a batch of experiences from memory."""
        experiences = random.sample(self.memory, k=self.batch_size)

        states = torch.from_numpy(np.stack([e.state for e in experiences if e is not None])).float().to(self.device)
        observes = torch.from_numpy(np.stack([e.observe for e in experiences if e is not None])).long().to(self.device)
        actions = torch.from_numpy(np.stack([e.action for e in experiences if e is not None])).long().to(self.device)

        rewards = []
        for e in experiences:
            if e is not None:
                observe_indices = np.where(e.observe == 1)[0]
                reward = np.zeros_like(e.observe, dtype=np.float32)
                t = 0
                for idx in observe_indices:
                    try:
                        reward[idx] = e.reward[0, t]  # Aalen / Cox
                    except Exception:
                        reward[idx] = e.reward[t]  # KM
                    t += 1
                rewards.append(reward)
        rewards = torch.from_numpy(np.stack(rewards)).float().to(device)
        next_states = torch.from_numpy(np.stack([e.next_state for e in experiences if e is not None])).float().to(self.device)
        dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(self.device)

        return (states, actions, observes, rewards, next_states, dones)

    def __len__(self):
        """Return the current size of internal memory."""
        return len(self.memory)

###############################################################################################################
# DQN_Agent
###############################################################################################################

class DDPG_Agent:
    def __init__(self, state_size, action_size, layer_size, n_step, BATCH_SIZE, BUFFER_SIZE, LR_ACTOR, LR_CRITIC, TAU, GAMMA, UPDATE_EVERY, device, seed):
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(seed)
        self.device = device
        self.TAU = TAU
        self.GAMMA = GAMMA
        self.UPDATE_EVERY = UPDATE_EVERY
        self.BATCH_SIZE = BATCH_SIZE
        self.Q_updates = 0
        self.n_step = n_step

        self.actor_local = Actor(state_size, action_size, layer_size, seed).to(self.device)
        self.actor_target = Actor(state_size, action_size, layer_size, seed).to(self.device)
        self.critic_local = Critic(state_size, action_size, layer_size, seed).to(self.device)
        self.critic_target = Critic(state_size, action_size, layer_size, seed).to(self.device)

        self.optimizer_actor = optim.Adam(self.actor_local.parameters(), lr=LR_ACTOR)
        self.optimizer_critic = optim.Adam(self.critic_local.parameters(), lr=LR_CRITIC)

        self.memory = ReplayBuffer(BUFFER_SIZE, BATCH_SIZE, self.device, seed, self.GAMMA, n_step)

        self.t_step = 0

    def step(self, state, action, observe, reward, next_state, done, writer):
        self.memory.add(state, action, observe, reward, next_state, done)
        self.t_step = (self.t_step + 1) % self.UPDATE_EVERY
        if self.t_step == 0:
            if len(self.memory) > self.BATCH_SIZE:
                experiences = self.memory.sample()
                self.learn(experiences)
                self.Q_updates += 1
                writer.add_scalar("Actor_loss", self.actor_loss, self.Q_updates)
                writer.add_scalar("Critic_loss", self.critic_loss, self.Q_updates)

    def act(self, state, noise=0.):
        state = np.array(state)
        state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
        self.actor_local.eval()
        with torch.no_grad():
            action = self.actor_local(state)
        self.actor_local.train()
        action = action + noise * torch.randn(self.action_size, device=self.device)
        action = torch.clamp(action, -1, 1)
        return action.cpu().numpy()

    def learn(self, experiences):
        states, actions, observes, rewards, next_states, dones = experiences

        states = states[observes.bool()]
        next_states = next_states[observes.bool()]
        dones = dones[observes.bool()]
        actions = actions.squeeze(1).squeeze(-1)[observes.bool()]
        actions = actions.unsqueeze(1)
        rewards = rewards[observes.bool()]
        

        # Update Critic
        self.optimizer_critic.zero_grad()
        actions_next = self.actor_target(next_states)
        Q_targets_next = self.critic_target(next_states, actions_next).detach()
        Q_targets = rewards.unsqueeze(1) + (self.GAMMA * Q_targets_next * (1 - dones.unsqueeze(1)))
        Q_expected = self.critic_local(states, actions)
        self.critic_loss = F.mse_loss(Q_expected, Q_targets)
        self.critic_loss.backward()
        self.optimizer_critic.step()

        # Update Actor
        self.optimizer_actor.zero_grad()
        actions_pred = self.actor_local(states)
        self.actor_loss = -self.critic_local(states, actions_pred).mean()
        self.actor_loss.backward()
        self.optimizer_actor.step()

        # Soft update target networks
        self.soft_update(self.actor_local, self.actor_target)
        self.soft_update(self.critic_local, self.critic_target)

    def soft_update(self, local_model, target_model):
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(self.TAU * local_param.data + (1.0 - self.TAU) * target_param.data)

###############################################################################################################
# Training Loop
###############################################################################################################

def run(episodes=1000, eps_fixed=False, eps_frames=1e6, min_eps=0.01, num_agents=1000):
    scores = []
    rewards = [[] for _ in range(num_agents)]
    rate = []
    rates = []
    survival_times1 = []
    survival_times2 = []
    scores_window = deque(maxlen=100)
    survival_times1_window = deque(maxlen=20)
    survival_times2_window = deque(maxlen=20)
    output_history = []
    frame = 0
    if eps_fixed:
        noise = 0
    else:
        noise = 1
    noise_start = 1
    i_episode = 1
    state, observe = env.reset()
    score = np.zeros(num_agents)
    for i_episode in range(1, episodes + 1):
        while True:
            action = agent.act(state, noise)
            # print(action)
            next_state, reward, next_observe, Delta, survival_time1, survival_time2, done = env.step(action)

            agent.step(state, action, observe, reward, next_state, done, writer)
            indices = np.where(observe == 1)[0]
            our_actions = np.array(action).squeeze()
            our_actions = our_actions[indices]
            t = 0
            for idx in indices:
                rewards[idx].append(reward[0, t])
                t += 1
            survival_times1.append(survival_time1)
            survival_times2.append(survival_time2)
            survival_times1_window.append(survival_time1)
            survival_times2_window.append(survival_time2)
            censor_rate = 1 - np.sum(observe) / len(observe)
            rate.append(censor_rate)
            state = next_state
            observe = next_observe
            t = 0
            for idx in indices:
                score[idx] = reward[0, t] + score[idx]
                t += 1
            frame += 1

            if not eps_fixed:
                if frame < eps_frames:
                    noise = max(noise_start - (frame * (1 / eps_frames)), min_eps)
                else:
                    noise = min_eps

            if done.all():
                scores_window.append(np.mean(score))
                scores.append(np.mean(score))
                rates.append(censor_rate)
                writer.add_scalar("Average100", np.mean(scores_window), frame)
                output_history.append(np.mean(scores_window))
                print('\rEpisode {} \tAEV: {:.2f} \tARED: {:.2f} '.format(i_episode, float(np.sum(survival_times1_window)), float(np.sum(survival_times2_window))), end="")
                state, observe = env.reset()
                score = np.zeros(num_agents)
                break

    return scores, rewards, rates, rate, survival_times1, survival_times2

###############################################################################################################
# Train All for Mode
###############################################################################################################

def train_all_for_mode(mode, seeds, save_prefix):
    max_steps = 20
    max_steps = max_steps - 1
    episodes = 50
    num_peoples = [10000]
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Selected device:", device)

    neg_softplus_flag = (mode != 'km')

    for seed in seeds:
        scores_results = {}
        rate_results = {}
        rates_results = {}
        survival_times1_results = {}
        survival_times2_results = {}

        for people in range(len(num_peoples)):
            num_people = num_peoples[people]
            print(f"Running num {num_people} with seed {seed}")
            np.random.seed(seed)
            torch.manual_seed(seed)
            global env
            env = CustomEnv(max_steps=max_steps, num_agents=num_people, mode=mode)
            env.seed(seed)
            action_size = env.action_space.shape[0]
            state_size = env.observation_space.shape

            global writer
            writer = SummaryWriter(f"DQNt_LL_new_1_num_{num_people}_seed_{seed}/")

            global agent
            agent = DDPG_Agent(state_size=state_size,
                               action_size=action_size,
                               layer_size=64,
                               n_step=1,
                               BATCH_SIZE=32,
                               BUFFER_SIZE=6400,
                               LR_ACTOR=1e-4,
                               LR_CRITIC=1e-3,
                               TAU=1e-2,
                               GAMMA=0.99,
                               UPDATE_EVERY=1,
                               device=device,
                               seed=seed)

            eps_fixed = False
            t0 = time.time()
            scores, rewards, rates, rate, survival_times1, survival_times2 = run(episodes=episodes, eps_fixed=eps_fixed, eps_frames=100, min_eps=0.025, num_agents=num_people)
            t1 = time.time()
            print(f"Training time for num={num_people} with seed {seed}: {round((t1 - t0) / 60, 2)}min")

            scores_results[people] = scores
            rate_results[people] = rate
            rates_results[people] = rates
            survival_times1_results[people] = survival_times1
            survival_times2_results[people] = survival_times2
            writer.close()

            subfolder_path = f"{save_prefix}_seed_{seed}"
            os.makedirs(subfolder_path, exist_ok=True)

            with open(os.path.join(subfolder_path, f'scores_results_{num_people}.txt'), 'w') as file:
                file.write(f'{[float(x) for x in scores]}\n')
            
            with open(os.path.join(subfolder_path, f'rate_results_{num_people}.txt'), 'w') as file:
                file.write(f'{[float(x) for x in rate]}\n')
            
            with open(os.path.join(subfolder_path, f'rates_results_{num_people}.txt'), 'w') as file:
                file.write(f'{[float(x) for x in rates]}\n')
            
            with open(os.path.join(subfolder_path, f'survival_times1_results_{num_people}.txt'), 'w') as file:
                file.write(f'{[float(x) for x in survival_times1]}\n')
            
            with open(os.path.join(subfolder_path, f'survival_times2_results_{num_people}.txt'), 'w') as file:
                file.write(f'{[float(x) for x in survival_times2]}\n')
