import jax
import jax.numpy as jnp
import flax.linen as nn
from typing import Tuple

class BaseMLP(nn.Module):
    input_dim: int
    output_dim: int
    hdim: int
    activ: str = 'elu'
    
    @nn.compact
    def __call__(self, x: jnp.ndarray, train: bool = False):
        activ = getattr(jax.nn, self.activ)
        
        # Layer 1
        x = nn.Dense(self.hdim)(x)
        x = nn.LayerNorm()(x)
        x = activ(x)
        
        # Layer 2
        x = nn.Dense(self.hdim)(x)
        x = nn.LayerNorm()(x)
        x = activ(x)
        
        # Output layer
        x = nn.Dense(self.output_dim)(x)
        return x

class Encoder(nn.Module):
    state_dim: int
    action_dim: int
    pixel_obs: bool = False
    zs_dim: int = 512
    za_dim: int = 256
    zsa_dim: int = 512
    hdim: int = 512
    activ: str = 'elu'
    
    def setup(self):
        self.zs_dim = self.zs_dim
        
        if self.pixel_obs:
            self.zs_cnn1 = nn.Conv(32, kernel_size=(3, 3), strides=(2, 2))
            self.zs_cnn2 = nn.Conv(32, kernel_size=(3, 3), strides=(2, 2))
            self.zs_cnn3 = nn.Conv(32, kernel_size=(3, 3), strides=(2, 2))
            self.zs_cnn4 = nn.Conv(32, kernel_size=(3, 3), strides=(1, 1))
            self.zs_lin = nn.Dense(self.zs_dim)
        else:
            self.zs_mlp = BaseMLP(
                input_dim=self.state_dim,
                output_dim=self.zs_dim,
                hdim=self.hdim,
                activ=self.activ
            )
        
        self.za = nn.Dense(self.za_dim)
        self.zsa = BaseMLP(
            input_dim=self.zs_dim + self.za_dim,
            output_dim=self.zsa_dim,
            hdim=self.hdim,
            activ=self.activ
        )
        self.model = nn.Dense(self.zs_dim + 2)  # +2 for done and scalar reward
    
    def __call__(self, state: jnp.ndarray, action: jnp.ndarray, train: bool = False) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
        # Get state encoding
        if self.pixel_obs:
            zs = self.cnn_zs(state, train)
        else:
            zs = self.mlp_zs(state, train)
        
        # Get action encoding
        za = getattr(jax.nn, self.activ)(self.za(action))
        
        # Get state-action encoding
        zsa = self.zsa(jnp.concatenate([zs, za], -1), train)
        
        # Get predictions
        dzr = self.model(zsa)
        return dzr[:, 0:1], dzr[:, 1:self.zs_dim+1], dzr[:, self.zs_dim+1:]  # done, zs, reward
    
    def cnn_zs(self, state: jnp.ndarray, train: bool) -> jnp.ndarray:
        state = state/255. - 0.5
        x = getattr(jax.nn, self.activ)(self.zs_cnn1(state))
        x = getattr(jax.nn, self.activ)(self.zs_cnn2(x))
        x = getattr(jax.nn, self.activ)(self.zs_cnn3(x))
        x = getattr(jax.nn, self.activ)(self.zs_cnn4(x))
        x = x.reshape(state.shape[0], -1)
        x = nn.LayerNorm()(x)
        x = getattr(jax.nn, self.activ)(self.zs_lin(x))
        return x
    
    def mlp_zs(self, state: jnp.ndarray, train: bool) -> jnp.ndarray:
        return self.zs_mlp(state, train)
