from numpy.typing import NDArray
from jax import numpy as jnp
import numpy as np

from typing import NamedTuple
from fairgym.envs.base_env import BaseEnv
from fairgym.envs.action import threshold_action


# W_1 - W_0 vs u
#
#           _____<___
#          /
# 0 ------s---------
#        /
#       /
# -->--'
#
# U01 > U11 > U10 > U00

# Utility to individuals, indexed by
# (y, y_hat). That is,
# (true qualification, predicted qualification)
default_agent_utility_matrix = jnp.array(
    [  # AGENT UTILTIES
        [1.00, 1.33],  # unqualified (rejected, accepted)
        [1.10, 1.21],  # qualified   (rejected, accepted)
    ]
)


class GroupQualificationReplicatorEnv(BaseEnv):
    def __init__(
        self,
        num_groups=2,  # currently limited to 2
        num_feature_bins=32,
        agent_utility_matrix=default_agent_utility_matrix,
        **kwargs,
    ):
        """
        A FairThresholds gym environment modelling a population according to
        replicator dynamics in group-specific qualification rates.

        Initialization arguments:

        # TODO allow more than 2
        num_groups: number of protected groups
        num_feature_bins: number of bins of feature values observed by agent

        # the reward function

        Reset arguments:
        options = {
            pr_q: len 2 array # probability of Y=1 in each group
            pr_G: len 2 array # relative frequency of each group
        }


        """

        self.agent_utility_matrix = agent_utility_matrix

        # To guarantee internal equilibrium:
        # unqualified, admitted >
        # qualified, admitted >
        # qualified, rejected >
        # unqualified, rejected
        assert (
            self.agent_utility_matrix[0, 1]
            > self.agent_utility_matrix[1, 1]
            > self.agent_utility_matrix[1, 0]
            > self.agent_utility_matrix[0, 0]
        )

        super().__init__(
            num_groups=num_groups,
            num_feature_bins=num_feature_bins,
            generate_init_state=_generate_init_state,
            generate_next_state=_generate_next_state,
            **kwargs,
        )

    def reset(self, seed=None, return_info=False, options=None):
        """
        Reset environment to new (randomly) generated initial state.
        Return initial observation
        """

        if options is None:
            options = {}

        if "agent_utility_matrix" in options:
            self.agent_utility_matrix = jnp.array(options["agent_utility_matrix"])

        options["agent_utility_matrix"] = self.agent_utility_matrix

        return super().reset(seed, return_info, options)


# TODO inheretance of Named Tuples would be really nice
class AugState(NamedTuple):
    """
    Indexed by [group, (x value)].
    """

    # "X":  "X=x"
    # "lX": "X<=x"
    # "a": and
    # "g": given
    pr_G: NDArray
    pr_X: NDArray
    pr_Y1gX: NDArray
    pr_Y1aX: NDArray
    pr_Y0aX: NDArray
    pr_lX: NDArray
    pr_Y1alX: NDArray
    pr_Y0alX: NDArray
    pr_Y1: NDArray  # qualified
    pr_Y0: NDArray  # unqualified
    agent_utility_matrix: NDArray
    pr_XgY1: NDArray
    pr_XgY0: NDArray


def create_state(agent_utility_matrix, pr_G, pr_X, pr_Y1gX) -> AugState:

    num_groups = pr_G.size

    # These follow from Bayes' rule calcs

    # TODO check broadcasting should work correctly
    pr_Y1aX = pr_X * pr_Y1gX
    pr_Y0aX = pr_X * (1 - pr_Y1gX)

    # TODO fix cumulative rounding errors (is axis=-1 right?)
    pr_lX = jnp.cumsum(pr_X, axis=-1)
    pr_Y1alX = jnp.cumsum(pr_Y1aX, axis=-1)  # not a pdf
    pr_Y0alX = jnp.cumsum(pr_Y0aX, axis=-1)  # not a pdf

    pr_Y1 = pr_Y1alX[:, -1]
    pr_Y0 = pr_Y0alX[:, -1]

    pr_XgY1 = pr_Y1aX / pr_Y1.reshape(num_groups, 1)
    pr_XgY0 = pr_Y0aX / pr_Y0.reshape(num_groups, 1)

    return AugState(
        pr_G=pr_G,
        pr_X=pr_X,
        pr_Y1gX=pr_Y1gX,
        pr_Y1aX=pr_Y1aX,
        pr_Y0aX=pr_Y0aX,
        pr_lX=pr_lX,
        pr_Y1alX=pr_Y1alX,
        pr_Y0alX=pr_Y0alX,
        pr_Y1=pr_Y1,
        pr_Y0=pr_Y0,
        agent_utility_matrix=agent_utility_matrix,
        pr_XgY1=pr_XgY1,
        pr_XgY0=pr_XgY0,
    )


def _generate_init_state(num_groups, num_feature_bins, rng, options):

    assert options["agent_utility_matrix"] is not None

    # qualification rate
    if "pr_q" in options:
        pr_q = jnp.array(options["pr_q"]).reshape(2, 1)
    else:
        pr_q = jnp.array([rng.random(), rng.random()]).reshape(2, 1)
    pr_u = 1 - pr_q

    # Assume static Pr(G)
    # https://numpy.org/doc/stable/reference/random/generator.html#simple-random-data
    if "pr_G" in options:
        pr_G = jnp.array(options["pr_G"]).reshape(1, 2)
    else:
        group_1_size = rng.random()
        pr_G = jnp.array([group_1_size, 1 - group_1_size])
    pr_G = pr_G / jnp.sum(pr_G)  # normalize

    # Pr(X = x | Y = 1, G = g)
    # TODO replace with standard call
    def gaussian(x, mean=0, sigma=1):
        return jnp.exp(-((x - mean) ** 2) / (2 * sigma**2)) / (
            sigma * jnp.sqrt(2 * jnp.pi)
        )

    gaussian_p1 = gaussian(jnp.linspace(-0.9999, 0.9999, num_feature_bins), mean=1)
    gaussian_m1 = gaussian(jnp.linspace(-0.9999, 0.9999, num_feature_bins), mean=-1)

    # held constant
    d1_g = jnp.array([gaussian_p1, gaussian_p1])
    # normalize
    d1_g = d1_g / jnp.sum(d1_g, axis=1).reshape(num_groups, 1)

    # Pr(X = x | Y = 0, G = g)
    # held constant
    d0_g = jnp.array([gaussian_m1, gaussian_m1])
    # normalize
    d0_g = d0_g / jnp.sum(d0_g, axis=1).reshape(num_groups, 1)

    pr_X = d1_g * pr_q + d0_g * pr_u
    pr_Y1gX = d1_g * pr_q / pr_X

    # plt.plot(self.d1_g[0])
    # plt.plot(self.d1_g[1])
    # plt.plot(self.d0_g[0])
    # plt.plot(self.d0_g[1])
    # # plt.plot(self.d1_g[1] / self.d0_g[1])
    # plt.show()
    # assert False

    assert jnp.all(d0_g != 0)
    assert jnp.all(d1_g != 0)

    return create_state(options["agent_utility_matrix"], pr_G, pr_X, pr_Y1gX)


def _generate_next_state(
    state,
    action,
):
    """
    DYNAMICS
    return new distribution from currently stored distribution and action
    using deterministic (replicator dynamics) model.

    Not using static_argnums as args aren't hashable.
    """

    # compute how this action affects variables in current state
    results = threshold_action(state, action)

    pr_q = state.pr_Y1
    pr_u = state.pr_Y0

    # Dynamics
    # Start simple (e.g., replicator dynamics)

    # agent_utility_matrix[qualified?, admitted?]
    W1_g = (
        state.agent_utility_matrix[1, 1] * results.tp_rate
        + state.agent_utility_matrix[1, 0] * results.fn_rate
    )
    W0_g = (
        state.agent_utility_matrix[0, 1] * results.fp_rate
        + state.agent_utility_matrix[0, 0] * results.tn_rate
    )

    # sg is proportion qualified in each group
    # Wg is average fitness of each group
    W_g = pr_q * W1_g + pr_u * W0_g

    new_pr_q = jnp.clip((pr_q * (W1_g / W_g)).reshape((2, 1)), 0.00001, 0.99999)
    new_pr_x = new_pr_q * state.pr_XgY1 + (1 - new_pr_q) * state.pr_XgY0

    # Pr(Y=1 | X, G) = Pr(X | Y=1) Pr(Y=1) / Pr(X)
    new_pr_Y1gX = state.pr_XgY1 * new_pr_q / new_pr_x

    return create_state(state.agent_utility_matrix, state.pr_G, new_pr_x, new_pr_Y1gX)


if __name__ == "__main__":

    import matplotlib as mpl
    from matplotlib import pyplot as plt

    env = GroupQualificationReplicatorEnv()
    observation, info = env.reset(seed=0, return_info=True)

    fig = plt.figure()
    n_axs = 2
    scale = 4.5

    fig.set_size_inches(n_axs * scale, scale)

    axs = []
    for i in range(n_axs):
        axs.append(fig.add_subplot(1, n_axs, i + 1))

    colors = ["red", "blue"]
    linestyles = ["-", "--"]
    N = 10
    for t in range(N):

        state = info["current_state"]
        results = info["prev_results"]

        # P(X | G) and P(Y=1 | X, G)
        for g in range(env.num_groups):

            if t == 0:
                label = f"g={g + 1}"
            else:
                label = None

            axs[0].plot(
                np.linspace(0, 1, 32),
                state.pr_X[g],
                label=label,
                alpha=(1 - t / N),
                color=colors[g],
                linestyle=linestyles[g],
            )
            axs[1].plot(
                np.linspace(0, 1, 32),
                state.pr_Y1gX[g],
                label=label,
                alpha=(1 - t / N),
                color=colors[g],
                linestyle=linestyles[g],
            )

        observation, reward, terminated, truncated, info = env.step(
            jnp.array([0.5, 0.5])
        )

    axs[0].legend(loc="upper left")
    axs[1].legend(loc="upper left")

    axs[0].set_xlabel("X")
    axs[0].set_ylabel("Pr(X=x|G=g) (Density)")

    axs[1].set_xlabel("X")
    axs[1].set_ylabel("Pr(Y=1|X=x,G=g)")

    plt.savefig("example_dynamics.pdf")
