"""
Script to train PPO on standard Gymnasium environments with discrete actions.
Docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppopy
"""
import argparse
import os
import random
import time
from distutils.util import strtobool

import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

from agents import PPO
from utils.rollout_buffer import RolloutBuffer


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
        help="the name of this experiment")
    parser.add_argument("--seed", type=int, default=1,
        help="seed of the experiment")
    parser.add_argument("--eval_freq", type=int, default=10, 
        help="the interval between two consecutive evaluations")
    parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
        help="if toggled, `torch.backends.cudnn.deterministic=False`")
    parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
        help="if toggled, cuda will be enabled by default")

    # Algorithm specific arguments
    parser.add_argument("--env-id", type=str, default="CartPole-v1",
        help="the id of the environment")
    parser.add_argument("--total-timesteps", type=int, default=500000,
        help="total timesteps of the experiments")
    parser.add_argument("--hidden-width", type=int, default=256,
        help="the width of the hidden layers")
    parser.add_argument("--learning-rate", type=float, default=2.5e-4,
        help="the learning rate of the optimizer")
    parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
        help="Toggle learning rate annealing for policy and value networks")
    parser.add_argument("--gamma", type=float, default=0.99,
        help="the discount factor gamma")
    parser.add_argument("--gae-lambda", type=float, default=0.95,
        help="the lambda for the general advantage estimation")
    parser.add_argument("--batch-size", type=int, default=128,
        help="the batch size of sample from the replay buffer")
    parser.add_argument("--mini-batch-size", type=int, default=32,
        help="the batch size for one gradient update")
    parser.add_argument("--update-epochs", type=int, default=5,
        help="the K epochs to update the policy")
    parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
        help="Toggles advantages normalization")
    parser.add_argument("--clip-coef", type=float, default=0.2,
        help="the surrogate clipping coefficient")
    parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
        help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
    parser.add_argument("--ent-coef", type=float, default=0.01,
        help="coefficient of the entropy")
    parser.add_argument("--vf-coef", type=float, default=0.5,
        help="coefficient of the value function")
    parser.add_argument("--max-grad-norm", type=float, default=0.5,
        help="the maximum norm for the gradient clipping")
    parser.add_argument("--target-kl", type=float, default=None,
        help="the target KL divergence threshold")
    args = parser.parse_args()
    return args


def make_env(env_id, seed):
    env = gym.make(env_id)
    env.action_space.seed(seed)
    env.observation_space.seed(seed)
    return env


def evaluate_policy(env, agent, eval_count=5):
    total_reward = 0

    for _ in range(eval_count):
        s, _ = env.reset()
        episode_reward, terminated, truncated = 0, False, False
        while not terminated and not truncated:
            a = agent.act(s)
            s, r, terminated, truncated, _ = env.step(a)
            episode_reward += r
            
        total_reward += episode_reward

    average_return = total_reward / eval_count
    return average_return


if __name__ == "__main__":
    args = parse_args()
    run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
    writer = SummaryWriter(f"runs/{run_name}")
    
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
    )

    # seeding
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic

    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

    # env setup
    env = make_env(args.env_id, args.seed)
    eval_env = make_env(args.env_id, args.seed + 100)
    
    state_dim = env.observation_space.shape[0]
    
    if isinstance(env.action_space, gym.spaces.Discrete):
        n_actions = env.action_space.n
        continuous_actions = False
        buffer = RolloutBuffer(state_dim, 1, args.batch_size, device)
    elif isinstance(env.action_space, gym.spaces.Box):
        n_actions = env.action_space.shape[0]
        continuous_actions = True
        buffer = RolloutBuffer(state_dim, n_actions, args.batch_size, device)
    else:
        raise NotImplementedError

    # initialize agent
    agent = PPO(
        state_dim=state_dim,
        n_actions=n_actions, 
        continuous_actions=continuous_actions,
        hidden_width=args.hidden_width,
        learning_rate=args.learning_rate,
        final_learning_rate=args.final_learning_rate,
        batch_size=args.batch_size,
        mini_batch_size=args.mini_batch_size,
        update_epochs=args.update_epochs,
        gamma=args.gamma,
        gae_lambda=args.gae_lambda,
        clip_coef=args.clip_coef, 
        norm_adv=args.norm_adv,
        clip_vloss=args.clip_vloss,
        ent_coef=args.ent_coef,
        vf_coef=args.vf_coef,
        max_grad_norm=args.max_grad_norm, 
        target_kl=args.target_kl,
        use_anneal_lr=args.anneal_lr,
        device=device,
    )
    
    # Start the env
    global_step = 0
    episode_return, episode_steps = 0, 0
    start_time = time.time()
    state, _ = env.reset(seed=args.seed)
    state = torch.Tensor(state)
    next_done = torch.zeros(1)
    num_updates = args.total_timesteps // args.batch_size

    for update in range(1, num_updates + 1):
        # annealing the learning rate if instructed to do so
        agent.anneal_lr(update, num_updates)
        
        # reset the rollout buffer
        buffer.reset()

        for step in range(0, args.batch_size):
            global_step += 1
            episode_steps += 1
            
            # action logic
            with torch.no_grad():
                action, logprob, _, value = agent.get_action_and_value(state)
            
            # execute the action and store data
            if continuous_actions:
                next_state, reward, termination, truncation, _ = env.step(action.cpu().numpy().reshape(-1))
            else:
                next_state, reward, termination, truncation, _ = env.step(action.cpu().numpy())
            done = termination or truncation
            
            episode_return += reward
            reward = torch.tensor([reward]).view(-1)
            value = value.flatten()
            buffer.add(state, action, logprob, reward, next_done, value)
            
            state, next_done = torch.Tensor(next_state), torch.Tensor([done])

            if next_done:
                print(f"Global_step: {global_step} \t Episodic return (train): {episode_return}")
                writer.add_scalar("train/episodic_return", episode_return, global_step)
                writer.add_scalar("train/episodic_length", episode_steps, global_step)
                
                state, _ = env.reset()
                state = torch.Tensor(state)
                episode_return, episode_steps = 0, 0

        # update the agent
        metrics = agent.update(buffer, state, next_done)
        
        for k, v in metrics.items():
            writer.add_scalar(k, v, global_step)
        writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
        
        # evaluate the agent
        if update % args.eval_freq == 0:
            average_return = evaluate_policy(eval_env, agent, eval_count=5)
            print(f"Global step: {global_step} \t Average return (eval): {average_return}")
            writer.add_scalar("eval/average_return", average_return, global_step)

    env.close()
    writer.close()
