"""
Utilities for evaluating the agent in the Lending environment.
"""

from typing import Any, Union
import numpy as np
import random
from collections import defaultdict

from torch.utils.tensorboard import SummaryWriter

from fair_gym import (
    LendingEnv,
    LendingMetrics,
)

from agents import AbstractAgent, OracleThresholdAgent
from utils.env_utils import preprocess_obs


def evaluate_policy(
    env: Union[LendingEnv],
    agent: AbstractAgent,
    metrics: Union[LendingMetrics],
    episode_length: int,
    state_keys: list[str],
    eval_count: int = 1,
) -> tuple[float, dict]:
    """
    Evaluate the agent's policy in the Lending environment.

    Args:
        env (Union[LendingEnv]): The Lending environment.
        agent (AbstractAgent): The agent to evaluate.
        metrics (Union[LendingMetrics]): The metric tracker.
        episode_length (int): The maximum number of timesteps per episode.
        state_keys (list[str]): The keys of the state dictionary.
        eval_count (int): The number of episodes to evaluate.

    Returns:
        tuple[float, dict]: The average return and the evaluation results.
    """
    total_reward = 0
    results = []

    for e in range(eval_count):
        s, _ = env.reset()
        episode_reward, terminated, truncated = 0, False, False

        for i in range(episode_length):
            if not isinstance(agent, OracleThresholdAgent):
                s = preprocess_obs(s, keys=state_keys)
            
            a = agent.act(s)
            s, r, terminated, truncated, _ = env.step(a)
            episode_reward += r
            if terminated or truncated:
                break
        total_reward += episode_reward

        # Flatten the nested dictionary
        flattened_results = flatten_dict(metrics.get_all_metrics())
        results.append(flattened_results)

    average_return = total_reward / eval_count
    return average_return, results


def record_lending_evaluation(
    eval_logs: dict[str:Any],
    writer: SummaryWriter,
    results: list[dict],
    global_step: int,
    average_return: float,
    tag: str = "eval",
) -> dict[str, Any]:
    """
    Record evaluation logs to TensorBoard and a dictionary.
    Averages the list of dict along each key.

    Args:
        eval_logs (dict[str,Any]): The evaluation logs.
        writer (SummaryWriter): The TensorBoard writer.
        results (list[dict]): The evaluation results.
        global_step (int): The global step.
        average_return (float): The average return.

    Returns:
        dict[str,Any]: The updated evaluation logs.
    """
    # Average the results
    avg_results = average_dicts(results)

    writer.add_scalar(f"{tag}/average_return", average_return, global_step)
    eval_logs["average_return"].append(average_return)
    eval_logs["global_step"].append(global_step)

    for key in avg_results.keys():
        if key in ["recall", "precision"]:
            for g in range(len(avg_results[key])):
                writer.add_scalar(
                    f"{tag}/{key}_group_{g + 1}", avg_results[key][g], global_step
                )
                eval_logs[f"{key}_group_{g + 1}"].append(avg_results[key][g])
        elif "w_distances" in key:
            writer.add_scalar(f"{tag}/{key}", avg_results[key], global_step)
            eval_logs[key].append(avg_results[key])

    return eval_logs


def record_college_admission_evaluation(
    eval_logs: dict[str, Any],
    writer: SummaryWriter,
    results: list[dict],
    global_step: int,
    average_return: float,
    tag: str = "eval",
) -> dict[str, Any]:
    """
    Record evaluation logs to TensorBoard and a dictionary.
    Averages the list of dict along each key.

    Args:
        eval_logs (dict[str,Any]): The evaluation logs.
        writer (SummaryWriter): The TensorBoard writer.
        results (list[dict]): The evaluation results.
        global_step (int): The global step.
        average_return (float): The average return.

    Returns:
        dict[str,Any]: The updated evaluation logs.
    """
    # Average the results
    avg_results = average_dicts(results)

    writer.add_scalar(f"{tag}/average_return", average_return, global_step)
    eval_logs["average_return"].append(average_return)
    eval_logs["global_step"].append(global_step)

    for key in avg_results.keys():
        if key in [
            "recall",
            "precision",
            "acceptance_ratio",
            "average_cost_paid_by_accepted",
        ]:
            for g in range(len(avg_results[key])):
                writer.add_scalar(
                    f"{tag}/{key}_group_{g + 1}", avg_results[key][g], global_step
                )
                eval_logs[f"{key}_group_{g + 1}"].append(avg_results[key][g])

    return eval_logs


def average_dicts(dicts: list[dict]) -> dict:
    """
    Gets a list of dictionaries and returns a dictionary with the average of
    the values for each key. Values can be either scalars or list.

    Args:
        dicts (list[dict]): The list of dictionaries.

    Returns:
        dict: The dictionary with the average values.
    """
    avg_dict = defaultdict(list)
    for d in dicts:
        for k, v in d.items():
            avg_dict[k].append(v)
    for k, v in avg_dict.items():
        avg_dict[k] = np.mean(v, axis=0)
    return dict(avg_dict)


def flatten_dict(d: dict, parent_key: str = "", sep: str = "_"):
    """
    Flattens a dictionary of dictionaries.

    Args:
        d (dict): The dictionary to flatten.
        parent_key (str): The parent key.
        sep (str): The separator.

    Returns:
        dict: The flattened dictionary.
    """
    items = []
    for k, v in d.items():
        new_key = f"{parent_key}{sep}{k}" if parent_key else k
        if isinstance(v, dict):
            items.extend(flatten_dict(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)
