"""
Utilities for creating lending agents.
"""

from argparse import Namespace

import gymnasium as gym

from agents import RandomAgent, PPO, DQN, AbstractAgent, CollegeClassifier
from utils.params import CostMatrix
from agents import (
    classifier_agents,
    oracle_lending_agent,
    threshold_policies,
)

from fair_gym import AcceptRejectAction

PPO_AGENTS = ["ppo", "bisim_rew", "bisim_rew_dyn"]
DQN_AGENTS = ["dqn", "dqn_bisim_rew", "dqn_bisim_rew_dyn"]
LENDING_NON_RL_AGENTS = ["max_util", "eo"]
COLLEGE_ADMISSION_NON_RL_AGENTS = ["classifier"]
NON_RL_AGENTS = ["random"]

MISC_AGENTS = {
    "random": RandomAgent,
}

THRESHOLD_POLICIES = {
    "max_util": threshold_policies.ThresholdPolicy.MAXIMIZE_REWARD,
    "eo": threshold_policies.ThresholdPolicy.EQUALIZE_OPPORTUNITY,
}


def make_agent(
    agent: str,
    state_dim: int,
    n_actions: int,
    continuous_actions: bool = False,
    env=None,
    device: str = "cpu",
    args: Namespace = None,
) -> AbstractAgent:
    """
    Create an agent with the specified parameters.

    Args:
        agent (str): The name of the agent to create.
        state_dim (int): The dimension of the state space.
        n_actions (int): The number of actions in the action space.
        continuous_actions (bool): Whether the action space is continuous.
        env (gym.Env): The environment to use for training.
        device (str): The device to use for training.
        args (Namespace): The command line arguments.

    Returns:
        AbstractAgent: The lending agent.
    """
    assert (
        agent
        in PPO_AGENTS
        + DQN_AGENTS
        + LENDING_NON_RL_AGENTS
        + COLLEGE_ADMISSION_NON_RL_AGENTS
        + NON_RL_AGENTS
    ), f"Unknown agent: {agent}"

    if agent in PPO_AGENTS:
        return PPO(
            state_dim=state_dim,
            n_actions=n_actions,
            continuous_actions=continuous_actions,
            hidden_width=args.hidden_width,
            learning_rate=args.learning_rate,
            final_learning_rate=args.final_learning_rate,
            batch_size=args.batch_size,
            mini_batch_size=args.mini_batch_size,
            update_epochs=args.update_epochs,
            gamma=args.gamma,
            gae_lambda=args.gae_lambda,
            clip_coef=args.clip_coef,
            norm_adv=args.norm_adv,
            clip_vloss=args.clip_vloss,
            ent_coef=args.ent_coef,
            vf_coef=args.vf_coef,
            max_grad_norm=args.max_grad_norm,
            target_kl=args.target_kl,
            use_anneal_lr=args.anneal_lr,
            device=device,
        )
    elif agent in DQN_AGENTS:
        return DQN(
            state_dim=state_dim,
            n_actions=n_actions,
            hidden_width=args.hidden_width,
            learning_rate=args.learning_rate,
            final_learning_rate=args.final_learning_rate,
            batch_size=args.batch_size,
            update_epochs=args.dqn_update_epochs,
            gamma=args.gamma,
            tau=args.tau,
            use_anneal_lr=args.anneal_lr,
            device=device,
        )
    elif agent in COLLEGE_ADMISSION_NON_RL_AGENTS:
        return CollegeClassifier(
            state_dim=state_dim,
            env=env,
            hidden_width=args.hidden_width,
            device=device,
        )
    elif agent in LENDING_NON_RL_AGENTS:
        return make_ml_fairness_lending_agent(
            env=env,
            agent=args.agent,
            burnin=args.burnin,
        )
    else:
        return MISC_AGENTS[agent](
            state_dim=state_dim,
            n_actions=n_actions,
            continuous_actions=continuous_actions,
        )


def make_ml_fairness_lending_agent(
    env: gym.Env,
    agent: str,
    burnin: int = 50,
) -> AbstractAgent:
    threshold_policy = THRESHOLD_POLICIES[agent]

    agent_params = classifier_agents.ScoringAgentParams(
        feature_keys=["credit_score"],
        group_key="group",
        default_action_fn=(lambda: AcceptRejectAction.ACCEPT.value),
        burnin=burnin,
        convert_one_hot_to_integer=True,
        threshold_policy=threshold_policy,
        skip_retraining_fn=lambda action, observation: action
        == AcceptRejectAction.REJECT.value,
        cost_matrix=CostMatrix(fn=0, fp=-1, tp=1, tn=0),
    )
    agent = oracle_lending_agent.OracleThresholdAgent(
        action_space=env.action_space,
        reward_fn=None,
        observation_space=env.observation_space,
        params=agent_params,
        env=env,
    )
    return agent
