import math
from typing import Any, NamedTuple

from flax import nnx
from flax.nnx.filterlib import Filter
from jax import Array, lax, numpy as jnp

from offline.modules.actor.ensemble import GaussianActorEnsemble
from offline.modules.base import TargetModel
from offline.modules.mlp import MLP, MLPEnsemble
from offline.modules.policy import Policy


ActorCriticFilter = nnx.Any(
    nnx.PathContains("actor"), nnx.PathContains("critic")
)
ActorFilter = nnx.All(nnx.Param, nnx.PathContains("actor"))


def mask_values(mask: Array, values: Array):
    min_values = jnp.broadcast_to(-jnp.inf, values.shape)
    return lax.select(mask, values, min_values)


class QCriticEnsemble(nnx.Module):
    def __init__(
        self,
        action_dim: int,
        ensemble_size: int,
        num_behaviors: int,
        observation_dim: int,
        rngs: nnx.Rngs,
        **kwargs
    ):
        self.model = MLPEnsemble(
            ensemble_size=ensemble_size,
            in_features=action_dim + observation_dim,
            out_features=num_behaviors,
            rngs=rngs,
            **kwargs
        )

    def __call__(self, observations: Array, actions: Array):
        inputs = jnp.concatenate((observations, actions), axis=-1)
        outputs = self.model(inputs)
        return outputs


class LBPTCPolicy(Policy):
    POI: Filter | None = ActorCriticFilter

    def __init__(
        self,
        action_dim: int,
        behavior_actor: GaussianActorEnsemble,
        classifier: MLP,
        deltas_multiplier: float,
        ensemble_size: int,
        num_behaviors: int,
        observation_dim: int,
        rngs: nnx.Rngs,
        threshold: float,
        **kwargs
    ):
        self.actor = GaussianActorEnsemble(
            action_dim=action_dim,
            ensemble_size=num_behaviors,
            observation_dim=observation_dim,
            out_axis=-2,
            rngs=rngs,
            **kwargs
        )
        self.behavior_actor = behavior_actor
        self.classifier = classifier
        self.critic = QCriticEnsemble(
            action_dim=action_dim,
            ensemble_size=ensemble_size,
            num_behaviors=num_behaviors,
            observation_dim=observation_dim,
            rngs=rngs,
            **kwargs
        )
        self.deltas_multiplier = deltas_multiplier
        self.log_threshold = math.log(threshold)

    def __call__(
        self, observations: Array, state: None
    ) -> tuple[Array, None, dict[str, Any]]:
        logits = self.classifier(observations)
        max_logits = jnp.max(logits, axis=-1, keepdims=True)
        # [..., num_behaviors]
        mask = logits >= max_logits - self.log_threshold
        # [..., num_behaviors, action_dim]
        means, stds = self.behavior_actor(observations)
        deltas, _ = self.actor(observations)
        actions, info = self.compute_actions(
            deltas=deltas, mask=mask, means=means, observations=observations
        )
        info["deltas"] = deltas
        info["mask"] = mask
        info["means"] = means
        info["logits"] = logits
        info["stds"] = stds
        return actions, state, info

    def compute_actions(
        self, deltas: Array, mask: Array, means: Array, observations: Array
    ):
        # values: [..., num_behaviors]
        actions, masked_values, values = self.compute_candidates_values(
            deltas=deltas, mask=mask, means=means, observations=observations
        )
        # [...]
        indices = jnp.argmax(masked_values, axis=-1)
        # [..., 1, action_dim]
        actions = jnp.take_along_axis(
            actions, jnp.expand_dims(indices, (-1, -2)), axis=-2
        )
        # [..., action_dim]
        actions = jnp.squeeze(actions, axis=-2)
        return actions, {"selected": indices, "values": values}

    def compute_candidates_values(
        self, deltas: Array, mask: Array, means: Array, observations: Array
    ):
        # [..., num_behaviors, action_dim]
        actions = means + self.deltas_multiplier * deltas
        # [..., 1, observation_dim]
        observations = jnp.expand_dims(observations, -2)
        # [..., num_behaviors, observation_dim]
        observations = jnp.broadcast_to(
            observations, means.shape[:-1] + (observations.shape[-1],)
        )
        # [num_ensemble, ..., num_behaviors, num_behaviors]
        values = self.critic(observations, actions)
        # [..., num_behaviors, num_behaviors]
        values = jnp.min(values, axis=0)
        # [..., num_behaviors]
        values = jnp.diagonal(values, axis1=-2, axis2=-1)
        masked_values = mask_values(mask, values)
        return actions, masked_values, values

    def eval(self, **attributes):
        self.actor.eval(**attributes)
        self.critic.eval(**attributes)

    def train(self, **attributes):
        self.actor.train(**attributes)
        self.critic.train(**attributes)


class BehaviorState(NamedTuple):
    actor: GaussianActorEnsemble
    classifier: MLP
    critic: MLP


class LBPTCTrainState(NamedTuple):
    actor_optimizer: nnx.Optimizer
    critic_optimizer: nnx.Optimizer
    policy: LBPTCPolicy
    target_policy: TargetModel[LBPTCPolicy]
