from typing import Any, SupportsFloat, Union, Literal

import numpy as np

import gymnasium as gym
from gymnasium import spaces

from fair_gym.utils.utils import one_hot_encode, compute_recall, compute_precision
from fair_gym.utils.actions import AcceptRejectAction


# Distribution over the groups
GROUP_DISTRIBUTION = (0.5, 0.5)
# Likelihoods of credit score given group membership
CREDIT_SCORE_DISTRIBUTION = (
    (0.0, 0.0, 0.0, 0.2, 0.3, 0.3, 0.2),
    (0.2, 0.2, 0.3, 0.2, 0.1, 0.0, 0.0),
)
# Likelihoods of loan payment success based on credit score
SUCCESS_PROBABILITY = (0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)
# Conscientiousness distribution
CONSCIENTIOUSNESS_DISTRIBUTIONS = ["uniform", "normal"]


class LendingEnv(gym.Env):
    """
    A simple lending environment where the agent must decide whether or not
    to accept a loan application. The agent's goal is to maximize the overall
    return on investment.
    """

    def __init__(
        self,
        n_groups: int = 2,
        success_func: Literal[
            "credit", "cons", "credit_cons", "latent_credit", "latent_credit_cons"
        ] = "credit",
        group_distribution: tuple[float] = GROUP_DISTRIBUTION,
        credit_score_distribution: tuple[tuple[float]] = CREDIT_SCORE_DISTRIBUTION,
        success_probability: tuple[float] = SUCCESS_PROBABILITY,
        cons_distribution: Literal["uniform", "normal"] = "normal",
        cons_mean: float = 0.5,
        cons_std: float = 0.1,
        population_size: int = 1000,
        max_credit: int = 7,
        loan_amount: int = 1,
        interest_rate: float = 0.3,
        bank_starting_cash: int = 1000000,
    ) -> None:
        """
        Initialize the LendingEnv class with the given parameters.

        Args:
            n_groups (int): The number of groups.
            success_func (Literal["credit", "cons", "credit_cons", "latent_credit", "latent_credit_cons"]): The success function.
            group_distribution (tuple): The distribution of groups.
            credit_score_distribution (tuple[tuple[float]]): The distribution of credit scores.
            success_probability (tuple[float]): The probability of loan payment success based on credit score.
            cons_distribution (Literal["uniform", "normal"]): The distribution of conscientiousness.
            cons_mean (float): The mean of the normal distribution of conscientiousness.
            cons_std (float): The standard deviation of the normal distribution of conscientiousness.
            population_size (int): The size of the population.
            max_credit (int): The maximum credit score.
            loan_amount (int): The amount of the loan.
            interest_rate (float): The interest rate of the loan.
            bank_starting_cash (int): The starting cash of the bank.

        Returns:
            None
        """
        super().__init__()

        assert n_groups == len(
            group_distribution
        ), "Group distribution must have the same length as the number of groups."
        assert np.sum(group_distribution) == 1, "Group distribution must sum to 1."
        assert n_groups == len(
            credit_score_distribution
        ), "Credit score distribution must have the same length as the number of groups."
        assert all(
            np.isclose(np.sum(group), 1) for group in credit_score_distribution
        ), "Credit score distribution must sum to 1."
        assert max_credit == len(
            success_probability
        ), "Success probability must have the same length as the number of groups."
        assert (
            cons_distribution in CONSCIENTIOUSNESS_DISTRIBUTIONS
        ), "Invalid conscientiousness distribution."

        self.n_groups = n_groups
        self.success_func = success_func
        self.group_distribution = group_distribution
        self.credit_score_distribution = credit_score_distribution
        self.success_probability = success_probability
        self.cons_distribution = cons_distribution
        self.cons_mean = cons_mean
        self.cons_std = cons_std

        self.population_size = population_size
        self.max_credit = max_credit
        self.loan_amount = loan_amount
        self.interest_rate = interest_rate
        self.bank_starting_cash = bank_starting_cash
        self.bank_cash = bank_starting_cash

        # Credit changes are group-specific
        self.positive_credit_changes = np.ones(self.n_groups).astype(int)
        self.negative_credit_changes = np.ones(self.n_groups).astype(int)

        # Observations are one-hot encoded vectors
        self.observation_space = spaces.Dict(
            {
                "credit_score": spaces.Box(
                    low=0, high=1, shape=(max_credit,), dtype=int
                ),
                "group": spaces.Box(low=0, high=1, shape=(n_groups,), dtype=int),
                "prev_applicant_new_credit_score": spaces.Box(
                    low=0, high=1, shape=(max_credit,), dtype=int
                ),
                "loans_repaid": spaces.Box(low=0, high=1, shape=(1,), dtype=float),
                "loans_defaulted": spaces.Box(low=0, high=1, shape=(1,), dtype=float),
            }
        )

        # Reject or accept the loan application
        self.action_space = spaces.Discrete(2)

        self._create_population()
        self._reset_metrics()

    def _create_population(self) -> None:
        """
        Create a population of applicants by sampling the group and credit score
        from the given distributions.

        Args:
            None

        Returns:
            None
        """
        self.population = []

        def _sample_group() -> int:
            """
            Sample a group from the group distribution.

            Args:
                None

            Returns:
                int: The group.
            """
            return np.random.choice(self.n_groups, p=self.group_distribution)

        def _sample_credit_score(group: int) -> int:
            """
            Sample a credit score from the credit score distribution of the given group.

            Args:
                group (int): The group.

            Returns:
                int: The credit score.
            """
            return np.random.choice(
                self.max_credit, p=self.credit_score_distribution[group]
            )

        def _sample_conscientiousness() -> float:
            """
            Sample consientiousness from the conscientiousness distribution.

            Args:
                None

            Returns:
                float: The conscientiousness.
            """
            if self.cons_distribution == "uniform":
                return np.random.uniform(0, 1)
            elif self.cons_distribution == "normal":
                cons = np.random.normal(self.cons_mean, self.cons_std)
                return max(0.0, min(1.0, cons))
            else:
                raise ValueError("Invalid conscientiousness distribution.")

        for i in range(self.population_size):
            group = _sample_group()
            credit_score = _sample_credit_score(group)
            cons = _sample_conscientiousness()
            self.population.append(
                {
                    "applicant_id": i,
                    "group": group,
                    "credit_score": credit_score,
                    "latent_credit_score": credit_score,
                    "cons": cons,
                    "loans_given": np.zeros(1, dtype=float),
                    "loans_repaid": np.zeros(1, dtype=float),
                    "loans_defaulted": np.zeros(1, dtype=float),
                }
            )

    def _reset_metrics(self):
        """
        Reset the metrics.

        Args:
            None

        Returns:
            None
        """
        self.cumulative_loan = [0 for _ in range(self.n_groups)]
        self.cumulative_loan_history = [[] for _ in range(self.n_groups)]
        self.tp = [0 for _ in range(self.n_groups)]
        self.fp = [0 for _ in range(self.n_groups)]
        self.fn = [0 for _ in range(self.n_groups)]
        self.tn = [0 for _ in range(self.n_groups)]
        self.recall = [0 for _ in range(self.n_groups)]
        self.precision = [0 for _ in range(self.n_groups)]

    def _sample_applicant(self) -> dict[str, Any]:
        """
        Sample an applicant from the population.

        Args:
            None

        Returns:
            dict[str, Any]: The applicant.
        """
        return np.random.choice(self.population)

    def _get_obs(
        self,
        curr_applicant: dict[str, Any],
        prev_applicant: Union[dict[str, Any], None] = None,
    ) -> dict[str, Any]:
        """
        Get the observation for the given applicant.

        Args:
            applicant (dict[str, Any]): The applicant.
            prev_applicant (Union[dict[str, Any], None]): The previous applicant.

        Returns:
            dict[str, Any]: The observation.
        """
        return {
            "credit_score": one_hot_encode(
                curr_applicant["credit_score"], self.max_credit
            ),
            "group": one_hot_encode(curr_applicant["group"], self.n_groups),
            "prev_applicant_new_credit_score": (
                np.zeros(self.max_credit, dtype=int)
                if prev_applicant is None
                else one_hot_encode(prev_applicant["credit_score"], self.max_credit)
            ),
            "loans_repaid": (
                curr_applicant["loans_repaid"] / curr_applicant["loans_given"]
                if curr_applicant["loans_given"] > 0
                else np.zeros(1, dtype=float)
            ),
            "loans_defaulted": (
                curr_applicant["loans_defaulted"] / curr_applicant["loans_given"]
                if curr_applicant["loans_given"] > 0
                else np.zeros(1, dtype=float)
            ),
        }

    def _get_info(self, will_pay: Union[bool, None]) -> dict[str, Any]:
        """
        Get the info dictionary.

        Args:
            will_pay (Union[bool, None]): Whether the applicant will pay the loan.

        Returns:
            dict[str, Any]: The info dictionary.
        """
        return {
            "recall": self.recall,
            "precision": self.precision,
            "success": will_pay,
        }

    def _is_terminated(self) -> bool:
        """
        Check if the bank is out of cash.

        Args:
            None

        Returns:
            bool: True if the bank is out of cash, False otherwise.
        """
        return self.bank_cash <= 0

    def _get_current_credit_score_distribution(
        self, latent_credit_score: bool = False
    ) -> tuple[tuple[float]]:
        """
        Get the current credit score distribution for each group.

        Args:
            latent_credit_score (bool): Whether to get the actual credit score distribution.

        Returns:
            tuple[tuple[float]]: The credit score distribution for each group.
        """
        credit_score_distribution = []
        key = "latent_credit_score" if latent_credit_score else "credit_score"
        for group in range(self.n_groups):
            credit_score_distribution.append(
                np.histogram(
                    [
                        applicant[key]
                        for applicant in self.population
                        if applicant["group"] == group
                    ],
                    bins=self.max_credit,
                    range=(0, self.max_credit),
                    density=True,
                )[0]
            )
        return tuple(credit_score_distribution)

    def _get_distribution(self, data_key) -> tuple[tuple[float]]:
        """
        Get the distribution of the given data key for each group.

        Args:
            data_key (str): The data key.

        Returns:
            tuple[tuple[float]]: The distribution for each group.
        """
        assert data_key in [
            "cons",
            "loans_given",
            "loans_repaid",
            "loans_defaulted",
        ], "Invalid data key."

        distribution = []
        for group in range(self.n_groups):
            distribution.append(
                np.histogram(
                    [
                        applicant[data_key]
                        for applicant in self.population
                        if applicant["group"] == group
                    ],
                    bins=20 if data_key == "cons" else 10,
                    range=(0, 1) if data_key == "cons" else None,
                    density=True,
                )[0]
            )
        return tuple(distribution)

    def _get_success_prob(self) -> float:
        """
        Get the success probability of the given applicant.
        Success probability can be either on the credit score
        or conscientiousness or both.

        Args:
            applicant (dict[str, Any]): The applicant.

        Returns:
            float: The success probability.
        """
        if self.success_func == "credit":
            return self.success_probability[self.curr_applicant["credit_score"]]
        elif self.success_func == "cons":
            return self.curr_applicant["cons"]
        elif self.success_func == "credit_cons":
            cons_prob = self.curr_applicant["cons"]
            credit_prob = self.success_probability[self.curr_applicant["credit_score"]]
            return (cons_prob + credit_prob) / 2
        elif self.success_func == "latent_credit":
            return self.success_probability[self.curr_applicant["latent_credit_score"]]
        elif self.success_func == "latent_credit_cons":
            cons_prob = self.curr_applicant["cons"]
            credit_prob = self.success_probability[
                self.curr_applicant["latent_credit_score"]
            ]
            return (cons_prob + credit_prob) / 2
        else:
            raise ValueError("Invalid success function.")

    def reset(
        self,
        seed: Union[None, int] = None,
        options: Union[None, dict[str, Any]] = None,
    ) -> tuple[spaces.Dict, dict[str, Any]]:
        """
        Reset the environment to its initial state.

        Args:
            seed (Union[None, int]): The random seed.
            options (Union[None, dict[str, Any]]): The options.

        Returns:
            tuple[spaces.Dict, dict[str, Any]]: The observation and the info.
        """
        super().reset(seed=seed, options=options)

        # Reset the bank's cash, population, and metrics
        self.bank_cash = self.bank_starting_cash
        self._create_population()
        self._reset_metrics()

        # Save initial credit distribution
        self.initial_credit_score_distribution = (
            self._get_current_credit_score_distribution(latent_credit_score=False)
        )
        self.initial_latent_credit_score_distribution = (
            self._get_current_credit_score_distribution(latent_credit_score=True)
        )

        # Sample an applicant and return the observation
        self.curr_applicant = self._sample_applicant()
        obs = self._get_obs(self.curr_applicant, None)
        info = self._get_info(will_pay=False)
        return obs, info

    def step(
        self, action: spaces.Discrete
    ) -> tuple[spaces.Dict, SupportsFloat, bool, bool, dict[str, Any]]:
        """
        Step the environment by taking the given action.

        Args:
            action (spaces.Discrete): The action.

        Returns:
            tuple[spaces.Dict, SupportsFloat, bool, bool, dict[str, Any]]:
                The observation, reward, terminated flag, truncated flag, and info.
        """
        applicant_id = self.curr_applicant["applicant_id"]
        credit_score = self.curr_applicant["credit_score"]
        latent_credit_score = self.curr_applicant["latent_credit_score"]
        group = self.curr_applicant["group"]
        success_prob = self._get_success_prob()

        if action == AcceptRejectAction.ACCEPT.value:
            self.cumulative_loan[group] += 1
            self.population[applicant_id]["loans_given"] += 1

            if np.random.rand() < success_prob:
                # Applicant repays the loan
                reward = 1
                will_pay = True
                credit_score += self.positive_credit_changes[group]
                latent_credit_score += 1
                self.bank_cash += self.loan_amount * self.interest_rate
                self.population[applicant_id]["credit_score"] = min(
                    credit_score, self.max_credit - 1
                )
                self.population[applicant_id]["latent_credit_score"] = min(
                    latent_credit_score, self.max_credit - 1
                )
                self.population[applicant_id]["loans_repaid"] += 1
                self.tp[group] += 1
            else:
                # Applicant defaults on the loan
                reward = -1
                will_pay = False
                credit_score -= self.negative_credit_changes[group]
                latent_credit_score -= 1
                self.bank_cash -= self.loan_amount
                self.population[applicant_id]["credit_score"] = max(credit_score, 0)
                self.population[applicant_id]["latent_credit_score"] = max(
                    latent_credit_score, 0
                )
                self.population[applicant_id]["loans_defaulted"] += 1
                self.fp[group] += 1
        else:
            reward = 0
            if np.random.rand() < success_prob:
                # Applicant would have repaid the loan
                self.fn[group] += 1
                will_pay = True
            else:
                # Applicant would have defaulted on the loan
                self.tn[group] += 1
                will_pay = False

        # Update the cumulative loan history for all groups
        for i in range(self.n_groups):
            self.cumulative_loan_history[i].append(self.cumulative_loan[i])

        self.curr_applicant = self._sample_applicant()

        # Update the recall and precision
        self.recall = compute_recall(self.tp, self.fn)
        self.precision = compute_precision(self.tp, self.fp)

        obs = self._get_obs(self.curr_applicant, self.population[applicant_id])
        info = self._get_info(will_pay=will_pay)
        terminated = self._is_terminated()

        return obs, reward, terminated, False, info

    def set_credit_changes(
        self,
        positive_credit_changes: Union[list[int], np.ndarray],
        negative_credit_changes: Union[list[int], np.ndarray],
    ):
        """
        Set the credit changes for each group.

        Args:
            positive_credit_changes (Union[list[int], np.ndarray]): The positive credit changes.
            negative_credit_changes (Union[list[int], np.ndarray]): The negative credit changes.

        Returns:
            None
        """
        assert (
            len(positive_credit_changes)
            == len(negative_credit_changes)
            == self.n_groups
        ), "The number of positive and negative credit changes must be equal to the number of groups."
        assert np.all(np.array(positive_credit_changes) >= 0) and np.all(
            np.array(positive_credit_changes) >= 0
        ), "The credit score changes must be non-negative."

        self.positive_credit_changes = np.array(positive_credit_changes).astype(int)
        self.negative_credit_changes = np.array(negative_credit_changes).astype(int)
