from typing import Any

from flax import nnx
from jax import Array

from offline.modules.actor.base import GaussianActor
from offline.modules.critic import QCritic
from offline.modules.policy import Policy
from offline.utils.logger import Logger


ActorFilter = nnx.All(nnx.Param, nnx.PathContains("actor"))


class BPPOPolicy(Policy[None]):
    def __init__(self, actor: GaussianActor, critic: QCritic):
        self.actor = actor
        self.critic = critic

    def __call__(
        self, observations: Array, state: None
    ) -> tuple[Array, None, dict[str, Any]]:
        actions, _ = self.actor(observations)
        return actions, state, {}

    def eval(self, **attributes):
        self.actor.eval(**attributes)

    def save(self, step: int | None, logger: Logger) -> str:
        if step is None:
            return logger.save_model("actor", model=self.actor)
        return logger.save_model(
            "checkpoints", f"actor_{step}", model=self.actor
        )

    def train(self, **attributes):
        self.actor.train(**attributes)
