import jax
import jax.numpy as jnp
import jax.scipy as jsp

from nais.gflownet import GFlowNet
from nais.gym.base import EnvironmentConfig, EnvState, LogRewardBase, LogRewardConfig

# @struct.dataclass
# class HypergridState(struct.PyTreeNode):
#     state: jax.Array
#     forward_mask: jax.Array
#     backward_mask: jax.Array
#     stopped: jax.Array
#     is_initial: jax.Array


def move(
    env_state: EnvState,
    actions: jax.Array,
    active_mask: jax.Array,
    size: int,
    shift: int,
) -> EnvState:
    dims = env_state.state.shape[-1]
    batch_ids = env_state.batch_ids

    is_stop_action = actions == dims

    stopped = jnp.where(
        active_mask,
        # When shift < 0, the state is moving backwards
        # (i.e., the stop action removes the stopped label from the state)
        (is_stop_action & (shift > 0)).astype(env_state.stopped.dtype),
        env_state.stopped,
    )

    is_active_and_has_stopped = active_mask & (stopped > 0)
    is_active_and_has_not_stopped = active_mask & ~(stopped > 0)

    safe_actions = jnp.where(is_active_and_has_not_stopped, actions, 0)

    # We first update the state for the non-stop actions.
    current_values = env_state.state[batch_ids, safe_actions]
    updates = jnp.where(is_active_and_has_not_stopped, current_values + shift, current_values)
    state = env_state.state.at[batch_ids, safe_actions].set(updates)

    # Otherwise, we mask the actions that would lead to cornered states
    is_active_and_has_not_stopped = jnp.expand_dims(is_active_and_has_not_stopped, axis=1)
    updates = jnp.where(
        is_active_and_has_not_stopped,
        (state < size - jnp.abs(shift)),
        env_state.forward_mask[:, :-1],
    )
    forward_mask = env_state.forward_mask.at[:, :-1].set(updates)

    updates = jnp.where(
        is_active_and_has_not_stopped,
        (state > jnp.abs(shift) - 1),
        env_state.backward_mask[:, :-1],
    )
    backward_mask = env_state.backward_mask.at[:, :-1].set(updates)

    # We also update the mask; when the state stopped,
    # all actions are masked, except for the stop action.
    mask_with_only_stop_action = jnp.zeros_like(forward_mask).at[:, -1].set(1.0)
    is_active_and_has_stopped = jnp.expand_dims(is_active_and_has_stopped, axis=1)
    forward_mask = jnp.where(is_active_and_has_stopped, mask_with_only_stop_action, forward_mask)
    backward_mask = jnp.where(is_active_and_has_stopped, mask_with_only_stop_action, backward_mask)

    # If state is not stopped, backward_mask[:, -1] = 0
    backward_mask = backward_mask.at[batch_ids, -1].set(stopped)

    is_initial = jnp.where(
        active_mask,
        ((state == 0).all(axis=1) & ~(stopped > 0)).astype(env_state.is_initial.dtype),
        env_state.is_initial,
    )

    stopped = jnp.where(active_mask & (shift < 0), 0, stopped)

    return env_state.replace(
        state=state,
        forward_mask=forward_mask,
        backward_mask=backward_mask,
        stopped=stopped,
        is_initial=is_initial,
    )


def apply_fn(gflownet_key: tuple[GFlowNet, jax.Array], _):
    gflownet, key = gflownet_key
    env_state = gflownet.state.env_state
    is_active = env_state.stopped == 0.0

    out_pf = gflownet.pf.sample_actions(env_state, key=key)
    env_state = move(
        env_state,
        out_pf.actions,
        is_active,
        gflownet.log_reward.size,  # = grid size
        shift=1,
    )

    log_F = gflownet.flow_function(env_state).squeeze(axis=1)

    log_pf = gflownet.state.log_pf.at[:, gflownet.state.idx].set(jnp.where(is_active, out_pf.log_pf, 0.0))
    log_F = gflownet.state.log_F.at[:, gflownet.state.idx].set(jnp.where(is_active, log_F, 0.0))

    gflownet.state = gflownet.state.replace(
        env_state=env_state,
        log_pf=log_pf,
        log_F=log_F,
        idx=gflownet.state.idx + 1,
    )

    return (gflownet, out_pf.key), (out_pf.actions, is_active, env_state)


def backward_fn(gflownet_key: tuple[GFlowNet, jax.Array], _):
    gflownet, key = gflownet_key
    env_state = gflownet.state.env_state
    is_active = ~env_state.is_initial.astype(bool)

    out_pb = gflownet.pb.sample_actions(env_state, key=key)

    env_state = move(
        env_state=env_state,
        actions=out_pb.actions,
        active_mask=is_active,
        size=gflownet.log_reward.size,
        shift=-1.0,
    )

    idx = gflownet.state.idx - 1

    log_pb = gflownet.state.log_pb.at[:, idx].set(jnp.where(is_active, out_pb.log_pb, 0.0))

    gflownet.state = gflownet.state.replace(
        env_state=env_state,
        log_pb=log_pb,
        idx=idx,
    )

    return (gflownet, out_pb.key), (out_pb.actions, is_active, env_state)


def factory(size: int, dims: int, config: EnvironmentConfig):
    return EnvState(
        state=jnp.zeros((config.batch_size, dims)),
        forward_mask=jnp.ones((config.batch_size, dims + 1)),
        backward_mask=jnp.zeros((config.batch_size, dims + 1)),
        batch_ids=jnp.arange(config.batch_size),
        stopped=jnp.zeros((config.batch_size,)),
        is_initial=jnp.ones((config.batch_size,)),
        max_trajectory_length=dims * size + 1,
        batch_size=config.batch_size,
    )


def get_grid(size: int, dims: int, config: EnvironmentConfig) -> EnvState:
    states = jnp.meshgrid(*[jnp.arange(size, dtype=jnp.float32) for _ in range(dims)])
    states = jnp.stack(states, axis=0)
    states = states.reshape(dims, -1).T

    config.batch_size = len(states)
    env_state = factory(size, dims, config)

    mask = jnp.zeros_like(env_state.forward_mask)
    mask = mask.at[:, -1].set(1.0)

    env_state = env_state.replace(
        state=states,
        forward_mask=mask,
        backward_mask=mask,
        stopped=jnp.ones_like(env_state.stopped),
        is_initial=jnp.zeros_like(env_state.is_initial),
    )

    return env_state


class LogRewardUniform(LogRewardBase):
    def __call__(self, state: EnvState) -> jax.Array:
        return jnp.ones((state.batch_size,))


class LogReward(LogRewardBase):
    def __init__(self, config: LogRewardConfig, *, size: int, ro: float = 1e-3):
        super().__init__(config)
        self.ro = ro
        self.size = size

    def __call__(self, state: EnvState) -> jax.Array:
        ax = jnp.abs(state.state / (self.size - 1) * 2 - 1)
        return jnp.log((ax > 0.5).prod(-1) * 0.5 + ((ax < 0.8) * (ax > 0.6)).prod(-1) * 2 + self.ro) / self.temperature

    @property
    def mode_th(self):
        return jnp.log(0.5) / self.temperature

class LogRewardGaussian(LogRewardBase):
    def __init__(self, config: LogRewardConfig, *, size: int, sigma: float = 1e-2):
        super().__init__(config)
        self.size = size

        self.mean = jnp.zeros((2,), dtype=jnp.float32)
        self.sigma = sigma

        pad = 0.1

        left = -1 + pad
        mid = 0
        right = 1 - pad

        self.mu = jnp.array(
            [
                [left, left],
                [left, mid],
                [left, right],
                [mid, left],
                [mid, mid],
                [mid, right],
                [right, left],
                [right, mid],
                [right, right],
            ]
        )

        self.normal = jsp.stats.multivariate_normal

    def __call__(self, state: EnvState) -> jax.Array:
        ax = state.state - (self.size - 1) / 2
        ax = (2 * ax) / (self.size - 1)  # (B, D)
        ax = ax[:, None, :] - self.mu[None, ...]  # (B, 9, D)
        ax = self.normal.logpdf(ax, mean=self.mean, cov=self.sigma)  # (B, 9)
        return jax.nn.logsumexp(ax, axis=-1) - jnp.log(len(self.mu))
