"""Implementation of Future Predictor in JAX/NNX."""
import jax
import jax.numpy as jnp
import flax.nnx as nnx


class MLP(nnx.Module):
    def __init__(self, in_dim, out_dim, hidden_dim=1024, *, rngs):
        self.dense1 = nnx.Linear(in_dim, hidden_dim, rngs=rngs)
        self.dense2 = nnx.Linear(hidden_dim, out_dim, rngs=rngs)
        
    def __call__(self, x):
        x = self.dense1(x)
        x = nnx.gelu(x)
        x = self.dense2(x)
        return x


class VModel(nnx.Module):
    def __init__(self, n_embd, *, rngs):
        self.mlp = MLP(n_embd, n_embd, 4*n_embd, rngs=rngs)
        self.output_proj = nnx.Linear(n_embd, 1, rngs=rngs)
        
    def __call__(self, x):
        x = self.mlp(x)
        x = nnx.gelu(x)
        x = self.output_proj(x)
        return x
    
    
class QModel(nnx.Module):
    def __init__(self, n_embd, n_cond, *, rngs):
        self.cond_proj = nnx.Linear(n_cond, n_embd, rngs=rngs)
        self.mlp = MLP(2*n_embd, n_embd, 4*n_embd, rngs=rngs)
        self.output_proj = nnx.Linear(n_embd, 1, rngs=rngs)
        
    def __call__(self, x, c):
        c = self.cond_proj(c)
        x = jnp.concatenate([x, c], axis=-1)
        x = self.mlp(x)
        x = nnx.gelu(x)
        x = self.output_proj(x)
        return x
    
    
class HierarchicalQModel(nnx.Module):
    def __init__(self, n_embd, conds, *, rngs):
        self.cond_projs = {}
        self.mlps = {}
        for i, n_cond in enumerate(conds):
            self.cond_projs[f'cond_{i}'] = nnx.Linear(n_cond, n_embd, rngs=rngs)
            self.mlps[f'mlp_{i}'] = MLP(2*n_embd, n_embd, 4*n_embd, rngs=rngs)

        self.output_proj = nnx.Linear(n_embd, 1, rngs=rngs)
        
    def __call__(self, x, conds):
        assert len(conds) == len(self.cond_projs)
        outputs = []
        for i in range(len(conds)):
            c = self.cond_projs[f'cond_{i}'](conds[i])
            x = jnp.concatenate([x, c], axis=-1)
            x = self.mlps[f'mlp_{i}'](x)
            x = nnx.gelu(x)
            outputs.append(self.output_proj(x))
        return outputs