###Reinforcement Learning for Recurrent Event Data
###Code with proposed Q-function

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
import random
import gym
import math
from torch.utils.tensorboard import SummaryWriter
from collections import deque, namedtuple
import time

import os

os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
from torch.autograd import Variable
import pandas as pd
import lifelines
from lifelines import CoxPHFitter
from lifelines import KaplanMeierFitter
from scipy.optimize import minimize
from sklearn import linear_model
from sklearn.linear_model import LinearRegression


# from sklearn.linear_model import LogisticRegression
######################################## Survival Analysis ########################################
def Survival_interval_T(time, X, U, V, A):
    A = np.reshape(A, (A.shape[0], 1))
    if np.all(A == 0):
        print("A is constant 0!")
        data = pd.DataFrame(X, columns=['X1', 'X2'])
        pre_data = pd.DataFrame(X, columns=['X1', 'X2'])
    elif np.all(A == 1):
        print("A is constant 1!")
        data = pd.DataFrame(np.hstack([X, np.multiply(X, A)]), columns=['X1', 'X2', 'X1*A', 'X2*A'])
        pre_data = pd.DataFrame(np.hstack([X, np.multiply(X, A)]), columns=['X1', 'X2', 'X1*A', 'X2*A'])
    else:
        data = pd.DataFrame(np.hstack([X, A, np.multiply(X, A)]), columns=['X1', 'X2', 'A', 'X1*A', 'X2*A'])
        pre_data = pd.DataFrame(np.hstack([X, A, np.multiply(X, A)]), columns=['X1', 'X2', 'A', 'X1*A', 'X2*A'])

    data['U'] = U.flatten()
    data['V'] = V.flatten()

    cph = lifelines.fitters.log_logistic_aft_fitter.LogLogisticAFTFitter()
    cph.fit_interval_censoring(data, lower_bound_col='U', upper_bound_col='V')#, show_progress=True)


    st_estimate = cph.predict_survival_function(pre_data, times=time)

    return st_estimate.values


class QR_DQN(nn.Module):
    def __init__(self, state_size, action_size, layer_size, seed):
        super(QR_DQN, self).__init__()
        self.seed = torch.manual_seed(seed)
        self.input_shape = state_size[1]
        self.action_size = action_size
        self.head_1 = nn.Linear(self.input_shape, layer_size)
        self.ff_2 = nn.Linear(layer_size, action_size)

    def forward(self, x):
        x = torch.relu(self.head_1(x))
        x = nn.functional.softplus(self.ff_2(x))
        # for positive value
        return -x


def normalization(data):
    _range = np.max(data) - np.min(data)
    return (data - np.min(data)) / _range


class CustomEnv(gym.Env):
    def __init__(self, max_steps, num_agents):
        super(CustomEnv, self).__init__()

        self.action_space = gym.spaces.Discrete(2)
        self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(num_agents, 2), dtype=np.float32)
        self.num_agents = num_agents

        self.max_steps = max_steps
        self.current_step = 0
        self.stage = 0

    def seed(self, seed=None):
        self.seed_value = seed

    def reset(self):
        if self.seed_value is not None:
            np.random.seed(self.seed_value)
            self.state = np.random.randn(self.num_agents, 2)
            self.censor = np.random.uniform(low=0, high=10 * 7, size=(self.num_agents))
            self.censorU = np.random.uniform(0.5, 1.5, size=(self.num_agents))
            self.censorV = self.censorU + np.random.uniform(3, 4.5- self.stage*0.1, size=(self.num_agents))
        else:
            self.censor = np.random.uniform(low=0, high=10 * 7, size=(self.num_agents))
            self.censorU = np.random.uniform(0.5, 1.5, size=(self.num_agents))
            self.censorV = self.censorU + np.random.uniform(3, 4.5- self.stage*0.1, size=(self.num_agents))
            self.state = np.zeros(self.num_agents, 2)
        self.current_step = 0
        self.stage = 0
        self.observe = np.ones(self.num_agents)

        return self.state, self.observe

    def step(self, action):

        # Generate a four-dimensional vector from N(0,1)
        reward, real_actions, Delta, ratio = self.reward_function(action, self.state, self.observe)
        state_next = self.state_function(self.state, action)
        next_observe = self.observe_function(self.observe, Delta)
        self.state = state_next
        self.observe = next_observe
        done = self.current_step >= self.max_steps

        self.current_step += 1

        return state_next, reward, next_observe, real_actions, Delta, ratio, done

    def observe_function(self, observe, Delta):
        next_observe = observe.copy()
        indices_matrix = np.nonzero(next_observe)[0]
        next_observe[indices_matrix[Delta == 0]] = 0
        return next_observe

    def state_function(self, state, action):
        # Compute next_state
        A = np.array(action)
        next_state = np.zeros(state.shape)
        for i in range(state.shape[0]):
            beta = np.array([[3 / 4 * (2 * A[i][0] - 1), 0], [0, 3 / 4 * (1 - 2 * A[i][0])]])
            next_state[i] = np.dot(state[i], beta)
        next_state = next_state + np.random.normal(0, 0.25, size=state.shape)
        return next_state

    def reward_function(self, action, state, observe):
        action = np.array(action)
        action = action[:, 0]
        beta_1 = np.array([0, 2, -1])
        state_1 = np.insert(state, 0, 1, axis=1)

        lambda_param = np.exp((2 * action - 1) * (np.dot(state_1, beta_1) - 0.25))
        T = np.random.exponential(scale=lambda_param)
        T = np.clip(T, None, 4)

        self.censorU = np.random.uniform(0.5, 1.5, size=(self.num_agents))
        self.censorV = self.censorU + np.random.uniform(3, 4.5- self.stage*0.1, size=(self.num_agents))
        self.stage += 1

        Delta = (T <= self.censorV).astype(int)
        indices = np.where(observe == 1)[0]
        Delta = Delta[indices]
        state = state[indices]
        action = action[indices]
        state_1 = state_1[indices]

        T_valid = T[indices]
        U_valid = self.censorU[indices]
        V_valid = self.censorV[indices]

        L = np.zeros_like(T_valid)
        R = np.zeros_like(T_valid)

        count_0_U = 0
        count_U_V = 0
        count_V_inf = 0

        for i in range(len(T_valid)):
            if T_valid[i] <= U_valid[i]:
                L[i], R[i] = 0.1, U_valid[i]
                count_0_U += 1
            elif U_valid[i] < T_valid[i] <= V_valid[i]:
                L[i], R[i] = U_valid[i], V_valid[i]
                count_U_V += 1
            else:
                L[i], R[i] = V_valid[i], np.inf
                count_V_inf += 1

        reward_1 = np.exp(-np.dot(state_1, beta_1) + 0.25)
        reward_2 = np.exp(np.dot(state_1, beta_1) - 0.25)

        real_action = np.argmax([reward_1, reward_2], axis=0)
        time = 2
        reward = Survival_interval_T(time, np.array(state), L, R, np.array(action))

        # get log res
        reward = np.log(reward)

        sum = count_0_U + count_U_V + count_V_inf
        ans2 = [count_0_U, count_U_V, count_V_inf]

        return reward, real_action, Delta, ans2

    def render(self, mode='human'):
        pass

    def close(self):
        pass


class ReplayBuffer:
    """Fixed-size buffer to store experience tuples."""

    def __init__(self, buffer_size, batch_size, device, seed, gamma, n_step=1):
        """Initialize a ReplayBuffer object.
        Params
        ======
            buffer_size (int): maximum size of buffer
            batch_size (int): size of each training batch
            seed (int): random seed
        """
        self.device = device
        self.memory = deque(maxlen=buffer_size)
        self.batch_size = batch_size
        self.experience = namedtuple("Experience",
                                     field_names=["state", "action", "observe", "reward", "next_state", "done"])
        self.seed = random.seed(seed)
        self.gamma = gamma
        self.n_step = n_step
        self.n_step_buffer = deque(maxlen=self.n_step)

    def add(self, state, action, observe, reward, next_state, done):
        """Add a new experience to memory."""
        self.n_step_buffer.append((state, action, observe, reward, next_state, done))
        if len(self.n_step_buffer) == self.n_step:
            state, action, observe, reward, next_state, done = self.calc_multistep_return()
            e = self.experience(state, action, observe, reward, next_state, done)
            self.memory.append(e)

    def calc_multistep_return(self):
        Return = 0
        for idx in range(self.n_step):
            Return += self.gamma ** idx * self.n_step_buffer[idx][3]

        return self.n_step_buffer[0][0], self.n_step_buffer[0][1], self.n_step_buffer[0][2], Return, \
        self.n_step_buffer[-1][4], self.n_step_buffer[-1][5]

    def sample(self):
        """Randomly sample a batch of experiences from memory."""
        experiences = random.sample(self.memory, k=self.batch_size)

        states = torch.from_numpy(np.stack([e.state for e in experiences if e is not None])).float().to(self.device)
        observes = torch.from_numpy(np.stack([e.observe for e in experiences if e is not None])).long().to(self.device)
        actions = torch.from_numpy(np.stack([e.action for e in experiences if e is not None])).long().to(self.device)
        rewards = []
        for e in experiences:
            if e is not None:
                observe_indices = np.where(e.observe == 1)[0]
                reward = np.zeros_like(e.observe, dtype=np.float32)
                t = 0
                for idx in observe_indices:
                    reward[idx] = e.reward[0, t]
                    t += 1
                rewards.append(reward)
        rewards = torch.from_numpy(np.stack(rewards)).float().to(device)
        next_states = torch.from_numpy(np.stack([e.next_state for e in experiences if e is not None])).float().to(
            self.device)
        dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(
            self.device)

        return (states, actions, observes, rewards, next_states, dones)

    def __len__(self):
        """Return the current size of internal memory."""
        return len(self.memory)


class DQN_Agent():
    """Interacts with and learns from the environment."""

    def __init__(self,
                 state_size,
                 action_size,
                 Network,
                 layer_size,
                 n_step,
                 BATCH_SIZE,
                 BUFFER_SIZE,
                 LR,
                 TAU,
                 GAMMA,
                 UPDATE_EVERY,
                 device,
                 seed):
        """Initialize an Agent object.

        Params
        ======
            state_size (int): dimension of each state
            action_size (int): dimension of each action
            Network (str): dqn network type
            layer_size (int): size of the hidden layer
            BATCH_SIZE (int): size of the training batch
            BUFFER_SIZE (int): size of the replay memory
            LR (float): learning rate
            TAU (float): tau for soft updating the network weights
            GAMMA (float): discount factor
            UPDATE_EVERY (int): update frequency
            device (str): device that is used for the compute
            seed (int): random seed
        """
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(seed)
        self.device = device
        self.TAU = TAU
        self.GAMMA = GAMMA
        self.UPDATE_EVERY = UPDATE_EVERY
        self.BATCH_SIZE = BATCH_SIZE
        self.Q_updates = 0
        self.n_step = n_step

        self.qnetwork_local = QR_DQN(state_size, action_size, layer_size, seed).to(self.device)
        self.qnetwork_target = QR_DQN(state_size, action_size, layer_size, seed).to(self.device)

        self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)

        # Replay memory
        self.memory = ReplayBuffer(BUFFER_SIZE, BATCH_SIZE, self.device, seed, self.GAMMA, n_step)

        # Initialize time step (for updating every UPDATE_EVERY steps)
        self.t_step = 0

    def step(self, state, action, observe, reward, next_state, done, writer):
        # Save experience in replay memory

        self.memory.add(state, action, observe, reward, next_state, done)
        # Learn every UPDATE_EVERY time steps.

        self.t_step = (self.t_step + 1) % self.UPDATE_EVERY
        if self.t_step == 0:
            # If enough samples are available in memory, get random subset and learn
            if len(self.memory) > self.BATCH_SIZE:
                experiences = self.memory.sample()
                loss = self.learn(experiences)
                self.Q_updates += 1
                writer.add_scalar("Q_loss", loss, self.Q_updates)

    def act(self, state, eps=0.):
        """Returns actions for given state as per current policy. Acting only every 4 frames!

        Params
        ======
            frame: to adjust epsilon
            state (array_like): current state

        """

        state = np.array(state)
        state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)

        self.qnetwork_local.eval()
        with torch.no_grad():
            action_values = self.qnetwork_local.forward(state)
        self.qnetwork_local.train()
        action_values = action_values.squeeze(0)
        # Epsilon-greedy action selection
        if random.random() > eps:  # select greedy action if random number is higher than epsilon or noisy network is used!
            max_value, action = torch.max(action_values, dim=1)
            action = action.unsqueeze(1)
            return action, max_value
        else:
            action = torch.randint(low=0, high=action_values.shape[1], size=(action_values.shape[0], 1))
            print("eps action")
            max_value = action_values.gather(1, action.to(self.device))
            return action, max_value

    def learn(self, experiences):
        """Update value parameters using given batch of experience tuples.
        Params
        ======
            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples
            gamma (float): discount factor
        """
        self.optimizer.zero_grad()
        states, actions, observes, rewards, next_states, dones = experiences
        states = states[observes.bool()]
        next_states = next_states[observes.bool()]
        actions = actions[observes.bool()]
        rewards = rewards[observes.bool()]

        Q_targets_next = self.qnetwork_target(next_states).detach().cpu()
        action_indx = torch.argmax(Q_targets_next, dim=1)
        Q_targets_next = Q_targets_next.gather(1, action_indx.unsqueeze(1))

        Q_targets = rewards.unsqueeze(1) + self.GAMMA * Q_targets_next.to(self.device)
        Q_expected = self.qnetwork_local(states).gather(1, actions)
        loss = F.mse_loss(Q_targets, Q_expected)
        loss.backward()
        self.optimizer.step()

        # ------------------- update target network ------------------- #

        self.soft_update(self.qnetwork_local, self.qnetwork_target)
        return loss.detach().cpu().numpy()

    def soft_update(self, local_model, target_model):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target
        Params
        ======
            local_model (PyTorch model): weights will be copied from
            target_model (PyTorch model): weights will be copied to
            tau (float): interpolation parameter
        """
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(self.TAU * local_param.data + (1.0 - self.TAU) * target_param.data)




def run(episodes=1000, eps_fixed=False, eps_frames=1e6, min_eps=0.01, num_agents=1000):
    """Deep Q-Learning.

    Params
    ======
        episodes (int): maximum number of training episodes
        eps_fixed (bool): flag for using fixed epsilon or not
        eps_frames (float): number of frames for epsilon annealing
        min_eps (float): minimum value of epsilon
        alpha (float): exponent for score calculation
    """
    scores = []  # list containing scores from each episode
    accuracys = []
    rewards = [[] for _ in range(num_agents)]
    rate = []
    rates = []
    ratios = []
    scores_window = deque(maxlen=100)  # last 100 scores
    output_history = []
    frame = 0
    if eps_fixed:
        eps = 0
    else:
        eps = 1
    eps_start = 1
    i_episode = 1
    state, observe = env.reset()
    score = np.zeros(num_agents)
    for i_episode in range(1, episodes + 1):

        while True:
            action, _ = agent.act(state, eps)
            next_state, reward, next_observe, real_actions, Delta, ratio, done = env.step(action)
            agent.step(state, action, observe, reward, next_state, done, writer)
            indices = np.where(observe == 1)[0]
            our_actions = np.array(action).squeeze(1)
            our_actions = our_actions[indices]
            accuracy = 1 - np.count_nonzero(our_actions - real_actions) / len(real_actions)
            t = 0
            for idx in indices:
                rewards[idx].append(reward[0, t])
                t += 1
            accuracys.append(accuracy)
            ratios.append(ratio)
            print("accuracy:", accuracy)
            censor_rate = 1 - np.sum(observe) / len(observe)
            rate.append(censor_rate)
            state = next_state
            observe = next_observe
            t = 0
            for idx in indices:
                score[idx] = reward[0, t] + score[idx]
                t += 1
            frame += 1

            # linear annealing to the min epsilon value until eps_frames and from there slowly decease epsilon to 0 until the end of training
            if eps_fixed == False:
                if frame < eps_frames:
                    eps = max(eps_start - (frame * (1 / eps_frames)), min_eps)
                else:
                    # eps = max(min_eps - min_eps * ((frame - eps_frames) / (frame - eps_frames + 1)), 0.001)
                    eps = -1

            if done:
                scores_window.append(np.mean(score))  # save most recent score
                scores.append(np.mean(score))  # save most recent score
                rates.append(censor_rate)
                writer.add_scalar("Average100", np.mean(scores_window), frame)
                output_history.append(np.mean(scores_window))
                print('\rEpisode {} \tScores: {:.2f} \tCensor Rate: {:.2f} '.format(i_episode, float(np.mean(score)),
                                                                                    censor_rate), end="")
                state, observe = env.reset()
                score = np.zeros(num_agents)
                break

    return scores, accuracys, rewards, rates, rate, ratios

import warnings

if __name__ == "__main__":

    warnings.filterwarnings('ignore')

    seeds = [1]
    max_steps = 10
    max_steps = max_steps - 1
    episodes = 50
    num_peoples = [50, 100, 150, 200, 500, 1000, 5000]

    # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    device = torch.device("cpu")
    print("Selected device:", device)

    for seed in seeds:
        print("Seed:", seed)
        for people in range(len(num_peoples)):
            num_people = num_peoples[people]
            print(f"Running num {num_people} with seed {seed}")
            np.random.seed(seed)
            torch.manual_seed(seed)
            env = CustomEnv(max_steps=max_steps, num_agents=num_people)
            env.seed(seed)
            action_size = env.action_space.n
            state_size = env.observation_space.shape

            writer = SummaryWriter(f"runs/DQN_LL_new_1_num_{num_people}_seed_{seed}/")
            agent = DQN_Agent(state_size=state_size,
                              action_size=action_size,
                              Network="DDQN",
                              layer_size=128,
                              n_step=1,
                              BATCH_SIZE=64,
                              BUFFER_SIZE=6400,
                              LR=1e-3,
                              TAU=1e-4,
                              GAMMA=0.99,
                              UPDATE_EVERY=1,
                              device=device,
                              seed=seed)

            eps_fixed = False
            t0 = time.time()
            scores, accuracys, rewards, rates, rate, ratios = run(episodes=episodes, eps_fixed=eps_fixed, eps_frames=100,
                                                          min_eps=0.025, num_agents=num_people)
            t1 = time.time()
            print(f"Training time for num={num_people} with seed={seed}: {round((t1 - t0) / 60, 2)}min")

            # 创建存储结果的子文件夹
            subfolder_path = os.path.join("results", f"seed_{seed}")
            os.makedirs(subfolder_path, exist_ok=True)

            with open(os.path.join(subfolder_path, f'scores_results_{num_people}.txt'), 'w') as file:
                file.write(f'{scores}\n')
            with open(os.path.join(subfolder_path, f'accuracys_results_{num_people}.txt'), 'w') as file:
                file.write(f'{accuracys}\n')
            with open(os.path.join(subfolder_path, f'rates_results_{num_people}.txt'), 'w') as file:
                file.write(f'{ratios}\n')

            writer.close()
