import flax.nnx as nnx
import flax.struct as struct
import jax
import jax.numpy as jnp

from divgfn.utils import HasState


def forward(y: jax.Array, layer: nnx.Linear):
    y = nnx.leaky_relu(y)
    y = layer(y)
    return y, None


class GFlowNet(nnx.Module):
    def __init__(self, din: int, dmid: int, dout: int, nlayers: int = 2, *, rngs: nnx.Rngs):
        self.din = din
        self.dmid = dmid
        self.dout = dout
        self.nlayers = nlayers

        self.logz = nnx.Param(jnp.array(0.0))

        self.linear_in = nnx.Linear(din, dmid, rngs=rngs)

        @nnx.split_rngs(splits=nlayers)
        @nnx.vmap(in_axes=(0,), out_axes=0)
        def create_layer(rngs: nnx.Rngs):
            return nnx.Linear(dmid, dmid, rngs=rngs)

        self.layers = create_layer(rngs)
        self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)

        # Backward
        self.blayers = create_layer(rngs)
        self.blinear_out = nnx.Linear(dmid, dout, rngs=rngs)

    def __call__(self, x: HasState):
        # We first compute the logits, then sample the actions
        # @nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry)

        y = self.linear_in(x.state)
        y, _ = jax.lax.scan(
            f=forward,
            init=y,
            xs=self.layers,
            length=self.nlayers,
        )
        logits = self.linear_out(nnx.leaky_relu(y))  # (B, dout)
        logits = nnx.log_softmax(logits, axis=1)
        return logits

    def bcall(self, x: HasState):
        y = self.linear_in(x.state)
        y, _ = jax.lax.scan(
            f=forward,
            init=y,
            xs=self.blayers,
            length=self.nlayers,
        )
        logits = self.blinear_out(nnx.leaky_relu(y))  # (B, dout)
        logits = nnx.log_softmax(logits, axis=1)
        return logits


def eps_greedy(logits: jax.Array, eps: float, mask: jax.Array):
    ulogits = jnp.where(mask, 1, -jnp.inf)
    ulogits = nnx.log_softmax(ulogits, axis=-1)

    # Handle edge cases
    slogits = jax.lax.cond(
        jnp.isclose(eps, 0.0),
        lambda: logits,
        lambda: jax.lax.cond(
            jnp.isclose(eps, 1.0),
            lambda: ulogits,
            lambda: jax.nn.logsumexp(
                jnp.stack([logits, ulogits], axis=0) + jnp.array([jnp.log(1 - eps), jnp.log(eps)])[:, None, None],
                axis=0,
            ),
        ),
    )
    return slogits


def take(x: jax.Array, i: jax.Array):
    # fill_value=0 ensures this behaves gracefully when there is no bijection between
    # forward and backward actions; ideally, this would be encoded into `i`
    return jnp.take_along_axis(x, i[:, None], axis=1, fill_value=0).squeeze(1)


def sample(gfn: GFlowNet, state: HasState, key: jax.Array, eps: float):
    logits = gfn(state)
    slogits = eps_greedy(logits, eps, mask=jnp.ones_like(logits))

    key, subkey = jax.random.split(key, 2)

    actions = jax.random.categorical(key=subkey, logits=slogits, axis=-1)
    logits = take(logits, actions)
    return logits, actions, key


def sample_with_mask(gfn: GFlowNet, state: HasState, key: jax.Array, eps: float, mask: jax.Array):
    logits = gfn(state)
    logits = jnp.where(mask == 1, logits, -jnp.inf)
    logits = nnx.log_softmax(logits, axis=-1)

    slogits = eps_greedy(logits, eps, mask=mask)

    key, subkey = jax.random.split(key, 2)

    actions = jax.random.categorical(key=subkey, logits=slogits, axis=-1)
    logits = take(logits, actions)
    return logits, actions, key
