import jax
import jax.numpy as jnp
import flax.nnx as nnx

from nais.policies.base import (
    ForwardPolicyBase,
    ForwardPolicyConfig,
    BackwardPolicyBase,
    FlowFunctionBase,
    FlowFunctionConfig,
)

# from nais.gym.lines import Lines

from nais.gym.base import EnvState

from nais.nn import MLP
from nais.nets.srwm import SRWM
from nais.nets.fwp import FWP
from nais.nets.linear import Perceptron, LinearControl


class FlowFunction(FlowFunctionBase):
    def __init__(self, config: FlowFunctionConfig):
        super().__init__(config)

        self.flow_nn = MLP(1, self.config.hidden_dim, 1, rngs=nnx.Rngs(42))

    def _get_log_flows(self, state: jax.Array) -> jax.Array:
        return self.flow_nn(state[:, None])

    def __call__(self, state: EnvState) -> jax.Array:
        log_flows = self._get_log_flows(state.state)
        return log_flows


class ForwardPolicy(ForwardPolicyBase):
    def __init__(self, length: int, max_step_size: int, config: ForwardPolicyConfig):
        super().__init__(config)
        self.length = length

        self.policy_nn = MLP(
            1,
            self.config.hidden_dim,
            max_step_size + 1,
            rngs=nnx.Rngs(42),
            layer_norm=False,
        )

    def _get_logits(self, state: jax.Array) -> jax.Array:
        return self.policy_nn(state[:, None])

    def __call__(self, state: EnvState) -> jax.Array:
        # This is needed for numerical stability
        # Importantly, this is not needed for the path-dependent algorithm
        logits = self._get_logits(state.state)
        return logits


class ForwardPolicyLSTM(ForwardPolicyBase):
    _carry: nnx.Data[tuple[jax.Array, jax.Array] | None]

    def __init__(self, length: int, max_step_size: int, config: ForwardPolicyConfig):
        super().__init__(config)
        self.length = length
        self.max_step_size = max_step_size

        self.linear_in = nnx.Linear(1, self.config.hidden_dim, rngs=nnx.Rngs(0))
        self.lstm_cell = nnx.LSTMCell(
            self.config.hidden_dim, self.config.hidden_dim, rngs=nnx.Rngs(0)
        )
        self.linear_out = nnx.Linear(
            self.config.hidden_dim, self.max_step_size + 1, rngs=nnx.Rngs(1)
        )

        # Using nnx.data ensures that the carry is stored in the computation graph
        self._carry = nnx.data(None)

    def _reset_carry(self, batch_size: int) -> tuple[jax.Array, jax.Array]:
        zeros = jnp.zeros((batch_size, self.config.hidden_dim))
        return zeros, zeros

    def _get_logits(
        self, state: jax.Array, carry: tuple[jax.Array, jax.Array]
    ) -> tuple[tuple[jax.Array, jax.Array], jax.Array]:
        x = self.linear_in(state)
        carry, hidden = self.lstm_cell(carry, x)
        logits = self.linear_out(hidden)
        return carry, logits

    def __call__(self, state: EnvState) -> jax.Array:
        carry = self._carry
        carry, logits = self._get_logits(state.state[:, None], carry)
        self._carry = nnx.data(carry)
        return logits

    def lazy_init(self, state: EnvState):
        self._carry = self._reset_carry(state.batch_size)


class ForwardPolicySRWM(ForwardPolicyBase):
    def __init__(self, length: int, max_step_size: int, config: ForwardPolicyConfig):
        super().__init__(config)
        self.length = length
        self.max_step_size = max_step_size

        self.srwm = SRWM(
            1,
            self.config.hidden_dim,
            self.max_step_size + 1,
            layer_norm=False,
            rngs=nnx.Rngs(42),
        )

    def __call__(self, state: EnvState) -> jax.Array:
        logits = self.srwm(state.state[:, None])
        return logits

    def lazy_init(self, state: EnvState):
        self.srwm.lazy_init(state.batch_size)


class ForwardPolicyFWP(ForwardPolicyBase):
    def __init__(self, length: int, max_step_size: int, config: ForwardPolicyConfig):
        super().__init__(config)
        self.length = length
        self.max_step_size = max_step_size

        self.fwp = FWP(1, self.config.hidden_dim, self.max_step_size + 1)

    def __call__(self, state: EnvState) -> jax.Array:
        logits = self.fwp(state.state[:, None])
        return logits

    def lazy_init(self, state: EnvState):
        self.fwp.lazy_init(state.batch_size)


class ForwardPolicyPerceptron(ForwardPolicyBase):
    def __init__(self, length: int, max_step_size: int, config: ForwardPolicyConfig):
        super().__init__(config)
        self.length = length
        self.max_step_size = max_step_size

        self.perceptron = Perceptron(
            2, self.config.hidden_dim, max_step_size + 1, rngs=nnx.Rngs(42)
        )

    def __call__(self, state: EnvState) -> jax.Array:
        state = state.state[:, None]
        state = jnp.hstack([state, jnp.ones_like(state)])
        logits = self.perceptron(state)
        return logits


class ForwardPolicyLC(ForwardPolicyBase):
    def __init__(self, length: int, max_step_size: int, config: ForwardPolicyConfig):
        super().__init__(config)
        self.length = length
        self.max_step_size = max_step_size

        self.linear_control = LinearControl(
            2, self.config.hidden_dim, max_step_size + 1, rngs=nnx.Rngs(42)
        )

    def __call__(self, state: EnvState):
        state = state.state[:, None]
        state = jnp.hstack([state, jnp.ones_like(state)])
        logits = self.linear_control(state)
        return logits

    def lazy_init(self, state: EnvState):
        self.linear_control.lazy_init(state.batch_size)


class BackwardPolicy(BackwardPolicyBase):
    def __call__(self, state: EnvState) -> jax.Array:
        # The masking is done afterwards on the base class
        return jnp.ones_like(state.backward_mask)
