import flax.nnx as nnx
import jax
import jax.numpy as jnp

from nais.gym.base import EnvState
from nais.nets.fwp import FWP
from nais.nets.srwm import SRWM, StateConditionalSRWM
from nais.nn import MLP
from nais.policies.base import (
    BackwardPolicyBase,
    BackwardPolicyConfig,
    FlowFunctionBase,
    FlowFunctionConfig,
    ForwardPolicyBase,
    ForwardPolicyConfig,
)


class FlowFunction(FlowFunctionBase):
    def __init__(self, config: FlowFunctionConfig):
        super().__init__(config)

        self.flow_nn = MLP(
            self.config.input_dim,
            self.config.hidden_dim,
            1,
            rngs=nnx.Rngs(42),
        )

    def _get_log_flows(self, state: jax.Array) -> jax.Array:
        return self.flow_nn(state)

    def __call__(self, state: EnvState) -> jax.Array:
        log_flows = self._get_log_flows(state.state)
        return log_flows


class ForwardPolicy(ForwardPolicyBase):
    def __init__(self, config: ForwardPolicyConfig):
        super().__init__(config)

        self.policy_nn = MLP(
            self.config.input_dim,
            self.config.hidden_dim,
            self.config.output_dim,
            rngs=nnx.Rngs(42),
        )

    def _get_logits(self, state: jax.Array) -> jax.Array:
        return self.policy_nn(state)

    def __call__(self, state: EnvState) -> jax.Array:
        logits = self._get_logits(state.state)
        return logits


class ForwardPolicyLSTM(ForwardPolicyBase):
    _carry: nnx.Data[tuple[jax.Array, jax.Array] | None]

    def __init__(self, config: ForwardPolicyConfig):
        super().__init__(config)
        self.linear_in = nnx.Linear(self.config.input_dim, self.config.hidden_dim, rngs=nnx.Rngs(0))
        self.lstm_cell = nnx.OptimizedLSTMCell(self.config.hidden_dim, self.config.hidden_dim, rngs=nnx.Rngs(0))
        self.linear_out = nnx.Linear(self.config.hidden_dim, self.config.output_dim, rngs=nnx.Rngs(1))

        # Using nnx.data ensures that the carry is stored in the computation graph
        self._carry = nnx.data(None)

    def _reset_carry(self, batch_size: int) -> tuple[jax.Array, jax.Array]:
        zeros = jnp.zeros((batch_size, self.config.hidden_dim))
        return zeros, zeros

    def _get_logits(self, state: jax.Array, carry: tuple[jax.Array, jax.Array]) -> tuple[tuple[jax.Array, jax.Array], jax.Array]:
        x = self.linear_in(state)
        carry, hidden = self.lstm_cell(carry, x)
        logits = self.linear_out(hidden)
        return carry, logits

    def __call__(self, state: EnvState) -> jax.Array:
        carry = self._carry
        carry, logits = self._get_logits(state.state, carry)
        self._carry = nnx.data(carry)
        return logits

    def lazy_init(self, state: EnvState):
        self._carry = self._reset_carry(state.batch_size)


class ForwardPolicySRWM(ForwardPolicyBase):
    def __init__(self, config: ForwardPolicyConfig):
        super().__init__(config)
        self.srwm = SRWM(
            self.config.input_dim,
            self.config.hidden_dim,
            self.config.output_dim,
            rngs=nnx.Rngs(42),
        )

    def __call__(self, state: EnvState) -> jax.Array:
        logits = self.srwm(state.state)
        return logits

    def lazy_init(self, state: EnvState):
        return self.srwm.lazy_init(batch_size=state.batch_size)


class ForwardPolicyFWP(ForwardPolicyBase):
    def __init__(self, config: ForwardPolicyConfig):
        super().__init__(config)
        self.fwp = FWP(self.config.input_dim, self.config.hidden_dim, self.config.output_dim)

    def __call__(self, state: EnvState) -> jax.Array:
        logits = self.fwp(state.state)
        return logits

    def lazy_init(self, state: EnvState):
        self.fwp.lazy_init(state.batch_size)


class BackwardPolicy(BackwardPolicyBase):
    def __call__(self, state: EnvState) -> jax.Array:
        # This returns a distribution over existing states
        total = jnp.sum(state.state, axis=1, keepdims=True)
        safe_total = jnp.clip(total, a_min=1.0)
        return state.state / safe_total


class BackwardPolicyMLP(BackwardPolicyBase):
    def __init__(self, config: BackwardPolicyConfig):
        super().__init__(config)
        self.mlp = MLP(
            self.config.input_dim,
            self.config.hidden_dim,
            self.config.output_dim,
            rngs=nnx.Rngs(config.key),
        )

    def __call__(self, state: EnvState) -> jax.Array:
        # This returns a distribution over existing states
        logits = self.mlp(state.state)
        return logits


class BackwardPolicySRWM(BackwardPolicyBase):
    def __init__(self, config: BackwardPolicyConfig):
        super().__init__(config)
        self.lc = StateConditionalSRWM(
            self.config.input_dim,
            self.config.hidden_dim,
            self.config.output_dim,
            rngs=nnx.Rngs(42),
        )

    def __call__(self, state: EnvState) -> jax.Array:
        # This returns a distribution over existing states
        logits = self.lc(state.state)
        return logits

    def lazy_init(self, state: EnvState):
        return self.lc.lazy_init(state.state)
