from typing import Any, NamedTuple

from flax import nnx
from jax import Array

from offline.modules.actor.base import DeterministicActor
from offline.modules.base import TargetModel
from offline.modules.critic import QCriticEnsemble
from offline.modules.policy import Policy
from offline.types import ArrayLike


ActorFilter = nnx.All(nnx.Param, nnx.PathContains("actor"))


class SVRPolicy(Policy[None]):
    def __init__(
        self,
        action_dim: int,
        ensemble_size: int,
        observation_dim: int,
        rngs: nnx.Rngs,
        **kwargs
    ):
        self.actor = DeterministicActor(
            action_dim=action_dim,
            observation_dim=observation_dim,
            rngs=rngs,
            squash=True,
            **kwargs,
        )
        self.critic = QCriticEnsemble(
            action_dim=action_dim,
            ensemble_size=ensemble_size,
            observation_dim=observation_dim,
            rngs=rngs,
            **kwargs,
        )

    def __call__(
        self, observations: ArrayLike, state: None
    ) -> tuple[Array, None, dict[str, Any]]:
        actions, _ = self.actor(observations)
        return actions, state, {}

    def eval(self, **attributes):
        self.actor.eval(**attributes)
        self.critic.eval(**attributes)

    def train(self, **attributes):
        self.actor.train(**attributes)
        self.critic.train(**attributes)


class SVRTrainState(NamedTuple):
    actor_optimizer: nnx.Optimizer
    critic_optimizer: nnx.Optimizer
    policy: SVRPolicy
    target_policy: TargetModel[SVRPolicy]
