import os
import shutil
import sys
import time
import datetime
import numpy as np
import torch
import torch.nn.functional as func
import gym
from tqdm import tqdm

gym.logger.set_level(40)

def evaluate_policy(policy, eval_episodes = 50):
    reward_list = []

    for _ in range(eval_episodes):
        cumulated_reward = 0.0
        obs = env_eval.reset()
        done = False
        while not done: 
            action = policy.select_action(np.array(obs))
            obs, reward, done, _ = env_eval.step(action)
            cumulated_reward += reward
        reward_list.append(cumulated_reward)
    reward_mean = np.mean(reward_list, dtype = np.float64)
    reward_std = np.std(reward_list, dtype = np.float64)
    print("-----------------------------------------------------------")
    print("Average Reward over the Evaluation Step : {:.2f} ± {:.2f}".format(reward_mean, reward_std))
    print("-----------------------------------------------------------")
    
    return (reward_mean, reward_std)

def assign_weights(m):
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight, gain=1)
        torch.nn.init.constant_(m.bias, 0)

def soft_update(target, source, tau): 
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)

def hard_update(target, source):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(param.data)

data_dir = "./data"
if (not os.path.exists(data_dir)): os.mkdir(data_dir)

env_name = "InvertedPendulum-v2"

seed = 1_000 
start_timesteps = 10_000 
eval_freq = 5_000 
max_timesteps = 1_000_000 
exploration_noise = 0.1 
batch_size = 100 
discount = 0.99 
tau = 0.005 
policy_noise = 0.2 
noise_clip = 0.5 
policy_update_freq = 2 
num_eval_samples = 50 

alpha = 0.0003 
learning_rate = 0.0003
auto_entropy_tuning = True 
LOG_SIG_MAX = 2 
LOG_SIG_MIN = -20 

class ReplayBuffer(object):    
    def __init__(self, max_size = 1e6): 
        self.storage = [] 
        self.max_size = max_size 
        self.ptr = 0 
    
    def add(self, transition):
        if len(self.storage) == self.max_size: 
            self.storage[int(self.ptr)] = transition 
            self.ptr = (self.ptr + 1) % self.max_size 
        else: 
            self.storage.append(transition)
    
    def sample(self, batch_size):
        ind = np.random.randint(0, len(self.storage), size = batch_size) 
        batch_states, batch_next_states, batch_actions, batch_rewards, batch_dones = [], [], [], [], []
        
        for i in ind: 
            state, next_state, action, reward, done = self.storage[i]
            batch_states.append(np.array(state, copy = False))
            batch_next_states.append(np.array(next_state, copy = False))
            batch_actions.append(np.array(action, copy = False))
            batch_rewards.append(np.array(reward, copy = False))
            batch_dones.append(np.array(done, copy = False))
        
        return np.array(batch_states),\
               np.array(batch_next_states),\
               np.array(batch_actions),\
               np.array(batch_rewards).reshape(-1, 1),\
               np.array(batch_dones).reshape(-1, 1)

class GaussianPolicy(torch.nn.Module): 
    def __init__(self, state_dim, action_dim, action_space=None): 
        super(GaussianPolicy, self).__init__() 
  
        self.layer1 = torch.nn.Linear(state_dim, 512) 
        self.layer2 = torch.nn.Linear(512, 512) 
        self.mean_linear = torch.nn.Linear(512, action_dim)
        self.variance_linear = torch.nn.Linear(512, action_dim)
        self.epsilon = 1e-6
        self.apply(assign_weights)

        if action_space is None:
            self.action_scale = torch.tensor(1.)
            self.action_bias = torch.tensor(0.)
        else:
            self.action_scale = torch.FloatTensor((action_space.high - action_space.low) / 2.)
            self.action_bias = torch.FloatTensor((action_space.high + action_space.low) / 2.)

    def forward(self, state): 
        x = func.relu(self.layer1(state))
        x = func.relu(self.layer2(x))
        mean = self.mean_linear(x)
        variance = self.variance_linear(x)
        variance = torch.clamp(variance, min=LOG_SIG_MIN, max=LOG_SIG_MAX)

        return mean, variance

    def sample(self, state):
        mean, variance = self.forward(state)
        std = variance.exp()
        normal = torch.distributions.Normal(mean, std)
        xt = normal.rsample() 
        yt = torch.tanh(xt)
        action = yt * self.action_scale + self.action_bias
        log_prob = normal.log_prob(xt)

        log_prob -= torch.log(self.action_scale * (1 - yt.pow(2)) + self.epsilon)
        log_prob = log_prob.sum(1, keepdim=True)
        mean = torch.tanh(mean) * self.action_scale + self.action_bias

        return action, log_prob, mean

    def to(self, device):
        self.action_scale = self.action_scale.to(device)
        self.action_bias = self.action_bias.to(device)

        return super(GaussianPolicy, self).to(device)
  
class QNetwork(torch.nn.Module):
    def __init__(self, state_dim, action_dim):
        super(QNetwork, self).__init__()
  
        self.linear1 = torch.nn.Linear(state_dim + action_dim, 512)
        self.linear2 = torch.nn.Linear(512, 512)
        self.linear3 = torch.nn.Linear(512, 1)
  
        self.linear4 = torch.nn.Linear(state_dim + action_dim, 512)
        self.linear5 = torch.nn.Linear(512, 512)
        self.linear6 = torch.nn.Linear(512, 1)
        
        self.linear7 = torch.nn.Linear(state_dim + action_dim, 512)
        self.linear8 = torch.nn.Linear(512, 512)
        self.linear9 = torch.nn.Linear(512, 1)
        self.apply(assign_weights)
  
    def forward(self, state, action):
        xu = torch.cat([state, action], 1)
  
        x1 = func.relu(self.linear1(xu))
        x1 = func.relu(self.linear2(x1))
        x1 = self.linear3(x1)
  
        x2 = func.relu(self.linear4(xu))
        x2 = func.relu(self.linear5(x2))
        x2 = self.linear6(x2)
        
        x3 = func.relu(self.linear7(xu))
        x3 = func.relu(self.linear8(x3))
        x3 = self.linear9(x3)
        
        return x1, x2, x3
        
    def Q1(self, state, action):
        xu = torch.cat([state, action], 1)
        x1 = func.relu(self.linear1(xu))
        x1 = func.relu(self.linear2(x1))
        x1 = self.linear3(x1)
        
        return x1

class OPAC(object):
    def __init__(self, state_dim, action_space, gamma, tau, alpha, policy_update_freq, auto_entropy_tuning, learning_rate, policy_noise):
        self.gamma = gamma
        self.tau = tau
        self.alpha = alpha
        self.policy_update_freq = policy_update_freq
        self.auto_entropy_tuning = auto_entropy_tuning
        self.learning_rate = learning_rate
        self.policy_noise = policy_noise
        
        self.critic = QNetwork(state_dim, action_dim).to(device)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=learning_rate)
        
        self.critic_target = QNetwork(state_dim, action_dim).to(device)
        self.critic_target.load_state_dict(self.critic.state_dict())
        
        self.target_entropy = -torch.prod(torch.Tensor(action_space.shape).to(device)).item()
        self.log_alpha = torch.zeros(1, requires_grad=True, device=device) 
        self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=learning_rate)
        
        self.policy = GaussianPolicy(state_dim, action_dim, action_space).to(device)
        self.policy_optimizer = torch.optim.Adam(self.policy.parameters(), lr=learning_rate)
        
        self.policy_target = GaussianPolicy(state_dim, action_dim, action_space).to(device)
        self.policy_target.load_state_dict(self.policy.state_dict())
        
        hard_update(self.critic_target, self.critic)
        hard_update(self.policy_target, self.policy)

    
    def select_action(self, state):
        state = torch.Tensor(state.reshape(1, -1)).to(device) 
        action, _, _ = self.policy.sample(state)
        
        return action.cpu().data.numpy().flatten() 
    
    def update_params(self, replay_buffer, iterations, batch_size):
        for i in range(iterations):
            batch_states, batch_next_states, batch_actions, batch_rewards, batch_dones = replay_buffer.sample(batch_size)
            state = torch.Tensor(batch_states).to(device)
            next_state = torch.Tensor(batch_next_states).to(device)
            action = torch.Tensor(batch_actions).to(device)
            reward = torch.Tensor(batch_rewards).to(device)
            done = torch.Tensor(batch_dones).to(device)
            
            next_action, next_action_log_pi, _ = self.policy_target.sample(next_state)
            
            noise = torch.Tensor(batch_actions).data.normal_(0, self.policy_noise).to(device)
            noise = noise.clamp(-noise_clip, noise_clip)
            next_action = (next_action + noise).clamp(-max_action, max_action)
            
            target_Q1, target_Q2, target_Q3 = self.critic_target(next_state, next_action) 

            group = torch.cat([target_Q1, target_Q2, target_Q3], dim=1)
            value, _ = torch.topk(input=group, k=2, dim=1, largest=False)
            value = torch.mean(input=value, dim=1, keepdim=True)
            target_Q = value - self.alpha * next_action_log_pi
            
            target_Q = reward + ((1 - done) * self.gamma * target_Q).detach()

            Q1, Q2, Q3 = self.critic(state, action)
            critic_loss = func.mse_loss(Q1, target_Q) + func.mse_loss(Q2, target_Q) + func.mse_loss(Q3, target_Q)
            
            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            self.critic_optimizer.step()

            pi, log_pi, _ = self.policy.sample(state)
            if (i % self.policy_update_freq) == 0:
                policy_loss = -(self.critic.Q1(state, pi) - (self.alpha * log_pi)).mean()
                self.policy_optimizer.zero_grad()
                policy_loss.backward()
                self.policy_optimizer.step()
                soft_update(self.policy_target, self.policy, self.tau)
                soft_update(self.critic_target, self.critic, self.tau)
            
            alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            self.alpha = self.log_alpha.exp()
    
    def save(self, filename, directory):
        torch.save(self.policy.state_dict(), '%s/%s_policy.pth' % (directory, filename))
        torch.save(self.critic.state_dict(), '%s/%s_critic.pth' % (directory, filename))
 
    def load(self, filename, directory):
        self.policy.load_state_dict(torch.load('%s/%s_policy.pth' % (directory, filename)))
        self.critic.load_state_dict(torch.load('%s/%s_critic.pth' % (directory, filename)))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

env = gym.make(env_name)
env_eval = gym.make(env_name)

env.seed(seed)
env.action_space.seed(seed)
env_eval.seed(seed)
env_eval.action_space.seed(seed)

torch.manual_seed(seed)
np.random.seed(seed)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])
action_space = env.action_space

opac = OPAC(state_dim, action_space, discount, tau, alpha, policy_update_freq, auto_entropy_tuning, learning_rate, policy_noise)

reward_dict = {}

def training_process(cur_model, replay_buffer, model_name):

    print("\t-----------------------------------")
    print("\t\tMODEL: {}".format(model_name.upper()))
    print("\t-----------------------------------")
        
    reward_dict.update({0: evaluate_policy(cur_model, num_eval_samples)})

    timesteps_since_eval = 0
    episode_num = 0
    done = True
    
    for cur_timestep in tqdm(range(max_timesteps)):
        if done:
            if cur_timestep != 0:
                cur_model.update_params(replay_buffer, episode_timesteps, batch_size)
            
            if timesteps_since_eval >= eval_freq:
                timesteps_since_eval %= eval_freq
                reward_dict.update({cur_timestep : evaluate_policy(cur_model, num_eval_samples)})
                
            obs = env.reset()
            done = False

            episode_reward = 0
            episode_timesteps = 0
            episode_num += 1
        
        if cur_timestep < start_timesteps:
            action = env.action_space.sample()
        else:
            action = cur_model.select_action(np.array(obs))
            if exploration_noise != 0:
                action = (action + np.random.normal(0, exploration_noise, size = env.action_space.shape[0])).clip(env.action_space.low, env.action_space.high)
        
        new_obs, reward, done, _ = env.step(action)
        done_bool = 0 if episode_timesteps + 1 == env._max_episode_steps else float(done)
        
        episode_reward += reward
        
        replay_buffer.add((obs, new_obs, action, reward, done_bool))
        
        obs = new_obs
        episode_timesteps += 1
        timesteps_since_eval += 1
    
    reward_dict.update({cur_timestep : evaluate_policy(cur_model)})
    reward_mat = np.array([(key, val[0], val[1]) for (key, val) in reward_dict.items()])

    opac.save(filename = "{}_opac_{}".format(env.spec.id, seed), directory = data_dir)
    np.save(os.path.join(data_dir, "{}_opac_{}_reward".format(env.spec.id, seed)), reward_mat)

replay_buffer = ReplayBuffer()
training_process(opac, replay_buffer, "opac")
