from typing import Any, Generic, TypeVar

from flax import nnx
from flax.nnx.filterlib import Filter
from jax import Array

from offline.utils.logger import Logger


StateT = TypeVar("StateT")


class Policy(nnx.Module, Generic[StateT]):
    POI: Filter | None = None

    def __call__(
        self, observations: Array, state: StateT
    ) -> tuple[Array, StateT, dict[str, Any]]:
        raise NotImplementedError()

    def save(self, step: int | None, logger: Logger) -> str:
        if step is None:
            return logger.save_model("policy", model=self, poi=self.POI)
        return logger.save_model(
            "checkpoints", f"policy_{step}", model=self, poi=self.POI
        )
