"""
Utilities for the Lending environment.
"""

from typing import Any, List, Tuple

import torch
import numpy as np
import gymnasium as gym

import fair_gym
import utils.env_consts as consts
from fair_gym import (
    LendingEnv,
    LendingMetrics,
    CollegeAdmissionEnv,
    CollegeAdmissionMetrics,
)


def get_observation_dim(observation_space: gym.Space, keys: List) -> int:
    """
    Get the dimension of the observation space based on the keys provided.

    Args:
        observation_space (gym.Space): The observation space object.
        keys (List): The list of keys to consider.

    Returns:
        int: The dimension of the observation space.
    """
    obs_dim = 0
    for k in keys:
        v = observation_space[k]
        if isinstance(v, gym.spaces.Box):
            obs_dim += np.prod(v.shape)
        elif isinstance(v, gym.spaces.Discrete):
            obs_dim += 1
        elif isinstance(v, gym.spaces.Dict):
            obs_dim += get_observation_dim(v)
        else:
            raise NotImplementedError
    return obs_dim


def preprocess_obs(obs: dict, keys: List) -> List:
    """
    Preprocess the observation based on the keys provided.

    Args:
        observation (dict): The observation dictionary.
        keys (List): The list of keys to consider.

    Returns:
        List: The flattened observation.
    """
    flattened_obs = []
    for k in keys:
        v = obs[k]
        if isinstance(v, np.ndarray):
            flattened_obs.extend(v.reshape(-1))
        elif isinstance(v, dict):
            flattened_obs.extend(preprocess_obs(v))
        else:
            raise NotImplementedError
    return flattened_obs


def preprocess_lending_obs(
    obs: dict[str:Any],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Preprocess the Lending environment observation.

    Args:
        obs (dict[str, Any]): The observation dictionary.

    Returns:
        Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: The preprocessed observations.
    """
    state = preprocess_obs(obs, consts.LENDING_STATE_KEYS)
    state_without_group = preprocess_obs(obs, consts.LENDING_STATE_WITHOUT_GROUP_KEY)
    group = preprocess_obs(obs, consts.LENDING_GROUP_KEY)
    actual_next_state = preprocess_obs(obs, consts.LENDING_PREV_APPLICANT_NEXT_STATE_KEY)

    return (
        torch.Tensor(state),
        torch.Tensor(state_without_group),
        torch.Tensor(actual_next_state),
        torch.Tensor(group),
    )


def preprocess_college_admission_obs(
    obs: dict[str, Any],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Preprocess the College Admission environment observation.

    Args:
        obs (dict[str, Any]): The observation dictionary.

    Returns:
        Tuple[torch.Tensor]: The preprocessed observations.
    """
    state = preprocess_obs(obs, consts.COLLEGE_ADMISSION_STATE_KEYS)
    state_without_group = preprocess_obs(obs, consts.COLLEGE_ADMISSION_STATE_WITHOUT_GROUP_KEY)
    group = preprocess_obs(obs, consts.COLLEGE_ADMISSION_GROUP_KEY)
    actual_next_state = preprocess_obs(obs, consts.COLLEGE_PREV_APPLICANT_NEXT_STATE_KEY)

    return (
        torch.Tensor(state),
        torch.Tensor(state_without_group),
        torch.Tensor(actual_next_state),
        torch.Tensor(group),
    )


def make_env_and_metrics(args):
    """
    Create the environment and the metrics based on the arguments provided.

    Args:
        args (Namespace): The arguments.

    Returns:
        Tuple[gym.Env, Any]: The environment and the metrics.
    """
    if args.env == "lending":
        env = make_lending_env(
            consts.GROUP_DISTRIBUTION,
            consts.CREDIT_SCORE_DISTRIBUTION,
            consts.LENDING_SUCCESS_PROB,
            args.success_func,
            args.cons_mean,
            args.population_size,
            args.max_episode_steps,
        )
        metrics = make_lending_metrics(env)
    elif args.env == "college":
        env = make_college_admission_env(
            consts.GROUP_DISTRIBUTION,
            args.epsilon,
            args.population_size,
            args.max_episode_steps,
        )
        metrics = make_college_admission_metrics(env)
    else:
        raise NotImplementedError

    return env, metrics


def make_lending_env(
    group_distribution: tuple[float],
    credit_score_distribution: tuple[tuple[float]],
    success_probability: tuple[float],
    success_func: bool = str,
    cons_mean: float = 0.5,
    population_size: int = 1000,
    max_episode_steps: int = 1000,
) -> LendingEnv:
    """
    Create the Lending environment with the specified parameters.

    Args:
        group_distribution (tuple[float]): The distribution of groups.
        credit_score_distribution (tuple[tuple[float]]): The distribution of credit scores.
        success_probability (tuple[float]): The success probability of groups.
        success_func (bool): The success function to use.
        cons_mean (float): The mean of the conscientiousness distribution.
        population_size (int): The size of the population.
        max_episode_steps (int): The maximum number of steps in an episode.

    Returns:
        LendingEnv: The Lending environment.
    """
    env_kwargs = {
        "n_groups": len(group_distribution),
        "group_distribution": group_distribution,
        "credit_score_distribution": credit_score_distribution,
        "success_probability": success_probability,
        "success_func": success_func,
        "cons_mean": cons_mean,
        "population_size": population_size,
        "max_credit": len(success_probability),
    }

    env = gym.make(
        "fair_gym/LendingEnv", max_episode_steps=max_episode_steps, **env_kwargs
    )
    return env


def make_lending_metrics(env: LendingEnv) -> LendingMetrics:
    """
    Create the Lending environment metrics.

    Args:
        env (LendingEnv): The Lending environment.

    Returns:
        LendingMetrics: The Lending environment metrics.
    """
    metrics = LendingMetrics(env)
    return metrics


def make_college_admission_env(
    group_distribution: tuple[float], 
    epsilon: float = 0.7,
    population_size: int = 1000,
    max_episode_steps: int = 1000,
) -> CollegeAdmissionEnv:
    """
    Create the College Admission environment with the specified parameters.

    Args:
        group_distribution (tuple[float]): The distribution of groups.
        epsilon (float): The epsilon value for the epsilon-greedy policy.
        population_size (int): The size of the population.
        max_episode_steps (int): The maximum number of steps in an episode.

    Returns:
        CollegeAdmissionEnv: The College Admission environment.
    """
    env_kwargs = {
        "n_groups": len(group_distribution),
        "group_distribution": group_distribution,
        "score_distribution_mean": consts.SCORE_DISTRIBUTION_MEAN,
        "score_distribution_std": consts.SCORE_DISTRIBUTION_STD,
        "budget_distribution_mean": consts.BUDGET_DISTRIBUTION_MEAN,
        "budget_distribution_std": consts.BUDGET_DISTRIBUTION_STD,
        "success_prob": consts.COLLEGE_SUCCESS_PROB,
        "max_budget": consts.MAX_BUDGET,
        "epsilon": epsilon,
        "population_size": population_size,
    }

    env = gym.make(
        "fair_gym/CollegeAdmissionEnv", max_episode_steps=max_episode_steps, **env_kwargs
    )
    return env


def make_college_admission_metrics(env: CollegeAdmissionEnv) -> CollegeAdmissionMetrics:
    """
    Create the College Admission environment metrics.

    Args:
        env (CollegeAdmissionEnv): The College Admission environment.

    Returns:
        CollegeAdmissionMetrics: The College Admission environment metrics.
    """
    metrics = CollegeAdmissionMetrics(env)
    return metrics
