"""
Based on:
https://github.com/hercky/group-fairness-in-RL/tree/main
"""

import time
import datetime
import numpy as np
import gymnasium as gym
import torch
from collections import defaultdict
from itertools import combinations

from torch.utils.tensorboard import SummaryWriter

from agents import LagrangianPPO
from utils.rollout_buffer import RolloutBuffer
from utils.evaluator import (
    evaluate_policy,
    record_lending_evaluation,
    record_college_admission_evaluation
)
from utils.plot_utils import plot_lending, plot_college_admission
from utils import run_utils
from utils.env_utils import (
    make_env_and_metrics,
    get_observation_dim,
    preprocess_lending_obs,
    preprocess_college_admission_obs,
)
import utils.env_consts as consts
from utils.setup import parse_args


def main():
    args = parse_args()
    
    assert args.agent == "lagppo", "Only Lagrangian PPO is supported for this script."
    assert args.env in ["lending", "college"], "Only the lending and college admission environments are supported for this script."

    # Logging
    date_time_str = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    log_dir = f"runs/{args.env}_lagppo"
    
    # Setup env related things
    if args.env == "lending":
        log_dir += f"_{args.success_func}_{args.seed}_{date_time_str}"
        state_keys = consts.LENDING_STATE_KEYS
        group_key = consts.LENDING_GROUP_KEY
        
        preprocessor_fn = preprocess_lending_obs
        plotting_fn = plot_lending
        record_evaluation_fn = record_lending_evaluation
    elif args.env == "college": 
        log_dir += f"_{args.college_eps}_{args.seed}_{date_time_str}"
        state_keys = consts.COLLEGE_ADMISSION_STATE_KEYS
        group_key = consts.COLLEGE_ADMISSION_GROUP_KEY
        
        preprocessor_fn = preprocess_college_admission_obs
        plotting_fn = plot_college_admission
        record_evaluation_fn = record_college_admission_evaluation
    else:
        raise ValueError(f"Environment {args.env} is not supported.")

    metrics_dir = f"{log_dir}/metrics"
    writer = SummaryWriter(log_dir)    

    # Save the args
    run_utils.save_args(args, log_dir)

    # Set the random seed
    run_utils.seed_everywhere(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Create the eval environment and metrics 
    eval_env, metrics = make_env_and_metrics(args)

    state_dim = get_observation_dim(eval_env.observation_space, keys=state_keys)
    group_dim = get_observation_dim(eval_env.observation_space, keys=group_key)
    n_actions = eval_env.action_space.n
    continuous_actions = False
    
    max_episode_steps = args.max_episode_steps

    # Create the agent
    envs, agents, buffers, eval_logs = [], [], [], []
    
    for z in range(group_dim):
        eval_logs.append(defaultdict(list))
        envs.append(
            make_env_and_metrics(args)[0]
        )
        agents.append(
            LagrangianPPO(
                state_dim=state_dim,
                n_actions=n_actions,
                group_dim=group_dim,
                group=z,
                continuous_actions=continuous_actions,
                hidden_width=args.hidden_width,
                learning_rate=args.learning_rate,
                final_learning_rate=args.final_learning_rate,
                batch_size=args.batch_size,
                mini_batch_size=args.mini_batch_size,
                update_epochs=args.update_epochs,
                gamma=args.gamma,
                gae_lambda=args.gae_lambda,
                clip_coef=args.clip_coef,
                norm_adv=args.norm_adv,
                clip_vloss=args.clip_vloss,
                ent_coef=args.ent_coef,
                vf_coef=args.vf_coef,
                max_grad_norm=args.max_grad_norm,
                target_kl=args.target_kl,
                use_anneal_lr=args.anneal_lr,
                nu_init=args.nu_init,
                nu_max=args.nu_max,
                nu_lr=args.nu_lr,
                device=device,
            )
        )
        buffers.append(RolloutBuffer(state_dim, 1, args.batch_size, device))
    
    # RL agent training/evaluation
    eval_env.reset(seed=args.seed)
    
    # There is an issue here but for now, we will ignore it
    for z in range(group_dim):
        obs, _ = envs[z].reset(seed=args.seed)
        state, _, _, _ = preprocessor_fn(obs)

    next_done = torch.zeros(1)

    start_time = time.time()
    global_step, episode_return, episode_step = 0, 0, 0
    num_updates = args.total_timesteps // args.batch_size // group_dim
    cum_fair_violations = 0.
    
    for update in range(1, num_updates + 1):
        # Evaluate each agent
        if update % args.eval_every == 0 or update == 1 or update == num_updates:
            for z in range(group_dim):
                average_return, results = evaluate_policy(
                    eval_env,
                    agents[z],
                    metrics,
                    max_episode_steps,
                    eval_count=args.eval_count,
                    state_keys=state_keys,
                )
                print(f"Global step: {global_step} \t Average return (eval): {average_return}")
                
                eval_logs[z] = record_evaluation_fn(eval_logs[z], writer, results, global_step, average_return, tag=f"eval/agent_{z + 1}")
                
                # Plot the results
                if update % args.plot_every == 0 or update == 1 or update == num_updates:
                    save_dir = f"{metrics_dir}/final" if update == num_updates else f"{metrics_dir}/step_{global_step}"
                    for i, result in enumerate(results):
                        plotting_fn(result, f"{save_dir}/agent_{z + 1}/ep_{i}")
                        run_utils.save_json(result, f"{save_dir}/agent_{z + 1}/ep_{i}")
        
        for z in range(group_dim):
            # Anneal the learning rate
            agents[z].anneal_lr(update, num_updates)

            # Reset the rollout buffers
            buffers[z].reset()
            step = 0

            while step < args.batch_size:
                # Action logic
                with torch.no_grad():
                    action, logprob, _, value = agents[z].get_action_and_value(state)

                # Execute the action and store data
                next_obs, reward, termination, truncation, _ = envs[z].step(action.cpu().numpy())
                done = termination or truncation
                episode_step += 1
                
                next_state, _, _, group = preprocessor_fn(next_obs)
                # hack: if group is not the same as z, then we need to skip this step
                if np.argmax(group) != z:
                    if done or episode_step == max_episode_steps:
                        print(f"Global_step: {global_step} \t Episodic return (train): {episode_return}")
                        writer.add_scalar("train/episodic_return", episode_return, global_step)

                        obs, _ = envs[z].reset()
                        state, _, _, _ = preprocessor_fn(obs)
                        episode_return, episode_step = 0, 0
                    continue

                episode_return += reward
                global_step += 1   
                step += 1             

                value = value.flatten()
                reward = torch.tensor([reward]).view(-1)

                buffers[z].add(state, action, logprob, reward, next_done, value)
                
                state = torch.Tensor(next_state).to(device)
                next_done = torch.Tensor([done]).to(device)

                if next_done or episode_step == max_episode_steps:
                    print(f"Global_step: {global_step} \t Episodic return (train): {episode_return}")
                    writer.add_scalar("train/episodic_return", episode_return, global_step)

                    obs, _ = envs[z].reset()
                    state, _, _, _ = preprocessor_fn(obs)
                    episode_return, episode_step = 0, 0
        
        # Get the difference in returns for each agent
        group_returns = []
        for z in range(group_dim):
            _, _, _, rewards, _, _ = buffers[z].get_data()
            group_returns.append(rewards.sum())
        
        gap_between_returns = []
        for pair in combinations(range(group_dim), 2):
            gap_between_returns.append(abs(group_returns[pair[0]] - group_returns[pair[1]]))
        fair_gap = max(gap_between_returns)
        
        # Extra metrics
        fair_violation = float(fair_gap > args.epsilon)
        cum_fair_violations += fair_violation
        
        writer.add_scalar(f"fair/fair_gap", fair_gap, global_step)
        writer.add_scalar(f"fair/cum_violations", cum_fair_violations, global_step)
        for z in range(group_dim):
            writer.add_scalar(f"fair/return_group_{z + 1}", group_returns[z], global_step)

        # Update the agents
        for z0 in range(group_dim):
            return_diff = []
            for z1 in range(group_dim):
                if z0 == z1:
                    continue
                return_diff.append(group_returns[z0] - group_returns[z1])
            agent_metrics = agents[z0].update(buffers[z0], state, next_done, return_diff)
            
            for k, v in agent_metrics.items():
                writer.add_scalar(k, v, global_step)
                
        writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
        
    # Save the metrics
    for z in range(group_dim):
        run_utils.save_csv(eval_logs[z], log_dir, tag=f"_agent_{z + 1}")
    writer.close()


if __name__ == "__main__":
    main()
