import jax
import jax.numpy as jnp
import equinox as eqx
from typing import Tuple

from pomdp.utils import js_divergence
import nn

class RNNBeliefModel(eqx.Module):
    rnn: nn.MultiLayerLSTM
    head: eqx.nn.MLP

    def __call__(self, obs: jax.Array, h: jax.Array) -> jax.Array:
        out, h_new = self.update(obs, h)
        return self.head(out), h_new
    
    def reset(self, obs_shape: int):
        return jnp.zeros((obs_shape, self.rnn.num_layers, self.rnn.hidden_size * 2))

    def predict(self, h: jax.Array) -> jax.Array:
        return jax.nn.softmax(self.head(h[-1, :self.rnn.hidden_size]))
    
    def update(self, obs: jax.Array, h: jax.Array) -> jax.Array:
        return self.rnn(obs, h)
    
    def predict_seq(self, obss:jax.Array, initial_h: jax.Array) -> Tuple[jax.Array]:
        
        def _scan_fn(h, obs):
            _, h_new = self.update(obs, h)
            return h_new, self.head(h_new[-1, :self.rnn.hidden_size])

        h_final, logits_seq = jax.lax.scan(_scan_fn, initial_h, obss)
        return logits_seq, h_final
    

@eqx.filter_value_and_grad
def seq_loss(model: RNNBeliefModel, obss: jax.Array, h0: jax.Array, targets: jax.Array) -> float:
    logits_seq, _ = jax.vmap(model.predict_seq)(obss, h0)
    logits_seq = logits_seq.reshape((-1, targets.shape[-1]))
    targets = targets.reshape(-1, targets.shape[-1])
    return jnp.mean(jax.vmap(js_divergence)(jax.nn.softmax(logits_seq, axis=-1), targets))

def train(model, opt, opt_state, obss, targets):
    h0 = model.reset(obss.shape[0])
    loss_value, grads = seq_loss(model, obss, h0, targets)
    updates, opt_state = opt.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return model, opt_state, loss_value