import math
from typing import Any, NamedTuple

from flax import nnx
from jax import Array, numpy as jnp

from offline.lbp.tc.modules import select_values
from offline.modules.actor.ensemble import GaussianActorEnsembleWithIndices
from offline.modules.mlp import MLP
from offline.modules.policy import Policy
from offline.utils.logger import Logger


ActorFilter = nnx.All(nnx.Param, nnx.PathContains("actor"))


class QCritic(nnx.Module):
    def __init__(
        self,
        action_dim: int,
        codebook_size: int,
        observation_dim: int,
        rngs: nnx.Rngs,
        **kwargs,
    ):
        self.model = MLP(
            in_features=action_dim + observation_dim,
            out_features=codebook_size,
            rngs=rngs,
            **kwargs,
        )

    def __call__(self, assignments: Array, observations: Array, actions: Array):
        inputs = jnp.concatenate((observations, actions), axis=-1)
        # [..., codebook_size]
        outputs = self.model(inputs)
        # [..., 1]
        assignments = jnp.expand_dims(assignments, -1)
        # [...]
        outputs = jnp.squeeze(
            jnp.take_along_axis(outputs, assignments, axis=-1), axis=-1
        )
        return outputs


class BPPOTCPolicy(Policy[None]):
    def __init__(
        self,
        actor: GaussianActorEnsembleWithIndices,
        classifier: MLP,
        critic: QCritic,
        high_level_critic: MLP,
        threshold: float,
    ):
        self.actor = actor
        self.classifier = classifier
        self.critic = critic
        self.high_level_critic = high_level_critic
        self.log_threshold = math.log(threshold)

    def __call__(
        self, observations: Array, state: None
    ) -> tuple[Array, None, dict[str, Any]]:
        # [..., num_embeddings]
        logits = self.classifier(observations)
        max_logits = jnp.max(logits, axis=-1, keepdims=True)
        # [..., num_embeddings]
        mask = logits >= max_logits - self.log_threshold
        # [..., num_emeddings]
        values = self.high_level_critic(observations)
        # [...]
        indices = jnp.argmax(select_values(mask, values), axis=-1)
        # [..., action_dim]
        actions, _ = self.actor(indices, observations)
        return (
            actions,
            state,
            {"logits": logits, "selected": indices, "values": values},
        )

    def eval(self, **attributes):
        self.actor.eval(**attributes)

    def save(self, step: int | None, logger: Logger) -> str:
        if step is None:
            return logger.save_model("actor", model=self.actor)
        return logger.save_model(
            "checkpoints", f"actor_{step}", model=self.actor
        )

    def train(self, **attributes):
        self.actor.train(**attributes)


class BehaviorState(NamedTuple):
    actor: GaussianActorEnsembleWithIndices
    critic: MLP
