import argparse

import equinox as eqx
import jax
import jax.numpy as jnp

from flows import nn
from gridworld.utils import js_divergence


class RNNBeliefModel(eqx.Module):
    rnn: nn.MultiLayerLSTM
    head: eqx.nn.MLP

    def __call__(self, obs: jax.Array, h: jax.Array) -> jax.Array:
        out, h_new = self.update(obs, h)
        return self.head(out), h_new

    def reset(self, obs_shape: int):
        return jnp.zeros((obs_shape, self.rnn.num_layers, self.rnn.hidden_size * 2))

    def predict(self, h: jax.Array) -> jax.Array:
        return jax.nn.softmax(self.head(h[-1, : self.rnn.hidden_size]))

    def update(self, obs: jax.Array, h: jax.Array) -> jax.Array:
        return self.rnn(obs, h)

    def predict_seq(self, obss: jax.Array, initial_h: jax.Array) -> tuple[jax.Array]:
        def _scan_fn(h, obs):
            _, h_new = self.update(obs, h)
            return h_new, self.head(h_new[-1, : self.rnn.hidden_size])

        h_final, logits_seq = jax.lax.scan(_scan_fn, initial_h, obss)
        return logits_seq, h_final


def build_rnn_belief_net(args: argparse.Namespace, key: jax.Array) -> RNNBeliefModel:
    rnn_key, head_key = jax.random.split(key, 2)

    backbone = nn.MultiLayerLSTM(1, args.rnn_nlayers, args.rnn_hidden_size, key=rnn_key)
    head = eqx.nn.MLP(
        args.rnn_hidden_size,
        args.grid_size**args.ndim,
        args.mlp_hidden_size,
        args.mlp_nlayers,
        key=head_key,
    )

    return RNNBeliefModel(backbone, head)


@eqx.filter_jit
def train(model, opt, opt_state, obss, targets):
    @eqx.filter_value_and_grad
    def compute_loss(model, obss, h0, targets):
        logits_seq, _ = jax.vmap(model.predict_seq)(obss, h0)
        logits_seq = logits_seq.reshape((-1, targets.shape[-1]))
        targets = targets.reshape(-1, targets.shape[-1])
        return jnp.mean(jax.vmap(js_divergence)(jax.nn.softmax(logits_seq, axis=-1), targets))

    h0 = model.reset(obss.shape[0])
    loss, grads = compute_loss(model, obss, h0, targets)

    updates, opt_state = opt.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)

    return model, opt_state, loss
