"""
Script to train DQN on standard Gymnasium environments with discrete actions.
Docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppopy
"""
import os
import random
import time
import argparse

import gymnasium as gym
import numpy as np
import torch

from torch.utils.tensorboard import SummaryWriter

from agents.dqn import DQN, linear_schedule
from utils.rollout_buffer import DQNBuffer


def parse_args():
    parser = argparse.ArgumentParser()    
    parser.add_argument("--exp_name", type=str, default=os.path.basename(__file__)[: -len(".py")],
                        help="the name of this experiment")
    parser.add_argument("--seed", type=int, default=1, 
                        help="seed of the experiment")
    parser.add_argument("--torch_deterministic", type=bool, default=True,
                        help="if toggled, `torch.backends.cudnn.deterministic=False`")
    parser.add_argument("--cuda", type=bool, default=True,
                        help="if toggled, cuda will be enabled by default")
    parser.add_argument("--eval_freq", type=int, default=1000, 
                        help="the interval between two consecutive evaluations")
    parser.add_argument("--save_model", type=bool, default=False, 
                        help="whether to save model into the `runs/{run_name}` folder")
    parser.add_argument("--env_id", type=str, default="CartPole-v1",
                        help="the id of the environment")
    parser.add_argument("--total_timesteps", type=int, default=500000, 
                        help="total timesteps of the experiments")
    parser.add_argument("--learning_rate", type=float, default=2.5e-4, 
                        help="the learning rate of the optimizer")
    parser.add_argument("--gamma", type=float, default=0.99, 
                        help="the discount factor gamma")
    parser.add_argument("--tau", type=float, default=1.0, 
                        help="the target network update rate")
    parser.add_argument("--batch_size", type=int, default=256,
                        help="the batch size of sample from the reply memory")
    parser.add_argument("--start_e", type=float, default=1,
                        help="the starting epsilon for exploration")
    parser.add_argument("--end_e", type=float, default=0.05, 
                        help="the ending epsilon for exploration")
    parser.add_argument("--exploration_fraction", type=float, default=0.5,
                        help="the fraction of `total-timesteps` it takes from start-e to go end-e")
    parser.add_argument("--learning_starts", type=int, default=10000,
                        help="timestep to start learning")

    args = parser.parse_args()
    return args


def make_env(env_id, seed):
    env = gym.make(env_id)
    env.action_space.seed(seed)
    env.observation_space.seed(seed)
    return env


def evaluate_policy(env, agent, epsilon, eval_count=5):
    total_reward = 0

    for _ in range(eval_count):
        s, _ = env.reset()
        episode_reward, terminated, truncated = 0, False, False
        while not terminated and not truncated:
            if random.random() < epsilon:
                a = env.action_space.sample()
            else:
                a = agent.act(s)
            
            s, r, terminated, truncated, _ = env.step(a)
            episode_reward += r
            
        total_reward += episode_reward

    average_return = total_reward / eval_count
    return average_return


if __name__ == "__main__":
    import stable_baselines3 as sb3

    if sb3.__version__ < "2.0":
        raise ValueError(
            """Ongoing migration: run the following command to install the new dependencies:

poetry run pip install "stable_baselines3==2.0.0a1"
"""
        )
    args = parse_args()
    
    run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
    writer = SummaryWriter(f"runs/{run_name}")
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
    )

    # seeding
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic

    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

    # env setup
    env = make_env(args.env_id, args.seed)
    eval_env = make_env(args.env_id, args.seed + 100)
    
    assert isinstance(env.action_space, gym.spaces.Discrete), "only discrete action space is supported"
    
    state_dim = env.observation_space.shape[0]
    n_actions = env.action_space.n
    
    agent = DQN(
        state_dim=state_dim,
        n_actions=n_actions,        
        learning_rate=args.learning_rate,
        final_learning_rate=args.final_learning_rate,
        gamma=args.gamma,
        tau=args.tau,
        batch_size=args.batch_size,
        device=device,
    )
    
    buffer = DQNBuffer(state_dim, 1, args.batch_size, device)
    
    start_time = time.time()
    episode_return, episode_steps = 0, 0

    # start the game
    obs, _ = env.reset(seed=args.seed)
    for global_step in range(args.total_timesteps):
        episode_steps += 1
        
        # put action logic here
        epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
        if random.random() < epsilon:
            action = env.action_space.sample()
        else:
            action = agent.act(obs)

        # execute the game and log data.
        next_obs, reward, termination, truncation, info = env.step(action)
        
        episode_return += reward
        done = termination or truncation

        # save data to reply buffer; handle `final_observation`
        obs = torch.Tensor(obs)
        next_obs = torch.Tensor(next_obs)
        action = torch.Tensor([action])
        reward = torch.Tensor([reward])
        termination = torch.Tensor([termination])
        
        buffer.add(obs, next_obs, action, reward, termination)

        # CRUCIAL step easy to overlook
        obs = next_obs
        
        # record rewards for plotting purposes
        if done: 
            print(f"Global_step: {global_step} \t Episodic return (train): {episode_return}")
            writer.add_scalar("train/episodic_return", episode_return, global_step)
            writer.add_scalar("train/episodic_length", episode_steps, global_step)
            
            obs, _ = env.reset()            
            episode_return, episode_steps = 0, 0

        # training
        if global_step % args.batch_size == 0:
            metrics = agent.update(buffer)
            buffer.reset()
            
            for k, v in metrics.items():
                writer.add_scalar(k, v, global_step)
            
            if global_step % 100 == 0:
                print("SPS:", int(global_step / (time.time() - start_time)))
                writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

        
        if global_step % args.eval_freq == 0:
            average_return = evaluate_policy(
                eval_env,
                agent,            
                epsilon=epsilon,
                eval_count=5,
            )        
            print(f"Global step: {global_step} \t Average return (eval): {average_return}")
            writer.add_scalar("eval/average_return", average_return, global_step)

    if args.save_model:
        model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
        agent.save(model_path)
        print(f"model saved to {model_path}")

    env.close()
    writer.close()