import jax
import jax.numpy as jnp
import jax.scipy as jsp

from flax import struct
from typing import Self
from functools import partial

from nais.gflownet import GFlowNet
from nais.gym.base import (
    Environment,
    EnvironmentConfig,
    LogRewardBase,
    LogRewardConfig,
    EnvState,
)


# @struct.dataclass
# class LineState(struct.PyTreeNode):
#     state: jax.Array
#     forward_mask: jax.Array
#     backward_mask: jax.Array
#     stopped: jax.Array
#     is_initial: jax.Array


def move(
    env_state: EnvState,
    actions: jax.Array,
    active_mask: jax.Array,
    batch_ids: jax.Array,
    length: int,
    max_step_size: int,
    stop_action_index: int,
    direction: int,  # Either 1 (forward) or -1 (backward)
) -> EnvState:
    # The i-th action corresponds to moving forward by i steps.
    is_stop_action = actions == stop_action_index

    is_forward_action = direction == 1
    is_backward_action = direction == -1

    stopped = jnp.where(
        active_mask,
        (is_stop_action & is_forward_action).astype(env_state.stopped.dtype),
        env_state.stopped,
    )

    stopped = jnp.where(
        active_mask & is_backward_action,
        0,
        stopped,
    )

    is_active_and_has_stopped = active_mask & (stopped > 0)
    is_active_and_has_not_stopped = active_mask & ~(stopped > 0)

    step_sizes = jnp.where(~is_stop_action, actions + 1, 0)

    # We update the current state with the agent's current action.
    current_values = env_state.state
    updates = jnp.where(
        is_active_and_has_not_stopped,
        current_values + step_sizes * direction,
        current_values,
    )
    state = env_state.state.at[:].set(updates)

    # The forward mask should mask out actions i for which state + i > length.
    # It should also mask all actions except for the stop action if the agent has stopped.
    all_actions = jnp.linspace(1, max_step_size, num=max_step_size, endpoint=True)
    all_actions = jnp.expand_dims(all_actions, axis=0)
    next_pos = state[:, None] + all_actions  # (batch_size, max_step_size)

    mask_with_only_stop_action = (
        jnp.zeros_like(env_state.forward_mask).at[:, -1].set(1.0)
    )
    # The agent has not stopped
    updates = jnp.where(
        is_active_and_has_not_stopped[:, None],
        (next_pos < length).astype(env_state.forward_mask.dtype),
        env_state.forward_mask[:, :-1],
    )
    forward_mask = env_state.forward_mask.at[:, :-1].set(updates)

    # The agent has stopped
    forward_mask = jnp.where(
        is_active_and_has_stopped[:, None], mask_with_only_stop_action, forward_mask
    )

    # The backward mask follows a similar logic, but
    # we subtract the actions instead of adding them.
    next_pos = state[:, None] - all_actions  # (batch_size, max_step_size)

    updates = jnp.where(
        is_active_and_has_not_stopped[:, None],
        (next_pos >= 0).astype(env_state.backward_mask.dtype),
        env_state.backward_mask[:, :-1],
    )
    backward_mask = env_state.backward_mask.at[:, :-1].set(updates)
    backward_mask = jnp.where(
        is_active_and_has_stopped[:, None] & is_forward_action,
        mask_with_only_stop_action,
        backward_mask,
    )

    # Any reverse action should disrupt the stopped state
    backward_mask = jnp.where(
        active_mask[:, None] & is_backward_action,
        backward_mask.at[:, -1].set(0.0),
        backward_mask,
    )

    # is_initial and stopped can also be readily checked
    is_initial = jnp.where(
        active_mask & is_backward_action,
        (state == 0).astype(env_state.is_initial.dtype),
        env_state.is_initial,
    )
    # If this is a forward action, then the agent leaves the initial state
    is_initial = jnp.where(active_mask & is_forward_action, 0, is_initial)

    return env_state.replace(
        state=state,
        forward_mask=forward_mask,
        backward_mask=backward_mask,
        is_initial=is_initial,
        stopped=stopped,
    )


def apply_fn(gflownet_key: tuple[GFlowNet, jax.Array], _):
    gflownet, key = gflownet_key
    env_state = gflownet.state.env_state
    is_active = env_state.stopped == 0.0

    # forward_mask: (B, max_step_size + 1), +1 for the stop action
    max_step_size = env_state.forward_mask.shape[-1] - 1
    length = env_state.max_trajectory_length - 1

    out_pf = gflownet.pf.sample_actions(env_state, key)

    env_state = move(
        env_state,
        out_pf.actions,
        is_active,
        env_state.batch_ids,
        length,
        max_step_size,
        stop_action_index=max_step_size,
        direction=1,
    )

    # The backward action for sequences consists of removing the most recently included token
    out_pb = gflownet.pb.sample_actions(env_state, actions=out_pf.actions)

    log_pf = gflownet.state.log_pf.at[:, gflownet.state.idx].set(
        jnp.where(is_active, out_pf.log_pf, 0.0)
    )
    log_pb = gflownet.state.log_pb.at[:, gflownet.state.idx].set(
        jnp.where(is_active, out_pb.log_pb, 0.0)
    )

    gflownet.state = gflownet.state.replace(
        env_state=env_state, log_pf=log_pf, log_pb=log_pb, idx=gflownet.state.idx + 1
    )

    return (gflownet, out_pf.key), (out_pf.actions, is_active, env_state)


def backward_fn(gflownet_key: tuple[GFlowNet, jax.Array], _):
    gflownet, key = gflownet_key
    env_state = gflownet.state.env_state
    is_active = ~env_state.is_initial.astype(bool)

    # forward_mask: (B, max_step_size + 1), +1 for the stop action
    max_step_size = env_state.forward_mask.shape[-1] - 1
    length = env_state.max_trajectory_length - 1

    out_pb = gflownet.pb.sample_actions(env_state, key=key)

    env_state = move(
        env_state,
        out_pb.actions,
        is_active,
        env_state.batch_ids,
        length,
        max_step_size,
        stop_action_index=max_step_size,
        direction=-1,
    )

    idx = gflownet.state.idx - 1

    log_pb = gflownet.state.log_pb.at[:, idx].set(
        jnp.where(is_active, out_pb.log_pb, 0.0)
    )

    gflownet.state = gflownet.state.replace(
        env_state=env_state,
        log_pb=log_pb,
        idx=idx,
    )

    return (gflownet, out_pb.key), (out_pb.actions, is_active, env_state)


def factory(length: int, max_step_size: int, config: EnvironmentConfig):
    return EnvState(
        state=jnp.zeros((config.batch_size,)),
        forward_mask=jnp.ones((config.batch_size, max_step_size + 1)),
        backward_mask=jnp.zeros((config.batch_size, max_step_size + 1)),
        batch_ids=jnp.arange(config.batch_size),
        stopped=jnp.zeros((config.batch_size,)),
        is_initial=jnp.ones((config.batch_size,)),
        max_trajectory_length=length + 1,
        batch_size=config.batch_size,
    )


# Each action corresponds to a forward step in the line;
# a state is represented as a binary vector with size `length`.
# If state[i] = 1, then the agent is on the i-th position in the line.
# This is a simply visualized state that allows for assessing the trajectory length effect on learning.
# class Lines(Environment):
#     state: jax.Array

#     def __init__(self, length: int, max_step_size: int, config: EnvironmentConfig):
#         super().__init__(config)
#         self.length = length
#         self.max_step_size = max_step_size

#         # The number of actions in a state s is
#         # max_step_size + 1 (for the stop action). This holds for both forward and backward processes.
#         self.num_actions = max_step_size + 1

#         self.stop_action_index = max_step_size

#         self._state = LineState(
#             state=jnp.zeros((self.batch_size,)),
#             forward_mask=jnp.ones((self.batch_size, self.num_actions)),
#             backward_mask=jnp.zeros((self.batch_size, self.num_actions)),
#             stopped=self.stopped,
#             is_initial=self.is_initial,
#         )

#         self.max_trajectory_length = length + 1  # length + stop action

#         self._sync_views()

#     def _sync_views(self):
#         self.state = self._state.state
#         self.forward_mask = self._state.forward_mask
#         self.backward_mask = self._state.backward_mask
#         self.stopped = self._state.stopped
#         self.is_initial = self._state.is_initial

#     def merge(self, batch_state: Self):
#         super().merge(batch_state)
#         self._state = LineState(
#             state=self.state,
#             forward_mask=self.forward_mask,
#             backward_mask=self.backward_mask,
#             stopped=self.stopped,
#             is_initial=self.is_initial,
#         )

#     def apply(self, actions: jax.Array, active_mask: jax.Array | None = None) -> jax.Array:
#         active_mask = self._active_mask(active_mask)
#         self._state = _move_impl(
#             self._state,
#             actions,
#             active_mask,
#             self.batch_ids,
#             self.length,
#             self.max_step_size,
#             self.stop_action_index,
#             1,
#         )
#         self._sync_views()
#         return actions

#     def backward(self, actions: jax.Array, active_mask: jax.Array | None = None) -> jax.Array:
#         active_mask = self._active_mask(active_mask)
#         self._state = _move_impl(
#             self._state,
#             actions,
#             active_mask,
#             self.batch_ids,
#             self.length,
#             self.max_step_size,
#             self.stop_action_index,
#             -1,
#         )
#         self._sync_views()
#         return actions

#     def fwd_to_bcw_actions(self, actions: jax.Array):
#         return actions

#     @property
#     def log_space_size(self):
#         return jnp.log(self.length)

#     def get_state(self):
#         # We include an extra dimension for consistency with other environments
#         return self.state[:, None]


def get_lines(length: int, max_step_size: int, config: EnvironmentConfig) -> EnvState:
    states = jnp.arange(length, dtype=jnp.float32)

    config.batch_size = len(states)
    env_state = factory(length, max_step_size=max_step_size, config=config)

    mask = jnp.zeros_like(env_state.forward_mask)
    mask = mask.at[:, -1].set(1.0)

    env_state = env_state.replace(
        state=states,
        forward_mask=mask,
        backward_mask=mask,
        stopped=jnp.ones_like(env_state.stopped),
        is_initial=jnp.zeros_like(env_state.is_initial),
    )

    return env_state


class LogRewardUniform(LogRewardBase):
    def __call__(self, state: EnvState) -> jax.Array:
        return jnp.ones_like(state.state)


class LogReward(LogRewardBase):
    def __init__(self, length: int, config: LogRewardConfig, *, ro: float = 0.5):
        super().__init__(config)
        self.ro = ro

        # We have a few, exponentially separated spots around the line with a higher reward.
        self.reward_spots = (
            jnp.logspace(
                0,
                jnp.log2(length - 1),
                num=4,
                base=2,
                dtype=jnp.int32,
                endpoint=True,
            )
            - 1
        )
        self.reward_spots = jnp.unique(self.reward_spots)
        self.reward_spots = jnp.expand_dims(self.reward_spots, axis=0)

    def __call__(self, state: EnvState) -> jax.Array:
        # We compute the reward of a state as exp(-ro * min(state - reward_spots))
        log_rewards = jnp.abs(
            state.state[:, None] - self.reward_spots
        )  # (batch_size, n_spots)
        log_rewards = jnp.min(log_rewards, axis=1)  # (batch_size,)
        return -self.ro * log_rewards / self.temperature


class LogRewardNormal(LogRewardBase):
    def __init__(self, length: int, config: LogRewardConfig):
        super().__init__(config)
        self.sigma = length / 8
        self.mu = length * 0.9

    def __call__(self, state: EnvState) -> jax.Array:
        log_rewards = jsp.stats.norm.logpdf(state.state, loc=self.mu, scale=self.sigma)
        return log_rewards / self.temperature


class LogRewardCentered(LogRewardBase):
    def __init__(self, length: int, config: LogRewardConfig):
        super().__init__(config)
        self.length = length
        assert self.length > 2
        self.center = jnp.array([3, 4, self.length - 4, self.length - 3])
        self.center = jnp.expand_dims(self.center, axis=0)

        self.sigma = self.length // 2

        self.weights = jnp.ones_like(self.center) / self.center.shape[1]

        self.weights = self.weights.at[:, 0].set(0.15)
        self.weights = self.weights.at[:, 1].set(0.2)
        self.weights /= self.weights.sum()

    def __call__(self, state: EnvState) -> jax.Array:
        # A Laplace-like log reward
        log_rewards = jax.nn.logsumexp(
            -jnp.abs(state.state[:, None] - self.center) / self.sigma
            + jnp.log(self.weights),
            axis=1,
        )
        return log_rewards / self.temperature

class LogRewardSteep(LogRewardBase):

    def __init__(self, length: int, config: LogRewardConfig, *, alpha: float = 1e-2, beta: float = 1.0):
        super().__init__(config)
        self.alpha = alpha
        self.beta = beta
        self.length = length

    def __call__(self, x: EnvState):
        rewards = jnp.where(
            (x.state == 0) | (x.state == 1) | ((x.state >= 3) & (x.state < self.length - 1)),
            self.alpha, 
            self.beta 
        )
        return jnp.log(rewards) / self.temperature
