import sys
sys.path.append("..")
import os
import random
import time

import gym
import pybulletgym
import numpy as np

import torch

from ddpg_ndqfn import DDPG
from utils.noise import OrnsteinUhlenbeckActionNoise
from utils.replay_memory import ReplayMemory, Transition
from wrappers.normalized_actions import NormalizedActions

import pickle


# Parse given arguments
# gamma, tau, hidden_size, replay_size, batch_size, hidden_size are taken from the original paper
Gamma = 0.99
Tau = 0.001
Hidden_size = [400, 300]
Replay_size = int(1e5)
Batch_size = 64

Num_quant = 32

Noise_stddev = 0.1

Timesteps = int(1e6)
Save_dir = "./saved_models/"
Log_dir = "./log/"
Env = "InvertedPendulumPyBulletEnv-v0"
Seed = 0

N_test_cycles = 10


# if gpu is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using {}".format(device))

if __name__ == "__main__":

    # Define the directory where to save and load models
    checkpoint_dir = Save_dir + Env

    # Create the env
    env = gym.make(Env)
    env = NormalizedActions(env)

    # Define the reward threshold when the task is solved (if existing) for model saving
    reward_threshold = gym.spec(Env).reward_threshold if gym.spec(
        Env).reward_threshold is not None else np.inf

    # Set random seed for all used libraries where possible
    env.seed(Seed)
    torch.manual_seed(Seed)
    np.random.seed(Seed)
    random.seed(Seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed(Seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # Define and build DDPG agent
    hidden_size = tuple(Hidden_size)
    agent = DDPG(Gamma,
                 Tau,
                 hidden_size,
                 env.observation_space.shape[0],
                 env.action_space,
                 Num_quant,
                 checkpoint_dir=checkpoint_dir
                 )

    # Initialize replay memory
    memory = ReplayMemory(Replay_size)

    # Initialize OU-Noise
    nb_actions = env.action_space.shape[-1]
    ou_noise = OrnsteinUhlenbeckActionNoise(mu=np.zeros(nb_actions),
                                            sigma=float(Noise_stddev) * np.ones(nb_actions))

    # Define counters and other variables
    start_step = 0
    
    timestep = 1
    rewards, policy_losses, value_losses, mean_test_rewards = [], [], [], []
    epoch = 0
    t = 0

    best_test_reward = 0

    # Start training
    print('Train agent on {} env'.format({env.unwrapped.spec.id}))
    print('Doing {} timesteps'.format(Timesteps))
    
    while timestep <= Timesteps:
        ou_noise.reset()
        episode_return = 0
        episode_value_loss = 0
        episode_policy_loss = 0

        state = torch.Tensor([env.reset()]).to(device)
        while True:
            action = agent.calc_action(state, ou_noise) # perform noise action when training
            next_state, reward, done, _ = env.step(action.cpu().numpy()[0])
            timestep += 1
            episode_return += reward

            mask = torch.Tensor([done]).to(device)
            reward = torch.Tensor([reward]).to(device)
            next_state = torch.Tensor([next_state]).to(device)

            memory.push(state, action, mask, next_state, reward)

            state = next_state

            if len(memory) > Batch_size:
                transitions = memory.sample(Batch_size)
                
                batch = Transition(*zip(*transitions))

                # Update actor and critic according to the batch
                value_loss, policy_loss = agent.update_params(batch)

                episode_value_loss += value_loss
                episode_policy_loss += policy_loss

            if done:
                break

        rewards.append(episode_return)
        value_losses.append(episode_value_loss)
        policy_losses.append(episode_policy_loss)

        # Test every 10th episode (== 1e4) steps for a number of test_epochs epochs
        if timestep >= 10000 * t:
            t += 1
            test_rewards = []
            for _ in range(N_test_cycles):
                state = torch.Tensor([env.reset()]).to(device)
                test_reward = 0
                while True:

                    action = agent.calc_action(state)  # Selection without noise

                    next_state, reward, done, _ = env.step(action.cpu().numpy()[0])
                    test_reward += reward

                    next_state = torch.Tensor([next_state]).to(device)

                    state = next_state
                    if done:
                        break
                test_rewards.append(test_reward)

            mean_test_rewards.append(np.mean(test_rewards))

            print("Episode: {}, current timestep: {}, last reward: {}, "
                        "mean reward: {}, mean test reward {}".format(epoch,
                                                                      timestep,
                                                                      rewards[-1],
                                                                      np.mean(rewards[-10:]),
                                                                      np.mean(test_rewards)))
            
            # save if the mean of the last three test reward bigger than threshold
            if np.mean(mean_test_rewards[-3:]) >= best_test_reward:
                agent.save_checkpoint(timestep, memory)
                best_test_reward = np.mean(mean_test_rewards[-3:])


        epoch += 1

    agent.save_checkpoint(timestep, memory)
    save_log_dir = Log_dir + Env + '/'
    os.makedirs(save_log_dir, exist_ok=True)
    log_file_name = "rewards_actNoise_{}.pkl".format(Noise_stddev)
    with open(save_log_dir + log_file_name, 'wb') as f:
        pickle.dump(rewards, f, pickle.HIGHEST_PROTOCOL)
    env.close()
