from typing import Union, Any

import numpy as np
from scipy.stats import wasserstein_distance

import gymnasium as gym
from fair_gym.envs import CollegeAdmissionEnv
from fair_gym.utils.utils import compute_recall, compute_precision


class CollegeAdmissionMetrics:
    """
    A class to calculate college admission metrics based on the environment.
    The metrics resests when the environment is reset.
    """

    def __init__(self, env: Union[CollegeAdmissionEnv, gym.Wrapper]) -> None:
        """
        Initialize the CollegeAdmissionMetrics class with the given environment.

        Args:
            env (CollegeAdmissionEnv): The CollegeAdmissionEnv environment.

        Returns:
            None
        """
        self.env = env

    def get_current_score_distribution(self) -> tuple[tuple[float]]:
        """
        Get the current score distribution for each group.

        Returns:
            tuple[tuple[float]]: The score distribution for each group.
        """
        if isinstance(self.env, CollegeAdmissionEnv):
            return self.env._get_distribution("score")
        else:
            return self.env.unwrapped._get_distribution("score")

    def get_initial_score_distribution(self) -> tuple[tuple[float]]:
        """
        Get the initial score distribution for each group.

        Returns:
            tuple[tuple[float]]: The initial score distribution for each group.
        """
        if isinstance(self.env, CollegeAdmissionEnv):
            return self.env.initial_score_distribution
        else:
            return self.env.unwrapped.initial_score_distribution

    def get_budget_distribution(self) -> tuple[tuple[float]]:
        """
        Get the budget distribution for each group.

        Returns:
            tuple[tuple[float]]: The budget distribution for each group.
        """
        if isinstance(self.env, CollegeAdmissionEnv):
            return self.env._get_distribution("budget")
        else:
            return self.env.unwrapped._get_distribution("budget")

    def get_initial_budget_distribution(self) -> tuple[tuple[float]]:
        """
        Get the initial budget distribution for each group.

        Returns:
            tuple[tuple[float]]: The initial budget distribution for each group.
        """
        if isinstance(self.env, CollegeAdmissionEnv):
            return self.env.initial_budget_distribution
        else:
            return self.env.unwrapped.initial_budget_distribution

    def get_acceptance_ratio(self) -> tuple[float]:
        """
        Get the accepted distribution for each group.

        Returns:
            tuple[float]: The accepted distribution for each group.
        """
        if isinstance(self.env, CollegeAdmissionEnv):
            return self.env._get_acceptance_ratio()
        else:
            return self.env.unwrapped._get_acceptance_ratio()

    def get_cumulative_admissions(self) -> list[list[int]]:
        """
        Get the cumulative admissions for each group.

        Args:
            None

        Returns:
            list[list[int]]: The cumulative admissions for each group.
        """
        if isinstance(self.env, CollegeAdmissionEnv):
            return self.env.cumulative_admissions_history
        else:
            cumulative_admissions_history = (
                self.env.unwrapped.cumulative_admissions_history
            )
            episode_steps = self.env.get_wrapper_attr("_max_episode_steps")
            # fill the rest of the steps with the last value
            for _ in range(len(cumulative_admissions_history[0]), episode_steps):
                for group in range(self.env.unwrapped.n_groups):
                    cumulative_admissions_history[group].append(
                        cumulative_admissions_history[group][-1]
                    )
            return cumulative_admissions_history

    def get_average_cost_paid_by_accepted(self) -> list[float]:
        """
        Get the average cost for accepted students for each group.

        Args:
            None

        Returns:
            list[float]: The average cost for accepted students for each group.
        """
        if isinstance(self.env, CollegeAdmissionEnv):
            return self.env._get_average_cost_paid_by_accepted()
        else:
            return self.env.unwrapped._get_average_cost_paid_by_accepted()

    def get_recall(self) -> list[float]:
        """
        Get the recall for each group.

        Args:
            None

        Returns:
            list[float]: The recall for each group.
        """
        if isinstance(self.env, CollegeAdmissionEnv):
            return compute_recall(self.env.tp, self.env.fn)
        else:
            return compute_recall(self.env.unwrapped.tp, self.env.unwrapped.fn)

    def get_precision(self) -> list[float]:
        """
        Get the precision for each group.

        Args:
            None

        Returns:
            list[float]: The precision for each group.
        """
        if isinstance(self.env, CollegeAdmissionEnv):
            return compute_precision(self.env.tp, self.env.fp)
        else:
            return compute_precision(self.env.unwrapped.tp, self.env.unwrapped.fp)

    def get_wasserstein_distance(self, distributions) -> dict[str, float]:
        """
        Get the pair-wise Wasserstein distances between the given distributions.

        Args:
            distributions (list[list[float]]): The distributions.

        Returns:
            dict[str, float]: The pair-wise Wasserstein distances.
        """
        w_distances = {}
        for i in range(len(distributions)):
            for j in range(i + 1, len(distributions)):
                w_distances[f"group_{i + 1}_vs_group_{j + 1}"] = wasserstein_distance(
                    np.arange(len(distributions[i])),
                    np.arange(len(distributions[j])),
                    distributions[i],
                    distributions[j],
                )
        return w_distances

    def get_all_metrics(self) -> dict[str, Any]:
        """
        Get all the metrics.

        Args:
            None

        Returns:
            dict[str, Any]: The metrics.
        """
        # Get the score distributions
        initial_score_distributions = self.get_initial_score_distribution()
        final_score_distributions = self.get_current_score_distribution()

        # Get the budget distributions
        initial_budget_distributions = self.get_initial_budget_distribution()
        final_budget_distributions = self.get_budget_distribution()

        # Get the accepted distributions
        acceptance_ratio = self.get_acceptance_ratio()

        return {
            "initial_score_distribution": initial_score_distributions,
            "final_score_distribution": final_score_distributions,
            "initial_budget_distribution": initial_budget_distributions,
            "final_budget_distribution": final_budget_distributions,
            "average_cost_paid_by_accepted": self.get_average_cost_paid_by_accepted(),
            "acceptance_ratio": acceptance_ratio,
            "cumulative_admissions": self.get_cumulative_admissions(),
            "recall": self.get_recall(),
            "precision": self.get_precision(),
        }
