from typing import Any

from flax.linen import compact
from jax import Array, numpy as jnp

from offline.modules.base import Module
from offline.modules.mlp import MLP
from offline.types import ArrayLike


def select_action_with_logits(actions: Array, logits: Array):
    indices = jnp.argmax(logits, axis=-1, keepdims=True)
    indices = jnp.expand_dims(indices, -1)
    actions = jnp.take_along_axis(actions, indices, axis=-2)
    return jnp.squeeze(actions, -2)


class ConditionalDeterministicActor(Module):
    action_dim: int
    hidden_features: int = 256
    mlp_kwargs: dict[str, Any] | None = None

    @compact
    def __call__(self, observations: ArrayLike, latents: ArrayLike):
        mlp_kwargs = {} if self.mlp_kwargs is None else self.mlp_kwargs
        inputs = jnp.concatenate((observations, latents), axis=-1)
        actions = MLP(
            hidden_features=self.hidden_features,
            out_features=self.action_dim,
            **mlp_kwargs,
        )(inputs)
        return actions


class ConditionalGaussianActor(Module):
    action_dim: int
    hidden_features: int = 256
    log_std_max: float = 2
    log_std_min: float = -20
    mlp_kwargs: dict[str, Any] | None = None

    @compact
    def __call__(self, observations: ArrayLike, latents: ArrayLike):
        mlp_kwargs = {} if self.mlp_kwargs is None else self.mlp_kwargs
        inputs = jnp.concatenate((observations, latents), axis=-1)
        outputs = MLP(
            hidden_features=self.hidden_features,
            out_features=self.action_dim * 2,
            **mlp_kwargs,
        )(inputs)
        means, log_stds = jnp.split(outputs, 2, axis=-1)
        log_stds = jnp.clip(log_stds, self.log_std_min, self.log_std_max)
        return means, jnp.exp(log_stds)
