import numpy as np
from jax import numpy as jnp
from jax import device_get, jit

from functools import partial

import gymnasium as gym
from gymnasium import spaces

from fairgym.envs.state import create_state
from fairgym.envs.action import threshold_action
from fairgym.envs.obs_encoding import encode_obs

# Use for debugging
# import jax
# jax.config.update("jax_check_tracer_leaks", True)


class BaseEnv(gym.Env):
    """
    A baseclass for fair thresholds environment.
    Subclasses will define the utility/feedback and dynamics

    In general, class-based objects don't play well with Jax, which is based on pure functions.
    Easiest way to handle this (for now) is to break out methods as top-level functions.
    """

    metadata = {"render_modes": ["rgb_array"]}

    def __init__(
        self,
        num_groups,
        num_feature_bins,
        generate_init_state=None,
        generate_next_state=None,
        reward_fn=None,
        return_truncated=False,
        use_jit=True,
    ):
        """
        Initialize environment and renderer
        Args:
            `num_groups`: number of groups
            `feat_bins_exp`: 2^(this) = number of observable bins along X dimension
        """

        self.num_groups = num_groups
        self.num_feature_bins = num_feature_bins

        # Action space (one threshold in [0, 1] for each g in num_groups)
        self.action_space = spaces.Box(
            low=0, high=1, shape=(num_groups,), dtype=np.float32
        )

        # Observable state
        # (corresponds to current distribution) See function get_obs
        # probability (density) of X - red channel, for each group
        # probability of Y=1 conditioned on X - green channel, for each group
        self.observation_space = spaces.Box(
            low=0,
            high=255,
            shape=(num_groups, self.num_feature_bins, 4),
            dtype=np.uint8,
        )

        # Set by reset method ##################################################
        self.rng = None

        # current distribution and associated state variables
        self.state = None

        # Set defaults for generating state and reward functions (needed for Jax compilation) as attributes, as they
        # play nicer with Jax.
        # !!! Make sure these aren't closures associated changing variables !!!
        # Jit recompilation only gets triggered when the hash changes, which wouldn't happen with a closure.
        self.generate_next_state = (
            _generate_next_state_base
            if generate_next_state is None
            else generate_next_state
        )
        self.generate_init_state = (
            _generate_init_state_base
            if generate_init_state is None
            else generate_init_state
        )
        self.reward_fn = _reward_fn if reward_fn is None else reward_fn
        self.return_truncated = return_truncated

        self.use_jit = use_jit

    def reset(self, seed=None, return_info=False, options=None):
        """
        Reset environment to new (randomly) generated initial state and cumulative reward.
        Return initial observation
        """

        # distributions number generator handled by gym.Env
        super().reset(seed=seed)
        self.rng = self.np_random

        self.state = self.generate_init_state(
            self.num_groups, self.num_feature_bins, self.rng, options=options
        )

        # Make sure on device to prevent recompilation
        observation = device_get(encode_obs(self.state))
        info = {
            "prev_state": None,
            "prev_action": None,
            "prev_results": None,
            "current_state": self.state,
            "reward": None,
        }
        return (observation, info) if return_info else observation

    def step(self, action):
        """
        Wrapper for pure step.

        Also updates self.state.
        :param action:
        :return:
        """
        if self.use_jit:
            _step_fn = _jit_step
        else:
            _step_fn = _step

        observation, reward, terminated, truncated, info = _step_fn(
            self.generate_next_state, self.reward_fn, self.state, action
        )

        self.state = info["current_state"]
        if self.return_truncated:
            return (
                device_get(observation),
                float(device_get(reward)),
                bool(device_get(terminated)),
                info,
            )
        else:
            return (
                device_get(observation),
                float(device_get(reward)),
                bool(device_get(terminated)),
                bool(device_get(truncated)),
                info,
            )

    def check_step_reward(self, state, action):
        """
        Does not update state.
        :param state:
        :param action:
        :return:
        """

        if self.use_jit:
            _step_fn = _jit_step
        else:
            _step_fn = _step

        _, reward, _, _, _ = _step_fn(
            self.generate_next_state, self.reward_fn, state, action
        )
        return reward

    def render(self, *args, **kwargs):
        """
        Method required by Gym API
        """
        pass


def _step(generate_next_state_fn, reward_fn, prev_state, action):
    """
    Update state as function of agent's action `u` in [0, 1]
    Jit only compiles for (prev_state, action) as dynamic arguments.
    Will be compile twiced, one for JVP array arguments as used in check_step_rewards().

    Returns (observation, reward, terminated, truncated, info)
    """
    # TODO why 0.99?
    # threshold is structured as
    # Y_hat = (X > threshold)
    action = jnp.clip(action, 0, 0.999999)

    # compute new distribution and update to that distribution
    prev_results = threshold_action(prev_state, action)
    current_state = generate_next_state_fn(prev_state, action)
    observation = encode_obs(current_state)

    info = {
        "prev_state": prev_state,
        "prev_action": action,
        "prev_results": prev_results,
        "current_state": current_state,
    }
    current_reward = reward_fn(*info.values())
    info["reward"] = current_reward
    terminated = False
    truncated = False
    return observation, current_reward, terminated, truncated, info


_jit_step = partial(jit, static_argnames=("generate_next_state_fn", "reward_fn"))(_step)


def _generate_init_state_base(num_groups, num_feature_bins, rng=None, options=None):
    """
    Return brand new distribution
    """
    raise NotImplementedError("Implementation provided by a subclass of Env.")


def _generate_next_state_base(state, action):
    """
    Return new distribution from currently stored distribution and action
    """
    raise NotImplementedError("Implementation provided by a subclass of Env.")


def _reward_fn(prev_state, action, prev_results, current_state):
    """
    Return reward signal to agent.
    """
    return jnp.sum(prev_state.pr_G * (prev_results.tp_rate + prev_results.tn_rate))
