from typing import Callable, Sequence, Tuple, Optional
import jax
import jax.numpy as jnp
from flax import linen as nn
from src.models.common import MLP


class Critic(nn.Module):
    hidden_dims: Sequence[int]
    activations: Callable[[jax.Array], jax.Array] = nn.relu
    layernorm: Optional[bool] = False
    dropout_rate: Optional[float] = None

    @nn.compact
    def __call__(
        self,
        observations: jax.Array,
        actions: jax.Array,
        training: bool = True,
    ) -> jax.Array:

        inputs = jnp.concatenate([observations, actions], -1)
        critic = MLP(
            (*self.hidden_dims, 1),
            activations=self.activations,
            layernorm=self.layernorm,
            dropout_rate=self.dropout_rate,
        )(inputs, training=training)

        return jnp.squeeze(critic, -1)


class DoubleCritic(nn.Module):
    hidden_dims: Sequence[int]
    activations: Callable[[jax.Array], jax.Array] = nn.relu
    layernorm: Optional[bool] = False
    dropout_rate: Optional[float] = None

    @nn.compact
    def __call__(
        self,
        observations: jax.Array,
        actions: jax.Array,
        training: bool = True,
    ) -> Tuple[jax.Array, jax.Array]:

        critic1 = Critic(
            self.hidden_dims,
            activations=self.activations,
            layernorm=self.layernorm,
            dropout_rate=self.dropout_rate,
        )(observations, actions, training=training)

        critic2 = Critic(
            self.hidden_dims,
            activations=self.activations,
            layernorm=self.layernorm,
            dropout_rate=self.dropout_rate,
        )(observations, actions, training=training)

        return critic1, critic2


class TripleCritic(nn.Module):
    hidden_dims: Sequence[int]
    activations: Callable[[jax.Array], jax.Array] = nn.relu
    layernorm: Optional[bool] = False
    dropout_rate: Optional[float] = None

    @nn.compact
    def __call__(
        self,
        observations: jax.Array,
        actions: jax.Array,
        training: bool = True,
    ) -> Tuple[jax.Array, jax.Array]:

        critic1 = Critic(
            self.hidden_dims,
            activations=self.activations,
            layernorm=self.layernorm,
            dropout_rate=self.dropout_rate,
        )(observations, actions, training=training)

        critic2 = Critic(
            self.hidden_dims,
            activations=self.activations,
            layernorm=self.layernorm,
            dropout_rate=self.dropout_rate,
        )(observations, actions, training=training)

        critic3 = Critic(
            self.hidden_dims,
            activations=self.activations,
            layernorm=self.layernorm,
            dropout_rate=self.dropout_rate,
        )(observations, actions, training=training)

        return critic1, critic2, critic3


class QuadCritic(nn.Module):
    hidden_dims: Sequence[int]
    activations: Callable[[jax.Array], jax.Array] = nn.relu
    layernorm: Optional[bool] = False
    dropout_rate: Optional[float] = None

    @nn.compact
    def __call__(
        self,
        observations: jax.Array,
        actions: jax.Array,
        training: bool = True,
    ) -> Tuple[jax.Array, jax.Array]:

        critic1 = Critic(
            self.hidden_dims,
            activations=self.activations,
            layernorm=self.layernorm,
            dropout_rate=self.dropout_rate,
        )(observations, actions, training=training)

        critic2 = Critic(
            self.hidden_dims,
            activations=self.activations,
            layernorm=self.layernorm,
            dropout_rate=self.dropout_rate,
        )(observations, actions, training=training)

        critic3 = Critic(
            self.hidden_dims,
            activations=self.activations,
            layernorm=self.layernorm,
            dropout_rate=self.dropout_rate,
        )(observations, actions, training=training)

        critic4 = Critic(
            self.hidden_dims,
            activations=self.activations,
            layernorm=self.layernorm,
            dropout_rate=self.dropout_rate,
        )(observations, actions, training=training)

        return critic1, critic2, critic3, critic4


class DecaCritic(nn.Module):
    hidden_dims: Sequence[int]
    activations: Callable[[jax.Array], jax.Array] = nn.relu
    layernorm: Optional[bool] = False
    dropout_rate: Optional[float] = None

    @nn.compact
    def __call__(
        self,
        observations: jax.Array,
        actions: jax.Array,
        training: bool = True,
    ) -> Tuple[jax.Array, jax.Array]:

        critic1 = Critic(
            self.hidden_dims,
            activations=self.activations,
            layernorm=self.layernorm,
            dropout_rate=self.dropout_rate,
        )(observations, actions, training=training)

        critic2 = Critic(
            self.hidden_dims,
            activations=self.activations,
            layernorm=self.layernorm,
            dropout_rate=self.dropout_rate,
        )(observations, actions, training=training)

        critic3 = Critic(
            self.hidden_dims,
            activations=self.activations,
            layernorm=self.layernorm,
            dropout_rate=self.dropout_rate,
        )(observations, actions, training=training)

        critic4 = Critic(
            self.hidden_dims,
            activations=self.activations,
            layernorm=self.layernorm,
            dropout_rate=self.dropout_rate,
        )(observations, actions, training=training)

        critic5 = Critic(
            self.hidden_dims,
            activations=self.activations,
            layernorm=self.layernorm,
            dropout_rate=self.dropout_rate,
        )(observations, actions, training=training)

        critic6 = Critic(
            self.hidden_dims,
            activations=self.activations,
            layernorm=self.layernorm,
            dropout_rate=self.dropout_rate,
        )(observations, actions, training=training)

        critic7 = Critic(
            self.hidden_dims,
            activations=self.activations,
            layernorm=self.layernorm,
            dropout_rate=self.dropout_rate,
        )(observations, actions, training=training)

        critic8 = Critic(
            self.hidden_dims,
            activations=self.activations,
            layernorm=self.layernorm,
            dropout_rate=self.dropout_rate,
        )(observations, actions, training=training)

        critic9 = Critic(
            self.hidden_dims,
            activations=self.activations,
            layernorm=self.layernorm,
            dropout_rate=self.dropout_rate,
        )(observations, actions, training=training)

        critic10 = Critic(
            self.hidden_dims,
            activations=self.activations,
            layernorm=self.layernorm,
            dropout_rate=self.dropout_rate,
        )(observations, actions, training=training)

        return (
            critic1,
            critic2,
            critic3,
            critic4,
            critic5,
            critic6,
            critic7,
            critic8,
            critic9,
            critic10,
        )
