import os
import shutil
import sys
import time
import datetime
import numpy as np
import torch
import torch.nn.functional as func
import gym
from tqdm import tqdm

gym.logger.set_level(40)

def evaluate_policy(policy, eval_episodes = 50):
    reward_list = []

    for _ in range(eval_episodes):
        cumulated_reward = 0.0
        obs = env_eval.reset()
        done = False
        while not done: 
            action = policy.select_action(np.array(obs))
            obs, reward, done, _ = env_eval.step(action)
            cumulated_reward += reward
        reward_list.append(cumulated_reward)
    reward_mean = np.mean(reward_list, dtype = np.float64)
    reward_std = np.std(reward_list, dtype = np.float64)
    print("-----------------------------------------------------------")
    print("Average Reward over the Evaluation Step : {:.2f} ± {:.2f}".format(reward_mean, reward_std))
    print("-----------------------------------------------------------")
    
    return (reward_mean, reward_std)

data_dir = "./data"
if (not os.path.exists(data_dir)): os.mkdir(data_dir)

env_name = "InvertedPendulum-v2"

seed = 1_000
start_timesteps = 10_000 
eval_freq = 5_000 
max_timesteps = 1_000_000
expl_noise = 0.1 
batch_size = 100 
discount = 0.99 
tau = 0.005 
policy_noise = 0.2 
noise_clip = 0.5 
policy_freq = 2 
num_eval_samples = 10 

class ReplayBuffer(object):    
    def __init__(self, max_size = 1e6): 
        self.storage = [] 
        self.max_size = max_size 
        self.ptr = 0 
    
    def add(self, transition):
        if len(self.storage) == self.max_size: 
            self.storage[int(self.ptr)] = transition 
            self.ptr = (self.ptr + 1) % self.max_size 
        else: 
            self.storage.append(transition)
    
    def sample(self, batch_size):
        ind = np.random.randint(0, len(self.storage), size = batch_size) 
        batch_states, batch_next_states, batch_actions, batch_rewards, batch_dones = [], [], [], [], []
                
        for i in ind: 
            state, next_state, action, reward, done = self.storage[i]
            batch_states.append(np.array(state, copy = False))
            batch_next_states.append(np.array(next_state, copy = False))
            batch_actions.append(np.array(action, copy = False))
            batch_rewards.append(np.array(reward, copy = False))
            batch_dones.append(np.array(done, copy = False))
        
        return np.array(batch_states),\
               np.array(batch_next_states),\
               np.array(batch_actions),\
               np.array(batch_rewards).reshape(-1, 1),\
               np.array(batch_dones).reshape(-1, 1)
               
class Actor(torch.nn.Module): 
    def __init__(self, state_dim, action_dim, max_action): 
        super(Actor, self).__init__() 
        
        self.layer_1 = torch.nn.Linear(state_dim, 400) 
        self.layer_2 = torch.nn.Linear(400, 300) 
        self.layer_3 = torch.nn.Linear(300, action_dim) 
        self.max_action = max_action 
    
    def forward(self, x): 
        x = func.relu(self.layer_1(x))
        x = func.relu(self.layer_2(x))
        x = self.max_action * torch.tanh(self.layer_3(x)) 
    
        return x

class Critic(torch.nn.Module): 
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()
        
        self.layer_1 = torch.nn.Linear(state_dim + action_dim, 400)
        self.layer_2 = torch.nn.Linear(400, 300)
        self.layer_3 = torch.nn.Linear(300, 1)
        
        self.layer_4 = torch.nn.Linear(state_dim + action_dim, 400)
        self.layer_5 = torch.nn.Linear(400, 300)
        self.layer_6 = torch.nn.Linear(300, 1)
        
    def forward(self, x, u): 
        xu = torch.cat([x, u], 1) 
        
        x1 = func.relu(self.layer_1(xu))
        x1 = func.relu(self.layer_2(x1))
        x1 = self.layer_3(x1)
        
        x2 = func.relu(self.layer_4(xu))
        x2 = func.relu(self.layer_5(x2))
        x2 = self.layer_6(x2)
  
        return x1, x2
    
    def Q1(self, x, u): 
        xu = torch.cat([x, u], 1) 
        x1 = func.relu(self.layer_1(xu))
        x1 = func.relu(self.layer_2(x1))
        x1 = self.layer_3(x1)
  
        return x1

class TD3(object):
    def __init__(self, state_dim, action_dim, max_action): 
        self.actor = Actor(state_dim, action_dim, max_action).to(device)
        self.actor_target = Actor(state_dim, action_dim, max_action).to(device)
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters()) 
        
        self.critic = Critic(state_dim, action_dim).to(device)
        self.critic_target = Critic(state_dim, action_dim).to(device)
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters()) 
        self.max_action = max_action 
    
    def select_action(self, state):
        state = torch.Tensor(state.reshape(1, -1)).to(device) 
        
        return self.actor(state).cpu().data.numpy().flatten() 
    
    def update_params(self, replay_buffer, iterations, batch_size = 100, gamma = 0.99, tau = 0.005, policy_noise = 0.2, noise_clip = 0.5, policy_update_freq = 2): 
        for it in range(iterations):
            batch_states, batch_next_states, batch_actions, batch_rewards, batch_dones = replay_buffer.sample(batch_size)
            state = torch.Tensor(batch_states).to(device)
            next_state = torch.Tensor(batch_next_states).to(device)
            action = torch.Tensor(batch_actions).to(device)
            reward = torch.Tensor(batch_rewards).to(device)
            done = torch.Tensor(batch_dones).to(device)
            
            next_action = self.actor_target(next_state)
            noise = torch.Tensor(batch_actions).data.normal_(0, policy_noise).to(device)
            noise = noise.clamp(-noise_clip, noise_clip)
            next_action = (next_action + noise).clamp(-self.max_action, self.max_action)
        
            target_Q1, target_Q2 = self.critic_target(next_state, next_action)
            target_Q = torch.min(target_Q1, target_Q2)
            target_Q = reward + ((1 - done) * gamma * target_Q).detach()
            
            current_Q1, current_Q2 = self.critic(state, action)
            
            critic_loss = func.mse_loss(current_Q1, target_Q) + func.mse_loss(current_Q2, target_Q)
            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            self.critic_optimizer.step()
            
            if it % policy_update_freq == 0:
                actor_loss = -self.critic.Q1(state, self.actor(state)).mean()
                self.actor_optimizer.zero_grad()
                actor_loss.backward()
                self.actor_optimizer.step()
                
                for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                    target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
                
                for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                    target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
    
    def save(self, filename, directory):
        torch.save(self.actor.state_dict(), '%s/%s_actor.pth' % (directory, filename))
        torch.save(self.critic.state_dict(), '%s/%s_critic.pth' % (directory, filename))
    
    def load(self, filename, directory):
        self.actor.load_state_dict(torch.load('%s/%s_actor.pth' % (directory, filename)))
        self.critic.load_state_dict(torch.load('%s/%s_critic.pth' % (directory, filename)))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

env = gym.make(env_name)
env_eval = gym.make(env_name)

env.seed(seed)
env.action_space.seed(seed)
env_eval.seed(seed)
env_eval.action_space.seed(seed)

torch.manual_seed(seed)
np.random.seed(seed)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])

td3 = TD3(state_dim, action_dim, max_action)

reward_dict = {}

def training_process(cur_model, replay_buffer, model_name):

    print("\t-----------------------------------")
    print("\t\tMODEL: {}".format(model_name.upper()))
    print("\t-----------------------------------")
        
    reward_dict.update({0: evaluate_policy(cur_model, num_eval_samples)})

    timesteps_since_eval = 0
    episode_num = 0
    done = True

    for cur_timestep in tqdm(range(max_timesteps)):
        if done:
            if cur_timestep != 0:
                cur_model.update_params(replay_buffer, episode_timesteps, batch_size, discount, tau, policy_noise, noise_clip, policy_freq)
            
            if timesteps_since_eval >= eval_freq:
                timesteps_since_eval %= eval_freq
                reward_dict.update({cur_timestep : evaluate_policy(cur_model, num_eval_samples)})
            
            obs = env.reset()
            done = False
            episode_reward = 0
            episode_timesteps = 0
            episode_num += 1
        
        if cur_timestep < start_timesteps:
            action = env.action_space.sample()
        else:
            action = cur_model.select_action(np.array(obs))
            if expl_noise != 0:
                action = (action + np.random.normal(0, expl_noise, size = env.action_space.shape[0])).clip(env.action_space.low, env.action_space.high)
        
        new_obs, reward, done, _ = env.step(action)        
        done_bool = 0 if episode_timesteps + 1 == env._max_episode_steps else float(done)

        episode_reward += reward
        
        replay_buffer.add((obs, new_obs, action, reward, done_bool))

        obs = new_obs
        episode_timesteps += 1
        timesteps_since_eval += 1
    
    reward_dict.update({cur_timestep : evaluate_policy(cur_model)})
    reward_mat = np.array([(key, val[0], val[1]) for (key, val) in reward_dict.items()])

    td3.save(filename = "{}_td3_{}".format(env.spec.id, seed), directory = data_dir)
    np.save(os.path.join(data_dir, "{}_td3_{}_reward".format(env.spec.id, seed)), reward_mat)

replay_buffer = ReplayBuffer()
training_process(td3, replay_buffer, "td3")
