"""
Lending agent with access to ground truth distributions of applicants.

Code based on: 
https://github.com/google/ml-fairness-gym
"""

import numpy as np

from agents import classifier_agents
from fair_gym import AcceptRejectAction


def one_hot_encode(value: int, size: int) -> np.ndarray:
    """
    One-hot encode the given value.

    Args:
        value (int): The value to one-hot encode.
        size (int): The size of the one-hot encoding.

    Returns:
        np.ndarray: The one-hot encoding.
    """
    one_hot = np.zeros(size, dtype=int)
    one_hot[int(value)] = 1
    return one_hot


class OracleThresholdAgent(classifier_agents.ThresholdAgent):
    """Threshold agent with oracle access to distributional data."""

    def __init__(self, action_space, reward_fn, observation_space, params, env):
        super(OracleThresholdAgent, self).__init__(
            action_space=action_space,
            reward_fn=reward_fn,
            observation_space=observation_space,
            params=params,
        )
        self.env = env

    def _record_training_example(self, observation, action, reward):
        self._training_corpus = classifier_agents.TrainingCorpus()
        applicant_distributions = (
            self.env.unwrapped._get_current_credit_score_distribution()
        )
        n_groups = len(applicant_distributions)
        max_credit = self.env.unwrapped.max_credit

        for group, distribution in enumerate(applicant_distributions):
            for credit, cluster_prob in enumerate(distribution):
                prob_default = 1 - self.env.unwrapped.success_probability[int(credit)]
                observation = {
                    "credit_score": one_hot_encode(credit, max_credit),
                    "group": one_hot_encode(group, n_groups),
                }
                self._training_corpus.add(
                    classifier_agents.TrainingExample(
                        observation=observation,
                        action=AcceptRejectAction.ACCEPT.value,
                        label=0,
                        weight=prob_default * cluster_prob + 1e-6,
                        features=self._get_features(observation),
                    )
                )
                self._training_corpus.add(
                    classifier_agents.TrainingExample(
                        observation=observation,
                        action=AcceptRejectAction.ACCEPT,
                        label=1,
                        weight=(1 - prob_default) * cluster_prob + 1e-6,
                        features=self._get_features(observation),
                    )
                )
