from flax import nnx
from jax import numpy as jnp

from offline.modules.mlp import MLP, MLPEnsemble
from offline.types import ArrayLike


class QCritic(nnx.Module):
    def __init__(
        self, action_dim: int, observation_dim: int, rngs: nnx.Rngs, **kwargs
    ):
        self.model = MLP(
            in_features=action_dim + observation_dim,
            out_features=1,
            rngs=rngs,
            **kwargs
        )

    def __call__(self, observations: ArrayLike, actions: ArrayLike):
        inputs = jnp.concatenate((observations, actions), axis=-1)
        outputs = self.model(inputs)
        outputs = jnp.squeeze(outputs, axis=-1)
        return outputs


class QCriticEnsemble(nnx.Module):
    def __init__(
        self,
        action_dim: int,
        ensemble_size: int,
        observation_dim: int,
        rngs: nnx.Rngs,
        **kwargs
    ):
        self.model = MLPEnsemble(
            ensemble_size=ensemble_size,
            in_features=action_dim + observation_dim,
            out_features=1,
            rngs=rngs,
            **kwargs
        )

    def __call__(self, observations: ArrayLike, actions: ArrayLike):
        inputs = jnp.concatenate((observations, actions), axis=-1)
        outputs = self.model(inputs)
        outputs = jnp.squeeze(outputs, axis=-1)
        return outputs


class QCriticPair(QCriticEnsemble):
    def __init__(
        self, action_dim: int, observation_dim: int, rngs: nnx.Rngs, **kwargs
    ):
        super().__init__(
            action_dim=action_dim,
            ensemble_size=2,
            observation_dim=observation_dim,
            rngs=rngs,
            **kwargs
        )


class VCritic(nnx.Module):
    def __init__(self, observation_dim: int, rngs: nnx.Rngs, **kwargs):
        self.model = MLP(
            in_features=observation_dim, out_features=1, rngs=rngs, **kwargs
        )

    def __call__(self, observations: ArrayLike):
        outputs = self.model(observations)
        outputs = jnp.squeeze(outputs, axis=-1)
        return outputs


class VCriticEnsemble(nnx.Module):
    def __init__(
        self, ensemble_size: int, observation_dim: int, rngs: nnx.Rngs, **kwargs
    ):
        self.model = MLPEnsemble(
            ensemble_size=ensemble_size,
            in_features=observation_dim,
            out_features=1,
            rngs=rngs,
            **kwargs
        )

    def __call__(self, observations: ArrayLike):
        outputs = self.model(observations)
        outputs = jnp.squeeze(outputs, axis=-1)
        return outputs
