from typing import Callable, Sequence, Optional
import jax
import jax.numpy as jnp
from flax import linen as nn
from src.models.common import MLP


class RND(nn.Module):
    hidden_dims: Sequence[int]
    activations: Callable[[jax.Array], jax.Array] = nn.relu
    layernorm: Optional[bool] = False
    dropout_rate: Optional[float] = None
    feature_dims: Optional[int] = 64

    @nn.compact
    def __call__(
        self,
        observations: jax.Array,
        actions: jax.Array,
        training: bool = True,
    ) -> jax.Array:

        inputs = jnp.concatenate([observations, actions], -1)
        pred = MLP(
            (*self.hidden_dims, self.feature_dims),
            activations=self.activations,
            layernorm=self.layernorm,
            dropout_rate=self.dropout_rate,
        )(inputs, training=training)

        target = MLP(
            (*self.hidden_dims, self.feature_dims),
            activations=self.activations,
            layernorm=self.layernorm,
            dropout_rate=self.dropout_rate,
        )(inputs, training=training)

        return pred, jax.lax.stop_gradient(target)
