"""
Main script for running lending experiments.
"""

import time
import datetime
import gymnasium as gym
import torch
import random
from collections import defaultdict

from torch.utils.tensorboard import SummaryWriter

from agents import DQN, Bisimulator
from agents.dqn import linear_schedule
from utils.rollout_buffer import DQNBuffer, BisimulatorBuffer
from utils.evaluator import (
    evaluate_policy,
    record_lending_evaluation,
    record_college_admission_evaluation
)
from utils.agent_utils import (
    make_agent, 
    DQN_AGENTS, 
)
from utils.plot_utils import plot_lending, plot_college_admission
from utils import run_utils
from utils.env_utils import (
    make_env_and_metrics,
    get_observation_dim,
    preprocess_lending_obs,
    preprocess_college_admission_obs
)
import utils.env_consts as consts
from utils.setup import parse_args


def main():
    args = parse_args()
    
    assert (
        args.agent in DQN_AGENTS
    ), f"Only the following agents are supported: {DQN_AGENTS}"
    assert args.env in ["lending", "college"], "Only the lending and college admission environments are supported for this script."

    # Logging
    date_time_str = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    log_dir = f"runs/{args.env}_{args.agent}_rew_{args.rew_coef}"
    
    # Setup env related things
    if args.env == "lending":
        log_dir += f"_{args.success_func}_{args.seed}_{date_time_str}"
        state_keys = consts.LENDING_STATE_KEYS
        state_without_group_key = consts.LENDING_STATE_WITHOUT_GROUP_KEY
        group_key = consts.LENDING_GROUP_KEY
        prev_applicant_next_state_key = consts.LENDING_PREV_APPLICANT_NEXT_STATE_KEY
        
        preprocessor_fn = preprocess_lending_obs
        plotting_fn = plot_lending
        record_evaluation_fn = record_lending_evaluation
    elif args.env == "college": 
        log_dir += f"_{args.college_eps}_{args.seed}_{date_time_str}"
        state_keys = consts.COLLEGE_ADMISSION_STATE_KEYS
        state_without_group_key = consts.COLLEGE_ADMISSION_STATE_WITHOUT_GROUP_KEY
        group_key = consts.COLLEGE_ADMISSION_GROUP_KEY
        prev_applicant_next_state_key = consts.COLLEGE_PREV_APPLICANT_NEXT_STATE_KEY
        
        preprocessor_fn = preprocess_college_admission_obs
        plotting_fn = plot_college_admission
        record_evaluation_fn = record_college_admission_evaluation
    else:
        raise ValueError(f"Unknown environment: {args.env}")

    metrics_dir = f"{log_dir}/metrics"
    models_dir = f"{log_dir}/models"
    writer = SummaryWriter(log_dir)
    eval_logs = defaultdict(list)

    # Save the args
    run_utils.save_args(args, log_dir)

    # Set the random seed
    run_utils.seed_everywhere(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Create the environments and the metrics
    env, _ = make_env_and_metrics(args)
    eval_env, metrics = make_env_and_metrics(args)
    bisimulator_env, _ = make_env_and_metrics(args)

    assert isinstance(
        env.action_space, gym.spaces.Discrete
    ), "only discrete action space is supported"

    state_dim = get_observation_dim(env.observation_space, keys=state_keys)
    state_without_group_dim = get_observation_dim(
        env.observation_space, keys=state_without_group_key
    )
    actual_next_state_dim = get_observation_dim(
        env.observation_space,
        keys=prev_applicant_next_state_key,
    )
    group_dim = get_observation_dim(env.observation_space, keys=group_key)
    n_actions = env.action_space.n
    max_episode_steps = args.max_episode_steps
    
    # Create the agent
    agent = make_agent(
        agent=args.agent,
        state_dim=state_dim,
        n_actions=n_actions,
        env=env,
        device=device,
        args=args,
    )

    # RL-specific setup
    if args.agent in DQN_AGENTS:
        buffer = DQNBuffer(state_dim, 1, args.batch_size, device)
        bisimulator_buffer = BisimulatorBuffer(
            state_without_group_dim,
            actual_next_state_dim,
            group_dim,
            args.batch_size,
            device,
        )

        bisimulator = Bisimulator(
            state_without_group_dim=state_without_group_dim,
            actual_next_state_dim=actual_next_state_dim,
            group_dim=group_dim,
            n_actions=n_actions,            
            hidden_width=args.hidden_width,
            env_name=args.env,
            learning_rate=args.learning_rate,
            final_learning_rate=args.final_learning_rate,
            use_anneal_lr=args.anneal_bisim_lr,
            batch_size=args.batch_size,
            gamma=args.gamma,
            dyn_model_epochs=args.dyn_model_epochs,
            start_dyn_opt_step=args.start_dyn_opt,
            dyn_opt_iters=args.dyn_opt_iters,
            max_episode_steps=max_episode_steps,
            dyn_rollout_steps=args.dyn_rollout_steps,
            rew_coef=args.rew_coef,
            device=device,
        )

    # RL agent training/evaluation
    obs, _ = env.reset(seed=args.seed)
    eval_env.reset(seed=args.seed)
    bisimulator_env.reset(seed=args.seed)

    next_done = torch.zeros(1)
    state, state_without_group, _, group = preprocessor_fn(obs)

    start_time = time.time()
    global_step, episode_return, episode_bisimulator_return, episode_total_return, episode_step = 0, 0, 0, 0., 0
    num_updates = args.total_timesteps // args.batch_size
    epsilon = args.start_e

    for update in range(1, num_updates + 1):
        # Evaluate the agent
        if update % args.eval_every == 0 or update == 1 or update == num_updates:
            average_return, results = evaluate_policy(
                eval_env,
                agent,
                metrics,
                max_episode_steps,
                eval_count=args.eval_count,
                state_keys=state_keys,
            )
            print(f"Global step: {global_step} \t Average return (eval): {average_return}")
            
            eval_logs = record_evaluation_fn(eval_logs, writer, results, global_step, average_return)
            
            # Plot the results
            if update % args.plot_every == 0 or update == 1 or update == num_updates:
                save_dir = f"{metrics_dir}/final" if update == num_updates else f"{metrics_dir}/step_{global_step}"
                for i, result in enumerate(results):
                    plotting_fn(result, f"{save_dir}/ep_{i}")
                    run_utils.save_json(result, f"{save_dir}/ep_{i}")
        
        # Save the agent 
        if update % args.save_every == 0 or update == num_updates:
            save_dir = f"{models_dir}/final" if update == num_updates else f"{models_dir}/step_{global_step}"
            agent.save(save_dir)
            bisimulator.save(save_dir)
        
        # Anneal the learning rate if instructed to do so
        agent.anneal_lr(update, num_updates)
        bisimulator.anneal_lr(update, num_updates)
            
        # Reset the rollout buffers
        buffer.reset()
        bisimulator_buffer.reset()

        for step in range(0, args.batch_size):
            global_step += 1
            episode_step += 1
            epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
            
            # Action logic
            with torch.no_grad():
                if random.random() < epsilon:
                    action = env.action_space.sample()
                else:
                    action = agent.act(state)
                    
                if args.agent in ["dqn_bisim_rew", "dqn_bisim_rew_dyn"]:
                    action = torch.tensor([action]).to(device)
                    bisimulator_reward = bisimulator.get_reward(state_without_group, group, action).cpu().numpy()
                else:
                    bisimulator_reward = 0.0

            # Execute the action and store data
            next_obs, env_reward, termination, truncation, _ = env.step(action)
            done = termination or truncation

            episode_return += env_reward
            episode_bisimulator_return += bisimulator_reward
            next_state, next_state_without_group, actual_next_state, next_group = preprocessor_fn(next_obs)
            
            action = torch.Tensor([action]).to(device)
            termination = torch.Tensor([termination]).to(device)
            env_reward = torch.tensor([env_reward]).view(-1)

            if args.agent in ["dqn_bisim_rew", "dqn_bisim_rew_dyn"]:
                # Add the bisimulator reward to the environment reward
                frac = 1.0 - (global_step / args.total_timesteps) if args.decay_rew_coef else 1.0
                bisimulator_reward = torch.tensor(bisimulator_reward).view(-1)
                total_reward = env_reward + frac * args.rew_coef * bisimulator_reward
            else:
                total_reward = env_reward
            episode_total_return += total_reward
                        
            buffer.add(state, next_state, action, total_reward, termination)
            bisimulator_buffer.add(state_without_group, action, group, env_reward, actual_next_state)
            
            group = torch.Tensor(next_group).to(device)
            state = torch.Tensor(next_state).to(device)
            state_without_group = torch.Tensor(next_state_without_group).to(device)
            next_done = torch.Tensor([done]).to(device)

            if next_done or episode_step == max_episode_steps:
                print(f"Global_step: {global_step} \t Episodic return (train): {episode_return}")
                writer.add_scalar("train/episodic_return", episode_return, global_step)
                writer.add_scalar("train/episodic_length", episode_step, global_step)
                writer.add_scalar("train/episodic_bisimulator_return", episode_bisimulator_return, global_step)
                writer.add_scalar("train/episodic_total_return", episode_total_return, global_step)

                obs, _ = env.reset()
                state, state_without_group, _, group = preprocessor_fn(obs)
                episode_return, episode_bisimulator_return, episode_total_return, episode_step = 0, 0, 0., 0

        # Update the reward
        if args.agent in ["dqn_bisim_rew", "dqn_bisim_rew_dyn"] and update % args.rew_update_freq == 0:
            for i in range(args.rew_steps):
                reward_metrics = bisimulator.update_reward(bisimulator_buffer)
        else:
            reward_metrics = {}

        # Update the dynamics
        if args.agent in ["dqn_bisim_rew_dyn"]:
            if args.env == "lending":
                dynamics_metrics, pos_credit_changes, neg_credit_changes = bisimulator.update_dynamics(bisimulator_env, agent, global_step)            
                if pos_credit_changes is not None and neg_credit_changes is not None:
                    env.unwrapped.set_credit_changes(pos_credit_changes, neg_credit_changes)
                    eval_env.unwrapped.set_credit_changes(pos_credit_changes, neg_credit_changes)
                    
                    for g in range(group_dim):
                        dynamics_metrics[f"credit_changes/positive_group_{g}"] = pos_credit_changes[g]
                        dynamics_metrics[f"credit_changes/negative_group_{g}"] = neg_credit_changes[g]
            
            elif args.env == "college":
                dynamics_metrics, score_changes_coef = bisimulator.update_dynamics(bisimulator_env, agent, global_step)
                if score_changes_coef is not None:
                    env.unwrapped.set_score_changes(score_changes_coef)
                    eval_env.unwrapped.set_score_changes(score_changes_coef)
                    
                    for g in range(group_dim):
                        dynamics_metrics[f"score_changes/coef_group_{g}"] = score_changes_coef[g]
            else:
                raise ValueError(f"Unknown environment: {args.env}")
                
        else:
            dynamics_metrics = {}

        # Update the agent
        agent_metrics = agent.update(buffer, update)

        for d in [agent_metrics, reward_metrics, dynamics_metrics]:
            for k, v in d.items():
                writer.add_scalar(k, v, global_step)
        writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
        
    # Save the metrics
    run_utils.save_csv(eval_logs, log_dir)
    writer.close()


if __name__ == "__main__":
    main()
