from typing import Any, SupportsFloat, Union

import numpy as np

import gymnasium as gym
from gymnasium import spaces

from fair_gym.utils.utils import one_hot_encode, compute_recall, compute_precision
from fair_gym.utils.actions import AcceptRejectAction


# Distribution over the groups
GROUP_DISTRIBUTION = (0.5, 0.5)
# Mean of the score distribution for each group
SCORE_DISTRIBUTION_MEAN = (8, 6)
# Standard deviation of the score distribution for each group
SCORE_DISTRIBUTION_STD = (1, 1)
# Mean of the budget distribution for each group
BUDGET_DISTRIBUTION_MEAN = (4, 2)
# Standard deviation of the budget distribution for each group
BUDGET_DISTRIBUTION_STD = (1, 1)
# Success probability given the score
SUCCESS_PROB = (0.0, 0.0, 0.0, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)


class CollegeAdmissionEnv(gym.Env):
    """
    A simple college admission environment where the agent must decide whether or not
    to accept a college application. Applicants can manipulate their scores by paying
    a cost. The agent's goal is to maximize the overall accuracy.
    """

    def __init__(
        self,
        n_groups: int = 2,
        group_distribution: tuple[float] = GROUP_DISTRIBUTION,
        score_distribution_mean: tuple[int] = SCORE_DISTRIBUTION_MEAN,
        score_distribution_std: tuple[int] = SCORE_DISTRIBUTION_STD,
        budget_distribution_mean: tuple[int] = BUDGET_DISTRIBUTION_MEAN,
        budget_distribution_std: tuple[int] = BUDGET_DISTRIBUTION_STD,
        success_prob: tuple[float] = SUCCESS_PROB,
        population_size: int = 1000,
        max_score: int = 10,
        max_budget: int = 5,
        epsilon: float = 0.8,
    ) -> None:
        """
        Initialize the LendingEnv class with the given parameters.

        Args:
            n_groups (int): The number of groups.
            group_distribution (tuple[float]): The distribution over the groups.
            score_distribution_mean (tuple[float]): The mean of the score distribution for each group.
            score_distribution_std (tuple[float]): The standard deviation of the score distribution for each group.
            budget_distribution_mean (tuple[float]): The mean of the budget distribution for each group.
            budget_distribution_std (tuple[float]): The standard deviation of the budget distribution for each group.
            success_prob (tuple[float]): The success probability given the score.
            population_size (int): The population size.
            max_score (int): The maximum score.
            max_budget (float): The maximum budget.
            epsilon (float): The epsilon value for the epsilon-greedy policy for score manipulation.

        Returns:
            None
        """
        super().__init__()

        assert np.sum(group_distribution) == 1, "Group distribution must sum to 1."
        assert n_groups == len(
            group_distribution
        ), "Group distribution must have the same length as the number of groups."
        assert n_groups == len(
            score_distribution_mean
        ), "Score distribution mean must have the same length as the number of groups."
        assert n_groups == len(
            score_distribution_std
        ), "Score distribution std must have the same length as the number of groups."
        assert n_groups == len(
            budget_distribution_mean
        ), "Budget distribution mean must have the same length as the number of groups."
        assert n_groups == len(
            budget_distribution_std
        ), "Budget distribution std must have the same length as the number of groups."
        assert max_score == len(
            success_prob
        ), "Success probability must have the same length as the max score."

        self.n_groups = n_groups
        self.group_distribution = group_distribution
        self.population_size = population_size
        self.max_score = max_score
        self.max_budget = max_budget
        self.epsilon = epsilon
        self.success_prob = success_prob

        self.score_distribution_mean = score_distribution_mean
        self.score_distribution_std = score_distribution_std
        self.budget_distribution_mean = budget_distribution_mean
        self.budget_distribution_std = budget_distribution_std

        # Score changes are group-specific
        self.score_changes_coef = np.ones(self.n_groups).astype(int)

        # Observations are one-hot encoded vectors
        self.observation_space = spaces.Dict(
            {
                "group": spaces.Box(low=0, high=1, shape=(n_groups,), dtype=int),
                "score": spaces.Box(low=0, high=1, shape=(max_score,), dtype=int),
                "prev_applicant_acceptance": spaces.Box(
                    low=0, high=1, shape=(2,), dtype=int
                ),
            }
        )

        # Reject or accept the loan application
        self.action_space = spaces.Discrete(2)

        self._create_population()
        self._reset_metrics()

    def _create_population(self) -> None:
        """
        Create a population of applicants by sampling the group,
        score, and budget from the given distributions.

        Args:
            None

        Returns:
            None
        """
        self.population = []

        def _sample_group() -> int:
            """
            Sample a group from the group distribution.

            Args:
                None

            Returns:
                int: The group.
            """
            return np.random.choice(self.n_groups, p=self.group_distribution)

        def _sample_score(group: int) -> int:
            """
            Sample a score from the score distribution of the given group.

            Args:
                group (int): The group.

            Returns:
                int: The score.
            """
            score = int(
                np.random.normal(
                    self.score_distribution_mean[group],
                    self.score_distribution_std[group],
                )
            )
            return int(max(0, min(self.max_score - 1, score)))

        def _sample_budget(group) -> float:
            """
            Sample the budget from the budget distribution.

            Args:
                group (int): The group.

            Returns:
                float: The budget.
            """
            budget = int(
                np.random.normal(
                    self.budget_distribution_mean[group],
                    self.budget_distribution_std[group],
                )
            )

            return int(max(0.0, min(self.max_budget, budget)))

        for i in range(self.population_size):
            group = _sample_group()
            score = _sample_score(group)
            budget = _sample_budget(group)
            self.population.append(
                {
                    "applicant_id": i,
                    "group": group,
                    "score": score,
                    "true_score": score,
                    "budget": budget,
                    "cost_paid": 0,
                    "accepted": False,
                }
            )

    def _reset_metrics(self):
        """
        Reset the metrics.

        Args:
            None

        Returns:
            None
        """
        self.remaining_applicants = np.arange(self.population_size).tolist()
        self.cumulative_admissions = [0 for _ in range(self.n_groups)]
        self.cumulative_admissions_history = [[] for _ in range(self.n_groups)]
        self.tp = [0 for _ in range(self.n_groups)]
        self.fp = [0 for _ in range(self.n_groups)]
        self.fn = [0 for _ in range(self.n_groups)]
        self.tn = [0 for _ in range(self.n_groups)]
        self.recall = [0 for _ in range(self.n_groups)]
        self.precision = [0 for _ in range(self.n_groups)]

    def _sample_applicant(self) -> dict[str, Any]:
        """
        Sample an applicant from the population.

        Args:
            None

        Returns:
            dict[str, Any]: The applicant.
        """
        idx = np.random.choice(self.remaining_applicants)
        return self.population[idx]

    def _get_obs(
        self,
        curr_applicant_id: int,
        prev_applicant_id: Union[int, None] = None,
    ) -> dict[str, Any]:
        """
        Get the observation for the given applicant.

        Args:
            applicant (int): The applicant.
            prev_applicant (Union[int, None]): The previous applicant.

        Returns:
            dict[str, Any]: The observation.
        """
        return {
            "score": one_hot_encode(
                self.population[curr_applicant_id]["score"], self.max_score
            ),
            "group": one_hot_encode(
                self.population[curr_applicant_id]["group"], self.n_groups
            ),
            "prev_applicant_acceptance": (
                np.zeros(2, dtype=int)
                if prev_applicant_id is None
                else one_hot_encode(
                    int(self.population[prev_applicant_id]["accepted"]), 2
                )
            ),
        }

    def _get_info(self, will_succeed: Union[bool, None]) -> dict[str, Any]:
        """
        Get the info dictionary.

        Args:
            will_succeed (Union[bool, None]): Whether the applicant will succeed.

        Returns:
            dict[str, Any]: The info dictionary.
        """
        return {
            "recall": self.recall,
            "precision": self.precision,
            "success": will_succeed,
        }

    def _is_terminated(self) -> bool:
        """
        Check if the episode is terminated.

        Args:
            None

        Returns:
            bool: True if the episode is ended, False otherwise.
        """
        return len(self.remaining_applicants) <= 1

    def _get_distribution(self, data_key) -> tuple[tuple[float]]:
        """
        Get the distribution of the given data key for each group.

        Args:
            data_key (str): The data key.

        Returns:
            tuple[tuple[float]]: The distribution for each group.
        """
        assert data_key in ["score", "budget", "cost_paid"], "Invalid data key."

        distribution = []
        for group in range(self.n_groups):
            distribution.append(
                np.histogram(
                    [
                        applicant[data_key]
                        for applicant in self.population
                        if applicant["group"] == group
                    ],
                    bins=(
                        int(self.max_score) if data_key == "score" else self.max_budget
                    ),
                    range=(
                        (0, self.max_score)
                        if data_key == "score"
                        else (0, self.max_budget)
                    ),
                    density=True,
                )[0]
            )
        return tuple(distribution)

    def _get_average_cost_paid_by_accepted(self) -> tuple[tuple[float]]:
        """
        Get the average cost paid by the accepted applicants for each group.

        Args:
            None

        Returns:
            tuple[float]: The cost paid for each group.
        """
        cost_paid = []
        for group in range(self.n_groups):
            cost_paid.append(
                np.mean(
                    [
                        applicant["cost_paid"]
                        for applicant in self.population
                        if applicant["group"] == group and applicant["accepted"]
                    ]
                )
            )
        return tuple(cost_paid)

    def _get_acceptance_ratio(self) -> tuple[tuple[float]]:
        """
        Get the ratio of the accepted applicants for each group.

        Args:
            None

        Returns:
            tuple[float]: The ratio for each group.
        """
        acceptance_ratio = []
        for group in range(self.n_groups):
            total_applicants = len(
                [
                    applicant
                    for applicant in self.population
                    if applicant["group"] == group
                ]
            )
            accepted_applicants = len(
                [
                    applicant
                    for applicant in self.population
                    if applicant["group"] == group and applicant["accepted"]
                ]
            )
            acceptance_ratio.append(accepted_applicants / total_applicants)
        return tuple(acceptance_ratio)

    def _manipulate_score(self, applicant_id) -> None:
        """
        Manipulate the score of the applicant by paying a cost.
        Leaves the true score unchanged.

        Args:
            applicant_id (int): The applicant.

        Returns:
            None
        """
        group = self.population[applicant_id]["group"]
        score = self.population[applicant_id]["score"]
        budget = self.population[applicant_id]["budget"]

        if np.random.rand() < self.epsilon and budget > 0:
            # Manipulate the score
            max_budget_to_use = min(budget, self.max_score - score - 1)
            cost = (
                np.random.randint(1, max_budget_to_use + 1)
                if max_budget_to_use > 1
                else 1
            )

            score += cost * self.score_changes_coef[group]
            score = max(0.0, min(self.max_score - 1, score))

            self.population[applicant_id]["score"] = score
            self.population[applicant_id]["budget"] -= cost
            self.population[applicant_id]["cost_paid"] += cost

    def reset(
        self,
        seed: Union[None, int] = None,
        options: Union[None, dict[str, Any]] = None,
    ) -> tuple[spaces.Dict, dict[str, Any]]:
        """
        Reset the environment to its initial state.

        Args:
            seed (Union[None, int]): The random seed.
            options (Union[None, dict[str, Any]]): The options.

        Returns:
            tuple[spaces.Dict, dict[str, Any]]: The observation and the info.
        """
        super().reset(seed=seed, options=options)

        # Reset the bank's cash, population, and metrics
        self._create_population()
        self._reset_metrics()

        # Save initial score distribution
        self.initial_score_distribution = self._get_distribution(data_key="score")
        self.initial_budget_distribution = self._get_distribution(data_key="budget")

        # Sample an applicant and return the observation
        self.curr_applicant = self._sample_applicant()
        obs = self._get_obs(self.curr_applicant["applicant_id"], None)
        info = self._get_info(will_succeed=False)
        return obs, info

    def step(
        self, action: spaces.Discrete
    ) -> tuple[spaces.Dict, SupportsFloat, bool, bool, dict[str, Any]]:
        """
        Step the environment by taking the given action.

        Args:
            action (spaces.Discrete): The action.

        Returns:
            tuple[spaces.Dict, SupportsFloat, bool, bool, dict[str, Any]]:
                The observation, reward, terminated flag, truncated flag, and info.
        """
        applicant_id = self.curr_applicant["applicant_id"]
        true_score = self.curr_applicant["true_score"]
        group = self.curr_applicant["group"]
        success_prob = self.success_prob[true_score]

        if action == AcceptRejectAction.ACCEPT.value:
            self.cumulative_admissions[group] += 1
            self.population[applicant_id]["accepted"] = True
            self.remaining_applicants.remove(applicant_id)

            if np.random.rand() < success_prob:
                # Applicant succeeds in the exam
                reward = 1
                will_succeed = True
                self.tp[group] += 1
            else:
                # Applicant fails the exam
                reward = -1
                will_succeed = False
                self.fp[group] += 1
        else:
            reward = 0
            if np.random.rand() < success_prob:
                # Applicant would have succeeded in the exam
                will_succeed = True
                self.fn[group] += 1
            else:
                # Applicant would have failed the exam
                will_succeed = False
                self.tn[group] += 1

        # Update the cumulative admissions history for all groups
        for i in range(self.n_groups):
            self.cumulative_admissions_history[i].append(self.cumulative_admissions[i])

        self.curr_applicant = self._sample_applicant()
        self._manipulate_score(applicant_id)

        # Update the recall and precision
        self.recall = compute_recall(self.tp, self.fn)
        self.precision = compute_precision(self.tp, self.fp)

        obs = self._get_obs(self.curr_applicant["applicant_id"], applicant_id)
        info = self._get_info(will_succeed=will_succeed)
        terminated = self._is_terminated()

        return obs, reward, terminated, False, info

    def set_score_changes(
        self,
        score_changes_coef: Union[list[int], np.ndarray],
    ):
        """
        Set the credit changes for each group.

        Args:
            score_changes_coef (Union[list[int], np.ndarray]): The credit score changes.

        Returns:
            None
        """
        assert (
            len(score_changes_coef) == self.n_groups
        ), "The number of score changes must be equal to the number of groups."
        assert np.all(
            np.array(score_changes_coef) >= 0
        ), "The credit score changes must be non-negative."

        self.score_changes_coef = np.array(score_changes_coef).astype(int)
