import flax.struct as struct
import jax
import jax.numpy as jnp
import jax.scipy as jsp

from nais.gflownet import GFlowNet
from nais.gym.base import EnvironmentConfig, EnvState, LogRewardBase, LogRewardConfig


@struct.dataclass
class DiscreteDiffMetadata:
    T: int
    size: int


def move(env_state: EnvState, actions: jax.Array, direction: jax.Array, is_active: jax.Array):
    # direction is a vector of 1's and -1's
    state = env_state.state  # (B, d + 1)
    B, dp1 = state.shape
    d = dp1 - 1

    # actions[i] \in {0, ..., d - 1, d, ..., 2 * d - 1, 2 * d}
    is_stop_action = actions == 2 * d

    # Notice d is a valid index (corresponding to the timestamp coordinate)
    action_dim = actions % d
    action_dir = -2 * (actions // d) + 1

    # state = position + timestamp
    state = state.at[env_state.batch_ids, action_dim].add(
        jnp.where(is_stop_action | ~is_active, 0, direction * action_dir),
    )

    # Update timestamp
    state = state.at[:, -1].add(jnp.where(is_active, direction, 0))

    return state


def update_mask(state: jax.Array, size: int, T: int):
    # Only the :d indices matter (the dth indices is the timestamp)
    B, dp1 = state.shape
    d = dp1 - 1
    state_without_ts = state[:, :d]
    ts = state[:, d]

    allowed_p1 = state_without_ts < size  # (B, d)
    allowed_m1 = state_without_ts > -size  # (B, d)

    # We simply mask actions slipping away from the boundaries
    forward_mask = jnp.hstack(
        [allowed_p1, allowed_m1, jnp.ones((B, 1))],
        dtype=state.dtype,
    )

    # For the backward mask, we should also verify if the ensuing state is reachable
    # It may be shown that this is only the case if
    # (t - 1) - \sum x_i^{(t - 1)} >= 0.

    # Since the first d states correspond to adding 1 and the others, to subtracting 1,
    # we obtain the following condition

    curr_state_p1 = jnp.abs(state_without_ts[:, None, :] + jnp.eye(d, dtype=state.dtype)[None, ...])  # (B, d, d)
    curr_state_m1 = jnp.abs(state_without_ts[:, None, :] - jnp.eye(d, dtype=state.dtype)[None, ...])  # (B, d, d)
    curr_state = jnp.abs(state_without_ts)  # (B, d)

    allowed_p1_reach = ts[:, None] - curr_state_p1.sum(axis=2) > 0  # (B, d)
    allowed_m1_reach = ts[:, None] - curr_state_m1.sum(axis=2) > 0  # (B, d)

    # For the action to not move, we should ensure that the initial state is reachable
    allowed_stay = ts > curr_state.sum(axis=1)

    backward_mask = jnp.hstack(
        [
            # We exchange p1 and m1; for the backward action, direction = -1 in move.
            allowed_m1_reach & allowed_m1,
            allowed_p1_reach & allowed_p1,
            allowed_stay[..., None],
        ],
        dtype=state.dtype,
    )

    return forward_mask, backward_mask


# This will be environment-dependent
def apply_fn(gflownet_key: tuple[GFlowNet, jax.Array], _) -> GFlowNet:
    gflownet, key = gflownet_key

    env_state = gflownet.state.env_state
    is_active = env_state.stopped == 0

    out_pf = gflownet.pf.sample_actions(
        env_state,
        key,
    )  # (B, 2 * d + 1) [adding +1, -1 to each dim, or adding 0]

    state = move(
        env_state,
        out_pf.actions,
        jnp.ones((env_state.batch_size,)),
        is_active,
    )

    # We also update the forward/backward masks to avoid
    # the agent slips away from the environment's boundary
    forward_mask, backward_mask = update_mask(state, env_state.metadata.size, env_state.metadata.T)

    stopped = jnp.where(
        is_active,
        jnp.isclose(state[:, -1], env_state.metadata.T).astype(env_state.stopped.dtype),
        env_state.stopped,
    )
    is_initial = jnp.where(
        is_active,
        jnp.zeros_like(env_state.is_initial),
        env_state.is_initial,
    )

    env_state = env_state.replace(
        state=state,
        stopped=stopped,
        is_initial=is_initial,
        forward_mask=forward_mask,
        backward_mask=backward_mask,
    )

    log_pf = gflownet.state.log_pf.at[:, gflownet.state.idx].set(
        jnp.where(is_active, out_pf.log_pf, 0.0),
    )

    gflownet.state = gflownet.state.replace(
        env_state=env_state,
        log_pf=log_pf,
        idx=gflownet.state.idx + 1,
    )

    return (gflownet, out_pf.key), (out_pf.actions, is_active, env_state)


def backward_fn(gflownet_key: tuple[GFlowNet, jax.Array], _):
    gflownet, key = gflownet_key

    env_state = gflownet.state.env_state
    is_active = env_state.is_initial == 0

    out_pb = gflownet.pb.sample_actions(
        env_state,
        key,
    )  # (B, 2 * d + 1) [adding +1, -1 to each dim]

    state = move(
        env_state,
        out_pb.actions,
        -jnp.ones(env_state.batch_size),
        is_active,
    )

    # We also update the forward/backward masks to avoid
    # the agent slips away from the environment's boundary
    forward_mask, backward_mask = update_mask(state, env_state.metadata.size, env_state.metadata.T)

    stopped = jnp.where(
        is_active,
        jnp.zeros_like(env_state.stopped),
        env_state.stopped,
    )
    is_initial = jnp.where(
        is_active,
        jnp.isclose(state[:, -1], 0).astype(env_state.is_initial.dtype),
        env_state.is_initial,
    )

    env_state = env_state.replace(
        state=state,
        stopped=stopped,
        is_initial=is_initial,
        forward_mask=forward_mask,
        backward_mask=backward_mask,
    )

    idx = gflownet.state.idx - 1

    log_pb = gflownet.state.log_pb.at[:, idx].set(
        jnp.where(is_active, out_pb.log_pb, 0.0),
    )

    gflownet.state = gflownet.state.replace(
        env_state=env_state,
        log_pb=log_pb,
        idx=idx,
    )

    return (gflownet, out_pb.key), (out_pb.actions, is_active, env_state)


def factory(
    T: int,
    size: int,
    d: int,
    config: EnvironmentConfig,
):
    return EnvState(
        state=jnp.zeros((config.batch_size, d + 1)),
        forward_mask=jnp.ones((config.batch_size, 2 * d + 1)),
        backward_mask=jnp.zeros((config.batch_size, 2 * d + 1)),
        batch_ids=jnp.arange(config.batch_size),
        is_initial=jnp.ones((config.batch_size,)),
        stopped=jnp.zeros((config.batch_size,)),
        batch_size=config.batch_size,
        max_trajectory_length=T + 1,
        metadata=DiscreteDiffMetadata(T=T, size=size),
    )


def get_discrete_diff_grid(T: int, size: int, d: int, config: EnvironmentConfig) -> EnvState:
    assert d == 2  # We only plot a 2-dimensional grid

    states = jnp.meshgrid(*[jnp.arange(2 * size + 1, dtype=jnp.float32) - size for _ in range(d)])
    states = jnp.stack(states, axis=0)
    states = states.reshape(d, -1).T
    states = jnp.hstack(
        [states, jnp.ones((len(states), 1)) * T],
    )

    config.batch_size = len(states)
    env_state = factory(T, size, d, config)

    forward_mask, backward_mask = update_mask(states, size, T)

    env_state = env_state.replace(
        state=states,
        forward_mask=forward_mask,
        backward_mask=backward_mask,
        stopped=jnp.ones_like(env_state.stopped),
        is_initial=jnp.zeros_like(env_state.is_initial),
    )

    return env_state


class LogReward(LogRewardBase):
    def __init__(self, d: int, sigma: float, size: int, *, config: LogRewardConfig):
        super().__init__(config=config)
        self.d = d
        self.sigma = sigma
        self.size = size

        # We consider a simple multivariate Gaussian
        self.mu = jnp.zeros((d,))
        self.V = self.sigma * jnp.eye(d)

    def __call__(self, state: EnvState):
        log_r = jsp.stats.multivariate_normal.logpdf(
            state.state[:, : self.d] / self.size,
            mean=self.mu,
            cov=self.V,
        )
        return log_r / self.temperature


def two_rings_logpdf(
    x: jax.Array,
    r: jax.Array = jnp.array([0.2, 0.9]),
    s: float = 0.05,
    p: float = 0.4,
    eps: float = 1e-6,
):
    k = len(r)

    # (1 - p) / (k - 1), ..., (1 - p) / (k - 1), p
    w = (1 - p) / (k - 1) * jnp.ones_like(r)
    w = w.at[-1].set(p)
    log_w = jnp.log(w)

    # rho = ||x||
    rho = jnp.linalg.norm(x, axis=-1)

    log_g = -0.5 * ((rho[..., None] - r[None, ...]) / s) ** 2
    log_mix = jax.nn.logsumexp(log_w + log_g, axis=1)

    return log_mix


class LogRewardRings(LogRewardBase):
    def __init__(self, d: int, size: int, *, config: LogRewardConfig):
        super().__init__(config=config)
        self.d = d
        self.size = size

    def __call__(self, state: EnvState):
        state = state.state[:, : self.d] / self.size
        log_r = two_rings_logpdf(state)
        return log_r / self.temperature


def lp(x: jax.Array, p: int = 2, axis: int = -1):
    return jnp.sum(jnp.abs(x) ** p, axis=axis)


def four_gaussians_logpdf(
    x: jax.Array,
    mu1: jax.Array,
    mu2: jax.Array,
    mu3: jax.Array,
    mu4: jax.Array,
    sigma: float,
):
    g1 = jnp.exp(-0.5 * (lp(x - mu1[None, ...]) / sigma) ** 2)
    g2 = jnp.exp(-0.5 * (lp(x - mu2[None, ...]) / sigma) ** 2)
    g3 = jnp.exp(-0.5 * (lp(x - mu3[None, ...]) / sigma) ** 2)
    g4 = jnp.exp(-0.5 * (lp(x - mu4[None, ...]) / sigma) ** 2)
    mix = g1 + g2 + g3 + g4
    return jnp.log(mix) - jnp.log(4 * jnp.pi * sigma * jnp.sqrt(2 * jnp.pi))


class LogRewardFourGaussians(LogRewardBase):
    def __init__(self, d: int, size: int, *, config: LogRewardConfig):
        super().__init__(config=config)
        self.d = d
        self.size = size

        self.pad = 0.2
        # We include the mu's at the corner of the 2-dimensional hypercube
        self.mus = jnp.array(
            [
                [-1 + self.pad, -1 + self.pad],
                [-1 + self.pad, 1 - self.pad],
                [1 - self.pad, -1 + self.pad],
                [1 - self.pad, 1 - self.pad],
            ]
        )
        if d > 2:
            self.mus = jnp.concatenate([self.mus, self.zeros((4, d - 2))], axis=-1)

    def __call__(self, state: EnvState):
        state = state.state[:, : self.d] / self.size
        log_r = four_gaussians_logpdf(
            state,
            self.mus[0],
            self.mus[1],
            self.mus[2],
            self.mus[3],
            0.1,
        )
        return log_r / self.temperature


class LogRewardGraded(LogRewardBase):
    def __init__(self, d: int, *, omega: float = 1.0, axis: int = 0, config: LogRewardConfig):
        super().__init__(config=config)
        self.d = d
        self.omega = omega
        self.axis = axis
        assert axis < self.d

    def __call__(self, state: EnvState):
        s = state.state
        x = s[:, self.axis]
        return jnp.cos(x * self.omega)
