import flax.nnx as nnx
import jax
import jax.numpy as jnp

# MLP


class MLP(nnx.Module):
    def __init__(
        self,
        din: int,
        dmid: int,
        dout: int,
        layer_norm: bool = False,
        *,
        rngs: nnx.Rngs,
    ):
        self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
        self.linear2 = nnx.Linear(dmid, dmid, rngs=rngs)
        # Layer norm seems unhelpful here for some environments
        self._layer_norm = layer_norm
        if self._layer_norm:
            self.layer_norm = nnx.LayerNorm(dmid, rngs=rngs)
        self.linear3 = nnx.Linear(dmid, dout, rngs=rngs)

    def __call__(self, x: jax.Array):
        y = self.linear1(x)
        y = nnx.leaky_relu(y)
        y = self.linear2(y)
        y = nnx.leaky_relu(y)
        if self._layer_norm:
            y = self.layer_norm(y)
        y = self.linear3(y)
        return y


# LSTM


def create_cell(in_size: int, hidden_size: int):
    return nnx.OptimizedLSTMCell(in_size, hidden_size, rngs=nnx.Rngs(0))


def scan_fn(cell: nnx.LSTMCell):
    def inner_scan_op(carry, xs):
        h_prev, c_prev = carry

        h_t, c_t = cell((h_prev, c_prev), xs)

        return h_t, c_t

    return inner_scan_op


@nnx.vmap(in_axes=(None, 0), out_axes=0)
def process_batch(cell: nnx.LSTMCell, sequence_batch: jax.Array):
    scan_op = scan_fn(cell)

    h0 = jnp.zeros(cell.hidden_features)
    c0 = jnp.zeros(cell.hidden_features)

    final_carry, all_h_states = jax.lax.scan(f=scan_op, init=(h0, c0), xs=sequence_batch)

    return all_h_states, final_carry[1]


class LSTMModel(nnx.Module):
    def __init__(self, in_size: int, hidden_size: int, out_size: int):
        self.lstm_cell = create_cell(in_size, hidden_size)
        # Use a separate rngs for Linear
        self.linear = nnx.Linear(hidden_size, out_size, rngs=nnx.Rngs(1))

    def __call__(self, x: jax.Array):
        if x.ndim == 2:
            x = jnp.expand_dims(x, axis=0)  # Shape: (1, 128, 8)

        _, final_h_state = process_batch(self.lstm_cell, x)

        out = self.linear(final_h_state)

        return out
