from typing import Any, NamedTuple

from flax import nnx
from flax.nnx.filterlib import Filter
from jax import Array

from offline.modules.actor.base import Actor, DeterministicActor, GaussianActor
from offline.modules.base import TargetModel
from offline.modules.critic import QCriticEnsemble
from offline.modules.policy import Policy
from offline.types import ArrayLike


ActorCriticFilter = nnx.Any(
    nnx.PathContains("actor"), nnx.PathContains("critic")
)
ActorFilter = nnx.All(nnx.Param, nnx.PathContains("actor"))


class HDRPolicy(Policy[None]):
    POI: Filter | None = ActorCriticFilter

    def __init__(
        self,
        action_dim: int,
        behavior_actor: GaussianActor,
        deterministic: bool,
        ensemble_size: int,
        observation_dim: int,
        rngs: nnx.Rngs,
        **kwargs,
    ):
        if deterministic:
            self.actor: Actor = DeterministicActor(
                action_dim=action_dim,
                observation_dim=observation_dim,
                rngs=rngs,
                **kwargs,
            )
        else:
            self.actor = GaussianActor(
                action_dim=action_dim,
                observation_dim=observation_dim,
                rngs=rngs,
                **kwargs,
            )
        self.behavior_actor = behavior_actor
        self.critic = QCriticEnsemble(
            action_dim=action_dim,
            ensemble_size=ensemble_size,
            observation_dim=observation_dim,
            rngs=rngs,
            **kwargs,
        )

    def __call__(
        self, observations: ArrayLike, state: None
    ) -> tuple[Array, None, dict[str, Any]]:
        means, _ = self.behavior_actor(observations)
        deltas, _ = self.actor(observations)
        actions = means + deltas
        return actions, state, {}

    def eval(self, **attributes):
        self.actor.eval(**attributes)
        self.critic.eval(**attributes)

    def train(self, **attributes):
        self.actor.train(**attributes)
        self.critic.train(**attributes)


class HDRTrainState(NamedTuple):
    actor_optimizer: nnx.Optimizer
    critic_optimizer: nnx.Optimizer
    policy: HDRPolicy
    target_policy: TargetModel[HDRPolicy]
