import flax.nnx as nnx
import jax
import jax.numpy as jnp

from nais.gym.base import EnvState
from nais.nets.fwp import FWP
from nais.nets.linear import LinearControl, Perceptron
from nais.nets.srwm import SRWM, StateConditionalSRWM
from nais.nn import MLP
from nais.policies.base import (
    BackwardPolicyBase,
    BackwardPolicyConfig,
    ForwardPolicyBase,
    ForwardPolicyConfig,
)


class FourierPositionalEncoding(nnx.Module):
    def __init__(self, hidden_dim: int):
        assert hidden_dim % 2 == 0
        self.d = hidden_dim

        k = jnp.arange(self.d // 2)
        # ω_k = 10000^{-2k/d}
        self.omega = 1.0 / (10000 ** (2 * k / self.d))

    def __call__(self, p: jax.Array):
        B = p.shape[0]
        p = jnp.asarray(p)[..., None]  # (B, 1)
        angles = p * self.omega  # (B, d/2)

        sin = jnp.sin(angles)  # (B, d/2)
        cos = jnp.cos(angles)  # (B, d/2)

        # interleave [sin0, cos0, sin1, cos1, ...]
        return jnp.stack([sin, cos], axis=2).reshape(B, self.d)


class LinearEmbedding(nnx.Module):
    def __init__(self, din: int, dout: int, tdim: int = -1, *, rngs: jax.random.PRNGKey, norm: float = 1):
        self.din = din
        self.dout = dout
        self.tdim = tdim
        self.norm = norm

        self.emb = nnx.Linear(din, dout, rngs=rngs)
        self.positional_encoding = FourierPositionalEncoding(dout)

    def __call__(self, s: jax.Array):
        x = s[:, : self.tdim] / self.norm
        t = s[:, self.tdim]

        xe = self.emb(x)
        pe = self.positional_encoding(t)
        return jnp.concatenate([xe, pe], axis=-1)


class ForwardPolicy(ForwardPolicyBase):
    def __init__(self, d: int, size: int, *, config: ForwardPolicyConfig):
        super().__init__(config)
        self.d = d
        self.size = size

        rngs = nnx.Rngs(42)

        self.linear_embedding = LinearEmbedding(self.d, config.hidden_dim, rngs=rngs, norm=self.size)

        self.mlp = MLP(
            din=2 * config.hidden_dim,
            dmid=config.hidden_dim,
            dout=2 * d + 1,
            rngs=rngs,
            layer_norm=False,
        )

    def __call__(self, state: EnvState):
        y = self.linear_embedding(state.state)
        y = self.mlp(y)
        return y


class ForwardPolicySRWM(ForwardPolicyBase):
    def __init__(self, d: int, size: int, *, config: ForwardPolicyConfig):
        super().__init__(config)
        self.d = d
        self.size = size

        rngs = nnx.Rngs(42)

        self.linear_embedding = LinearEmbedding(self.d, self.config.hidden_dim, rngs=rngs, norm=self.size)

        self.srwm = SRWM(
            2 * self.config.hidden_dim,
            self.config.hidden_dim,
            2 * self.d + 1,
            rngs=rngs,
        )

    def __call__(self, state: EnvState) -> jax.Array:
        y = self.linear_embedding(state.state)
        logits = self.srwm(y)  # self.srwm(state.state[:, : self.d] / self.size)
        return logits

    def lazy_init(self, state: EnvState):
        return self.srwm.lazy_init(batch_size=state.batch_size)


class BackwardPolicy(BackwardPolicyBase):
    def __call__(self, state: EnvState):
        # Masking is applied afterwards (at the sample_actions's method)
        return jnp.ones_like(state.backward_mask)


class BackwardPolicyMLP(BackwardPolicyBase):
    def __init__(self, d: int, size: int, *, config: BackwardPolicyConfig):
        super().__init__(config)
        rngs = nnx.Rngs(42)
        self.d = d
        self.size = size

        self.linear_embedding = LinearEmbedding(self.d, self.config.hidden_dim, rngs=rngs, norm=size)

        self.mlp = MLP(
            2 * self.config.hidden_dim,
            self.config.hidden_dim,
            2 * self.d + 1,
            rngs=nnx.Rngs(config.key),
        )

    def __call__(self, state: EnvState) -> jax.Array:
        # This returns a distribution over existing states
        y = self.linear_embedding(state.state)
        logits = self.mlp(y)
        return logits


class BackwardPolicySRWM(BackwardPolicyBase):
    def __init__(self, d: int, size: int, *, config: BackwardPolicyConfig):
        super().__init__(config)
        self.d = d
        self.size = size

        rngs = nnx.Rngs(42)
        self.linear_embedding = LinearEmbedding(self.d, self.config.hidden_dim, rngs=rngs, norm=self.size)

        self.lc = StateConditionalSRWM(
            2 * self.config.hidden_dim,
            self.config.hidden_dim,
            2 * self.d + 1,
            rngs=nnx.Rngs(42),
        )

    def __call__(self, state: EnvState) -> jax.Array:
        # This returns a distribution over existing states
        y = self.linear_embedding(state.state)
        logits = self.lc(y)
        return logits

    def lazy_init(self, state: EnvState):
        y = self.linear_embedding(state.state)
        return self.lc.lazy_init(y)


if __name__ == "__main__":
    dim = 32
    pe = FourierPositionalEncoding(dim)
    res = pe(jnp.array([0.0, 1.0]))

    assert res.shape == (2, 32)
    jax.debug.print("{}", res)
