"""
Script for training advantage regularized PPO (A-PPO)
"""

import argparse

import time
import datetime
import gymnasium as gym
import numpy as np
import torch
from collections import defaultdict

from torch.utils.tensorboard import SummaryWriter

from agents.elbert import ELBERT, get_supply_demand
from utils.rollout_buffer import ELBERTRolloutBuffer
from utils.evaluator import (
    evaluate_policy,
    record_lending_evaluation,
    record_college_admission_evaluation
)
from utils.plot_utils import plot_lending, plot_college_admission
from utils import run_utils
from utils.env_utils import (
    make_env_and_metrics,
    get_observation_dim,
    preprocess_lending_obs,
    preprocess_college_admission_obs,
)
import utils.env_consts as consts
from utils.setup import parse_args


def main():
    args = parse_args()
    
    assert args.agent == "elbert", "Only ELBERT-PO is supported for this script."
    assert args.env in ["lending", "college"], "Only the lending and college admission environments are supported for this script."

    # Logging
    date_time_str = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    log_dir = f"runs/{args.env}_elbert"
    
    # Setup env related things
    if args.env == "lending":
        log_dir += f"_{args.success_func}_{args.seed}_{date_time_str}"
        state_keys = consts.LENDING_STATE_KEYS
        group_key = consts.LENDING_GROUP_KEY
        
        preprocessor_fn = preprocess_lending_obs
        plotting_fn = plot_lending
        record_evaluation_fn = record_lending_evaluation
    elif args.env == "college": 
        log_dir += f"_{args.college_eps}_{args.seed}_{date_time_str}"
        state_keys = consts.COLLEGE_ADMISSION_STATE_KEYS
        group_key = consts.COLLEGE_ADMISSION_GROUP_KEY
        
        preprocessor_fn = preprocess_college_admission_obs
        plotting_fn = plot_college_admission
        record_evaluation_fn = record_college_admission_evaluation
    else:
        raise ValueError(f"Environment {args.env} is not supported.")

    metrics_dir = f"{log_dir}/metrics"
    models_dir = f"{log_dir}/models"
    writer = SummaryWriter(log_dir)
    eval_logs = defaultdict(list)

    # Save the args
    run_utils.save_args(args, log_dir)

    # Set the random seed
    run_utils.seed_everywhere(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Create the environments and metrics
    env, _ = make_env_and_metrics(args)
    eval_env, metrics = make_env_and_metrics(args)

    assert isinstance(
        env.action_space, gym.spaces.Discrete
    ), "only discrete action space is supported"

    state_dim = get_observation_dim(env.observation_space, keys=state_keys)
    group_dim = get_observation_dim(env.observation_space, keys=group_key)
    n_actions = env.action_space.n
    continuous_actions = False
    max_episode_steps = args.max_episode_steps

    # Create the agent
    agent = ELBERT(
        state_dim=state_dim,
        n_actions=n_actions,
        group_dim=group_dim,
        continuous_actions=continuous_actions,
        hidden_width=args.hidden_width,
        learning_rate=args.learning_rate,
        final_learning_rate=args.final_learning_rate,
        batch_size=args.batch_size,
        mini_batch_size=args.mini_batch_size,
        update_epochs=args.update_epochs,
        gamma=args.gamma,
        gae_lambda=args.gae_lambda,
        clip_coef=args.clip_coef,
        norm_adv=args.norm_adv,
        ent_coef=args.ent_coef,
        vf_coef=args.vf_coef,
        max_grad_norm=args.max_grad_norm,
        target_kl=args.target_kl,
        use_anneal_lr=args.anneal_lr,
        bias_coef=args.bias_coef,
        beta_smooth=args.beta_smooth,        
        device=device,
    )
    
    # RL-specific setup
    buffer = ELBERTRolloutBuffer(state_dim, 1, args.batch_size, group_dim, device)

    # RL agent training/evaluation
    obs, info = env.reset(seed=args.seed)
    eval_env.reset(seed=args.seed)

    next_done = torch.zeros(1)
    state, _, _, _ = preprocessor_fn(obs)

    start_time = time.time()
    global_step, episode_return, episode_step = 0, 0, 0
    num_updates = args.total_timesteps // args.batch_size    

    for update in range(1, num_updates + 1):
        # Evaluate the agent
        if update % args.eval_every == 0 or update == 1 or update == num_updates:
            average_return, results = evaluate_policy(
                eval_env,
                agent,
                metrics,
                max_episode_steps,
                eval_count=args.eval_count,
                state_keys=state_keys,
            )
            print(f"Global step: {global_step} \t Average return (eval): {average_return}")
            
            eval_logs = record_evaluation_fn(eval_logs, writer, results, global_step, average_return)
            
            # Plot the results
            if update % args.plot_every == 0 or update == 1 or update == num_updates:
                save_dir = f"{metrics_dir}/final" if update == num_updates else f"{metrics_dir}/step_{global_step}"
                for i, result in enumerate(results):
                    plotting_fn(result, f"{save_dir}/ep_{i}")
                    run_utils.save_json(result, f"{save_dir}/ep_{i}")
        
        # Save the agent 
        if update % args.save_every == 0 or update == num_updates:
            save_dir = f"{models_dir}/final" if update == num_updates else f"{models_dir}/step_{global_step}"
            agent.save(save_dir)

        # Anneal the learning rate if instructed to do so
        agent.anneal_lr(update, num_updates)

        # Reset the rollout buffers
        buffer.reset()

        for step in range(0, args.batch_size):
            global_step += 1
            episode_step += 1

            # Action logic
            with torch.no_grad():
                action, logprob, _, value, supply_value, demand_value = agent.get_action_and_value(state)
                
            # Execute the action and store data
            next_obs, reward, termination, truncation, info = env.step(action.cpu().numpy())
            done = termination or truncation            
            
            # Get the supply and demand
            supply_reward, demand_reward = get_supply_demand(obs['group'], action.item(), info["success"])
            
            episode_return += reward
            next_state, _, _, _ = preprocessor_fn(next_obs)

            value = value.flatten()
            reward = torch.tensor([reward]).view(-1)
            supply_reward = torch.Tensor(supply_reward).to(device)
            demand_reward = torch.Tensor(demand_reward).to(device)
            
            buffer.add(state, action, logprob, reward, next_done, value,
                       supply_reward, demand_reward, supply_value, demand_value)
            
            state = torch.Tensor(next_state).to(device)
            next_done = torch.Tensor([done]).to(device)

            if next_done or episode_step == max_episode_steps:
                print(f"Global_step: {global_step} \t Episodic return (train): {episode_return}")
                writer.add_scalar("train/episodic_return", episode_return, global_step)
                writer.add_scalar("train/episodic_length", episode_step, global_step)

                obs, info = env.reset()
                state, _, _, _ = preprocessor_fn(obs)
                episode_return, episode_step = 0, 0

        # Update the agent
        agent_metrics = agent.update(buffer, state, next_done)

        for k, v in agent_metrics.items():
            writer.add_scalar(k, v, global_step)
        writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
        
    # Save the metrics
    run_utils.save_csv(eval_logs, log_dir)
    writer.close()


if __name__ == "__main__":
    main()
