from typing import Callable, Sequence, Tuple

import jax.numpy as jnp
from flax import linen as nn

from koopman.networks.common import MLP

class EncoderState(nn.Module):
    out_dim: int = 512
    activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu

    @nn.compact
    def __call__(self, observations: jnp.ndarray,) -> jnp.ndarray:
        output = MLP([256, 256], activations=self.activations, activate_final=True)(observations)
        return nn.Dense(features=self.out_dim, name='final')(output)

class DecoderState(nn.Module):
    observation_dim: int = 17
    activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu

    @nn.compact
    def __call__(self, state_reps: jnp.ndarray,) -> jnp.ndarray:
        output = MLP([256, 256], activations=self.activations, activate_final=True)(state_reps)
        return nn.Dense(features=self.observation_dim, name='final')(output)
