from flax import nnx
from jax import numpy as jnp
from jax.nn import softplus

from offline.modules.actor.base import EPS
from offline.modules.mlp import MLPEnsemble
from offline.types import ArrayLike


class DeterministicActorEnsemble(nnx.Module):
    def __init__(
        self,
        action_dim: int,
        ensemble_size: int,
        observation_dim: int,
        rngs: nnx.Rngs,
        **kwargs
    ):
        self.model = MLPEnsemble(
            ensemble_size=ensemble_size,
            in_features=observation_dim,
            out_features=action_dim,
            rngs=rngs,
            **kwargs
        )

    def __call__(self, observations: ArrayLike):
        # [codebook_size, ..., action_dim]
        actions = self.model(observations)
        return actions, None


class DeterministicActorEnsembleWithIndices(nnx.Module):
    def __init__(
        self,
        action_dim: int,
        ensemble_size: int,
        observation_dim: int,
        rngs: nnx.Rngs,
        **kwargs
    ):
        self.model = MLPEnsemble(
            ensemble_size=ensemble_size,
            in_features=observation_dim,
            out_features=action_dim,
            rngs=rngs,
            **kwargs
        )

    def __call__(self, indices: ArrayLike, observations: ArrayLike):
        # [1, ..., 1]
        indices = jnp.expand_dims(indices, (0, -1))
        # [codebook_size, ..., action_dim]
        actions = self.model(observations)
        # [..., action_dim]
        actions = jnp.squeeze(
            jnp.take_along_axis(actions, indices, axis=0), axis=0
        )
        return actions, None


class GaussianActorEnsemble(nnx.Module):
    def __init__(
        self,
        action_dim: int,
        ensemble_size: int,
        observation_dim: int,
        rngs: nnx.Rngs,
        eps: float = EPS,
        **kwargs
    ):
        self.eps = eps
        self.model = MLPEnsemble(
            ensemble_size=ensemble_size,
            in_features=observation_dim,
            out_features=action_dim * 2,
            rngs=rngs,
            **kwargs
        )

    def __call__(self, observations: ArrayLike):
        outputs = self.model(observations)
        means, stds = jnp.split(outputs, 2, axis=-1)
        stds = softplus(stds) + self.eps
        return means, stds


class GaussianActorEnsembleWithIndices(nnx.Module):
    def __init__(
        self,
        action_dim: int,
        ensemble_size: int,
        observation_dim: int,
        rngs: nnx.Rngs,
        eps: float = EPS,
        out_axis: int = 0,
        **kwargs
    ):
        if out_axis != 0:
            raise NotImplementedError("out_axis has to be 0")

        self.eps = eps
        self.model = MLPEnsemble(
            ensemble_size=ensemble_size,
            in_features=observation_dim,
            out_axis=out_axis,
            out_features=action_dim * 2,
            rngs=rngs,
            **kwargs
        )

    def __call__(self, indices: ArrayLike, observations: ArrayLike):
        # [1, ..., 1]
        indices = jnp.expand_dims(indices, (0, -1))
        # [codebook_size, ..., action_dim * 2]
        outputs = self.model(observations)
        # [..., action_dim * 2]
        outputs = jnp.squeeze(
            jnp.take_along_axis(outputs, indices, axis=0), axis=0
        )
        means, stds = jnp.split(outputs, 2, axis=-1)
        stds = softplus(stds) + self.eps
        return means, stds
