from numpy.typing import NDArray
from jax import numpy as jnp


from typing import NamedTuple
from fairgym.envs.base_env import BaseEnv
from fairgym.envs.action import threshold_action

from fairgym.envs.adult.preprocess import observe_distribution

# W_1 - W_0 vs u
#
#           _____<___
#          /
# 0 ------s---------
#        /
#       /
# -->--'
#
# U01 > U11 > U10 > U00

# Utility to individuals, indexed by
# (y, y_hat). That is,
# (true qualification, predicted qualification)
default_agent_utility_matrix = jnp.array(
    [  # AGENT UTILTIES
        [1.00, 1.33],  # unqualified (rejected, accepted)
        [1.10, 1.21],  # qualified   (rejected, accepted)
    ]
)


class AdultQualificationReplicatorEnv(BaseEnv):
    def __init__(
        self,
        num_groups=2,  # currently limited to 2
        num_feature_bins=32,
        agent_utility_matrix=default_agent_utility_matrix,
        **kwargs
    ):
        """
        A FairThresholds gym environment modelling a population according to
        replicator dynamics in group-specific qualification rates.

        Initialization arguments:

        # TODO allow more than 2
        num_groups: number of protected groups
        num_feature_bins: number of bins of feature values observed by agent

        # the reward function

        Reset arguments:
        options = {
            pr_q: len 2 array # probability of Y=1 in each group
            pr_G: len 2 array # relative frequency of each group
        }


        """

        self.agent_utility_matrix = agent_utility_matrix

        # To guarantee internal equilibrium:
        # unqualified, admitted >
        # qualified, admitted >
        # qualified, rejected >
        # unqualified, rejected
        assert (
            self.agent_utility_matrix[0, 1]
            > self.agent_utility_matrix[1, 1]
            > self.agent_utility_matrix[1, 0]
            > self.agent_utility_matrix[0, 0]
        )

        super().__init__(
            num_groups=num_groups,
            num_feature_bins=num_feature_bins,
            generate_init_state=_generate_init_state,
            generate_next_state=_generate_next_state,
            **kwargs
        )

    def reset(self, seed=None, return_info=False, options=None):
        """
        Reset environment to new (randomly) generated initial state.
        Return initial observation
        """

        if options is None:
            options = {}

        if "agent_utility_matrix" in options:
            self.agent_utility_matrix = jnp.array(options["agent_utility_matrix"])

        options["agent_utility_matrix"] = self.agent_utility_matrix

        return super().reset(seed, return_info, options)


# TODO inheretance of Named Tuples would be really nice
class AugState(NamedTuple):
    """
    Indexed by [group, (x value)].
    """

    # "X":  "X=x"
    # "lX": "X<=x"
    # "a": and
    # "g": given
    pr_G: NDArray
    pr_X: NDArray
    pr_Y1gX: NDArray
    pr_Y1aX: NDArray
    pr_Y0aX: NDArray
    pr_lX: NDArray
    pr_Y1alX: NDArray
    pr_Y0alX: NDArray
    pr_Y1: NDArray  # qualified
    pr_Y0: NDArray  # unqualified
    agent_utility_matrix: NDArray


def create_state(agent_utility_matrix, pr_G, pr_X, pr_Y1gX) -> AugState:

    # These follow from Bayes' rule calcs

    # TODO check broadcasting should work correctly
    pr_Y1aX = pr_X * pr_Y1gX
    pr_Y0aX = pr_X * (1 - pr_Y1gX)

    # TODO fix cumulative rounding errors (is axis=-1 right?)
    pr_lX = jnp.cumsum(pr_X, axis=-1)
    pr_Y1alX = jnp.cumsum(pr_Y1aX, axis=-1)  # not a pdf
    pr_Y0alX = jnp.cumsum(pr_Y0aX, axis=-1)  # not a pdf

    pr_Y1 = pr_Y1alX[:, -1]
    pr_Y0 = pr_Y0alX[:, -1]

    return AugState(
        pr_G=pr_G,
        pr_X=pr_X,
        pr_Y1gX=pr_Y1gX,
        pr_Y1aX=pr_Y1aX,
        pr_Y0aX=pr_Y0aX,
        pr_lX=pr_lX,
        pr_Y1alX=pr_Y1alX,
        pr_Y0alX=pr_Y0alX,
        pr_Y1=pr_Y1,
        pr_Y0=pr_Y0,
        agent_utility_matrix=agent_utility_matrix,
    )


def _generate_init_state(num_groups, num_feature_bins, rng, options):

    assert num_groups == 2
    assert options["agent_utility_matrix"] is not None

    # qualification rate
    if "pr_q" in options:
        pr_q = jnp.array(options["pr_q"]).reshape(2, 1)
    else:
        pr_q = jnp.array([rng.random(), rng.random()]).reshape(2, 1)

    # Assume static Pr(G)
    # https://numpy.org/doc/stable/reference/random/generator.html#simple-random-data
    if "pr_G" in options:
        pr_G = jnp.array(options["pr_G"]).reshape(1, 2)
    else:
        group_1_size = rng.random()
        pr_G = jnp.array([group_1_size, 1 - group_1_size]).reshape(1, 2)
    pr_G = pr_G / jnp.sum(pr_G)  # normalize

    pr_X, pr_Y1gX = observe_distribution(pr_G[0, 0], *pr_q, num_feature_bins)
    pr_X = jnp.array(pr_X)
    pr_Y1gX = jnp.array(pr_Y1gX)

    return create_state(options["agent_utility_matrix"], pr_G, pr_X, pr_Y1gX)


def _generate_next_state(
    state,
    action,
):
    """
    DYNAMICS
    return new distribution from currently stored distribution and action
    using deterministic (replicator dynamics) model.

    Not using static_argnums as args aren't hashable.
    """

    # compute how this action affects variables in current state
    results = threshold_action(state, action)

    pr_q = state.pr_Y1
    pr_u = state.pr_Y0

    # Dynamics
    # Start simple (e.g., replicator dynamics)

    # agent_utility_matrix[qualified?, admitted?]
    W1_g = (
        state.agent_utility_matrix[1, 1] * results.tp_rate
        + state.agent_utility_matrix[1, 0] * results.fn_rate
    )
    W0_g = (
        state.agent_utility_matrix[0, 1] * results.fp_rate
        + state.agent_utility_matrix[0, 0] * results.tn_rate
    )

    # sg is proportion qualified in each group
    # Wg is average fitness of each group
    W_g = pr_q * W1_g + pr_u * W0_g

    new_pr_q = jnp.clip((pr_q * (W1_g / W_g)).reshape((2, 1)), 0.00001, 0.99999)

    num_feature_bins = state.pr_X.shape[1]
    return _generate_init_state(
        2,
        num_feature_bins,
        None,
        {
            "agent_utility_matrix": state.agent_utility_matrix,
            "pr_q": new_pr_q,
            "pr_G": state.pr_G,
        },
    )
