from typing import Union, Any

import numpy as np
from scipy.stats import wasserstein_distance

import gymnasium as gym
from fair_gym.envs import LendingEnv
from fair_gym.utils.utils import compute_recall, compute_precision


class LendingMetrics:
    """
    A class to calculate lending metrics based on the environment.
    The metrics resests when the environment is reset.
    """

    def __init__(self, env: Union[LendingEnv, gym.Wrapper]) -> None:
        """
        Initialize the LendingMetrics class with the given environment.

        Args:
            env (LendingEnv): The lending environment.

        Returns:
            None
        """
        self.env = env

    def get_current_credit_distribution(
        self, latent_credit_score: bool = False
    ) -> tuple[tuple[float]]:
        """
        Get the credit score distribution for each group.

        Args:
            latent_credit_score (bool): Whether to get the latent credit score distribution.

        Returns:
            tuple[tuple[float]]: The credit score distribution for each group.
        """
        if isinstance(self.env, LendingEnv):
            return self.env._get_current_credit_score_distribution(
                latent_credit_score=latent_credit_score
            )
        else:
            return self.env.unwrapped._get_current_credit_score_distribution(
                latent_credit_score=latent_credit_score
            )

    def get_initial_credit_distribution(
        self, latent_credit_score: bool = False
    ) -> tuple[tuple[float]]:
        """
        Get the initial credit score distribution for each group.

        Args:
            latent_credit_score (bool): Whether to get the latent credit score distribution.

        Returns:
            tuple[tuple[float]]: The initial credit score distribution for each group.
        """
        if isinstance(self.env, LendingEnv):
            return (
                self.env.initial_credit_score_distribution
                if not latent_credit_score
                else self.env.initial_latent_credit_score_distribution
            )
        else:
            return (
                self.env.unwrapped.initial_credit_score_distribution
                if not latent_credit_score
                else self.env.unwrapped.initial_latent_credit_score_distribution
            )

    def get_cons_distribution(self) -> tuple[tuple[float]]:
        """
        Get the conscientiousness distribution for each group.

        Args:
            None

        Returns:
            tuple[tuple[float]]: The conscientiousness distribution for each group.
        """
        if isinstance(self.env, LendingEnv):
            return self.env._get_distribution("cons")
        else:
            return self.env.unwrapped._get_distribution("cons")

    def get_loans_given_distribution(self) -> tuple[tuple[float]]:
        """
        Get the loans given distribution for each group.

        Args:
            None

        Returns:
            tuple[tuple[float]]: The loans given distribution for each group.
        """
        if isinstance(self.env, LendingEnv):
            return self.env._get_distribution("loans_given")
        else:
            return self.env.unwrapped._get_distribution("loans_given")

    def get_loans_repaid_distribution(self) -> tuple[tuple[float]]:
        """
        Get the loans repaid distribution for each group.

        Args:
            None

        Returns:
            tuple[tuple[float]]: The loans repaid distribution for each group.
        """
        if isinstance(self.env, LendingEnv):
            return self.env._get_distribution("loans_repaid")
        else:
            return self.env.unwrapped._get_distribution("loans_repaid")

    def get_loans_defaulted_distribution(self) -> tuple[tuple[float]]:
        """
        Get the loans defaulted distribution for each group.

        Args:
            None

        Returns:
            tuple[tuple[float]]: The loans defaulted distribution for each group.
        """
        if isinstance(self.env, LendingEnv):
            return self.env._get_distribution("loans_defaulted")
        else:
            return self.env.unwrapped._get_distribution("loans_defaulted")

    def get_profit_rate(self) -> float:
        """
        Get the bank's profit.

        Args:
            None

        Returns:
            float: The bank's cash.
        """
        if isinstance(self.env, LendingEnv):
            return (
                self.env.bank_cash - self.env.bank_starting_cash
            ) / self.env.bank_starting_cash
        else:
            return (
                self.env.unwrapped.bank_cash - self.env.unwrapped.bank_starting_cash
            ) / self.env.unwrapped.bank_starting_cash

    def get_cumulative_loan(self) -> list[list[int]]:
        """
        Get the cumulative loan for each group.

        Args:
            None

        Returns:
            list[list[int]]: The cumulative loan for each group.
        """
        if isinstance(self.env, LendingEnv):
            return self.env.cumulative_loan_history
        else:
            return self.env.unwrapped.cumulative_loan_history

    def get_recall(self) -> list[float]:
        """
        Get the recall for each group.

        Args:
            None

        Returns:
            list[float]: The recall for each group.
        """
        if isinstance(self.env, LendingEnv):
            return compute_recall(self.env.tp, self.env.fn)
        else:
            return compute_recall(self.env.unwrapped.tp, self.env.unwrapped.fn)

    def get_precision(self) -> list[float]:
        """
        Get the precision for each group.

        Args:
            None

        Returns:
            list[float]: The precision for each group.
        """
        if isinstance(self.env, LendingEnv):
            return compute_precision(self.env.tp, self.env.fp)
        else:
            return compute_precision(self.env.unwrapped.tp, self.env.unwrapped.fp)

    def get_wasserstein_distance(self, distributions) -> dict[str, float]:
        """
        Get the pair-wise Wasserstein distances between the given distributions.

        Args:
            distributions (list[list[float]]): The distributions.

        Returns:
            dict[str, float]: The pair-wise Wasserstein distances.
        """
        w_distances = {}
        for i in range(len(distributions)):
            for j in range(i + 1, len(distributions)):
                w_distances[f"group_{i + 1}_vs_group_{j + 1}"] = wasserstein_distance(
                    np.arange(len(distributions[i])),
                    np.arange(len(distributions[j])),
                    distributions[i],
                    distributions[j],
                )
        return w_distances

    def get_all_metrics(self) -> dict[str, Any]:
        """
        Get all the metrics.
        Note that some of the metrics like loans_given, loans_repaid,
        loans_defaulted are not included in the returned metrics.

        Args:
            None

        Returns:
            dict[str, Any]: The metrics.
        """
        # Get the observable credit score distributions
        initial_credit_distribution = self.get_initial_credit_distribution(
            latent_credit_score=False
        )
        final_credit_distribution = self.get_current_credit_distribution(
            latent_credit_score=False
        )
        initial_w_distances = self.get_wasserstein_distance(initial_credit_distribution)
        final_w_distances = self.get_wasserstein_distance(final_credit_distribution)

        # Get the latent credit score distributions
        initial_latent_credit_distribution = self.get_initial_credit_distribution(
            latent_credit_score=True
        )
        final_latent_credit_distribution = self.get_current_credit_distribution(
            latent_credit_score=True
        )
        initial_latent_w_distances = self.get_wasserstein_distance(
            initial_latent_credit_distribution
        )
        final_latent_w_distances = self.get_wasserstein_distance(
            final_latent_credit_distribution
        )

        return {
            "initial_credit_distribution": initial_credit_distribution,
            "final_credit_distribution": final_credit_distribution,
            "initial_w_distances": initial_w_distances,
            "final_w_distances": final_w_distances,
            "initial_latent_credit_distribution": initial_latent_credit_distribution,
            "final_latent_credit_distribution": final_latent_credit_distribution,
            "initial_latent_w_distances": initial_latent_w_distances,
            "final_latent_w_distances": final_latent_w_distances,
            "cons_distribution": self.get_cons_distribution(),
            "profit_rate": self.get_profit_rate(),
            "cumulative_loan": self.get_cumulative_loan(),
            "recall": self.get_recall(),
            "precision": self.get_precision(),
        }
