import flax.nnx as nnx
import jax
import jax.numpy as jnp

# from nais.gym.sequences import Sequences
from nais.gym.base import EnvState
from nais.nets.fwp import FWP
from nais.nets.srwm import SRWM
from nais.nets.transformer import Transformer
from nais.nn import MLP
from nais.policies.base import (
    BackwardPolicyBase,
    FlowFunctionBase,
    FlowFunctionConfig,
    ForwardPolicyBase,
    ForwardPolicyConfig,
)


def get_in_nn(state: EnvState) -> jax.Array:
    return state.state


class FlowFunction(FlowFunctionBase):
    def __init__(self, seq_size: int, config: FlowFunctionConfig):
        super().__init__(config)

        self.flow_nn = MLP(
            seq_size,
            self.config.hidden_dim,
            1,
            rngs=nnx.Rngs(42),
        )

    def _get_log_flows(self, state: jax.Array) -> jax.Array:
        return self.flow_nn(state)

    def __call__(self, state: EnvState) -> jax.Array:
        log_flows = self._get_log_flows(state.state)
        return log_flows


class ForwardPolicy(ForwardPolicyBase):
    def __init__(
        self,
        seq_size: int,
        vocab_size: int,
        *,
        ngram_size: int = 1,
        config: ForwardPolicyConfig,
    ):
        super().__init__(config)

        self.policy_nn = MLP(
            seq_size,
            self.config.hidden_dim,
            vocab_size**ngram_size,
            rngs=nnx.Rngs(42),
        )

    def _get_logits(self, state: jax.Array) -> jax.Array:
        return self.policy_nn(state)

    def __call__(self, state: EnvState) -> jax.Array:
        in_nn = get_in_nn(state)
        logits = self._get_logits(in_nn)
        return logits


class ForwardPolicyLSTM(ForwardPolicyBase):
    _carry: nnx.Data[tuple[jax.Array, jax.Array] | None]

    def __init__(
        self,
        seq_size: int,
        vocab_size: int,
        *,
        ngram_size: int = 1,
        config: ForwardPolicyConfig,
    ):
        super().__init__(config)

        initializer = nnx.initializers.glorot_uniform()
        rngs = nnx.Rngs(42)

        self.linear_in = nnx.Linear(seq_size, self.config.hidden_dim, rngs=rngs, kernel_init=initializer)
        self.lstm_cell = nnx.OptimizedLSTMCell(
            self.config.hidden_dim,
            self.config.hidden_dim,
            rngs=rngs,
            kernel_init=initializer,
        )
        self.linear_out = nnx.Linear(
            self.config.hidden_dim,
            vocab_size**ngram_size,
            rngs=rngs,
            kernel_init=initializer,
        )

        # Using nnx.data ensures that the carry is stored in the computation graph
        self._carry = nnx.data(None)

    def _reset_carry(self, batch_size: int) -> tuple[jax.Array, jax.Array]:
        zeros = jnp.zeros((batch_size, self.config.hidden_dim))
        return zeros, zeros

    def _get_logits(
        self,
        state: jax.Array,
        carry: tuple[jax.Array, jax.Array],
    ) -> tuple[tuple[jax.Array, jax.Array], jax.Array]:
        x = self.linear_in(state)
        carry, hidden = self.lstm_cell(carry, x)
        logits = self.linear_out(hidden)
        return carry, logits

    def __call__(self, state: EnvState) -> jax.Array:
        carry = self._carry
        carry, logits = self._get_logits(get_in_nn(state), carry)
        self._carry = nnx.data(carry)
        return logits

    def lazy_init(self, state: EnvState):
        self._carry = self._reset_carry(state.batch_size)


class ForwardPolicySRWM(ForwardPolicyBase):
    def __init__(
        self,
        seq_size: int,
        vocab_size: int,
        *,
        ngram_size: int = 1,
        config: ForwardPolicyConfig,
    ):
        super().__init__(config)
        self.srwm = SRWM(seq_size, config.hidden_dim, vocab_size**ngram_size, rngs=nnx.Rngs(42))

    def __call__(self, state: EnvState) -> jax.Array:
        logits = self.srwm(get_in_nn(state))
        return logits

    def lazy_init(self, state: EnvState):
        self.srwm.lazy_init(state.batch_size)


class ForwardPolicyFWP(ForwardPolicyBase):
    def __init__(
        self,
        seq_size: int,
        vocab_size: int,
        *,
        ngram_size: int = 1,
        config: ForwardPolicyConfig,
    ):
        super().__init__(config)
        self.fwp = FWP(seq_size, config.hidden_dim, vocab_size**ngram_size)

    def __call__(self, state: EnvState) -> jax.Array:
        logits = self.fwp(get_in_nn(state))
        return logits

    def lazy_init(self, state: EnvState):
        self.fwp.lazy_init(state.batch_size)


class BackwardPolicy(BackwardPolicyBase):
    def __call__(self, state: EnvState) -> jax.Array:
        # Generation is autoregressive, so the probability of going back is unitary
        pol = jnp.ones_like(state.is_initial)
        # Backward action for autoregressive generation is unique
        pol = jnp.expand_dims(pol, axis=1)
        return pol
