from typing import Any, NamedTuple

from flax import nnx
from jax import Array, numpy as jnp

from offline.modules.actor.base import DeterministicActor
from offline.modules.base import TargetModel
from offline.modules.critic import QCriticPair
from offline.modules.policy import Policy
from offline.types import ArrayLike


ActorFilter = nnx.All(nnx.Param, nnx.PathContains("actor"))


class TD3BCPolicy(Policy[None]):
    def __init__(
        self, action_dim: int, observation_dim: int, rngs: nnx.Rngs, **kwargs
    ):
        self.actor = DeterministicActor(
            action_dim=action_dim,
            observation_dim=observation_dim,
            rngs=rngs,
            squash=True,
            **kwargs
        )
        self.critic = QCriticPair(
            action_dim=action_dim,
            observation_dim=observation_dim,
            rngs=rngs,
            **kwargs
        )

    def __call__(
        self, observations: ArrayLike, state: None
    ) -> tuple[Array, None, dict[str, Any]]:
        actions = jnp.tanh(self.actor(observations)[0])
        return actions, state, {}


class TD3BCTrainState(NamedTuple):
    actor_optimizer: nnx.Optimizer
    critic_optimizer: nnx.Optimizer
    policy: TD3BCPolicy
    target_policy: TargetModel[TD3BCPolicy]
