from typing import Any, NamedTuple

from flax import nnx
from flax.nnx.filterlib import Filter
from jax import Array, numpy as jnp

from offline.modules.actor.base import GaussianActor
from offline.modules.base import TargetModel
from offline.modules.critic import QCriticEnsemble, VCritic
from offline.modules.policy import Policy
from offline.types import ArrayLike


ActorCriticFilter = nnx.Any(
    nnx.PathContains("actor"), nnx.PathContains("critic")
)
ActorFilter = nnx.All(nnx.Param, nnx.PathContains("actor"))


class LBPPolicy(Policy[None]):
    POI: Filter | None = ActorCriticFilter

    def __init__(
        self,
        action_dim: int,
        behavior_actor: GaussianActor,
        ensemble_size: int,
        observation_dim: int,
        rngs: nnx.Rngs,
        zero_mean: bool,
        **kwargs,
    ):
        self.actor = GaussianActor(
            action_dim=action_dim,
            observation_dim=observation_dim,
            rngs=rngs,
            **kwargs,
        )
        self.behavior_actor = behavior_actor
        self.critic = QCriticEnsemble(
            action_dim=action_dim,
            ensemble_size=ensemble_size,
            observation_dim=observation_dim,
            rngs=rngs,
            **kwargs,
        )
        if zero_mean:
            kernel = self.actor.model.linears[-1].kernel.value
            mean_kernel, std_kernel = jnp.split(kernel, 2, axis=1)
            self.actor.model.linears[-1].kernel.value = jnp.concatenate(
                (jnp.zeros_like(mean_kernel), std_kernel), axis=1
            )

    def __call__(
        self, observations: ArrayLike, state: None
    ) -> tuple[Array, None, dict[str, Any]]:
        means, _ = self.behavior_actor(observations)
        deltas, _ = self.actor(observations)
        actions = means + deltas
        return actions, state, {}

    def eval(self, **attributes):
        self.actor.eval(**attributes)
        self.critic.eval(**attributes)

    def train(self, **attributes):
        self.actor.train(**attributes)
        self.critic.train(**attributes)


class BehaviorState(NamedTuple):
    actor: GaussianActor
    critic: VCritic


class LBPTrainState(NamedTuple):
    actor_optimizer: nnx.Optimizer
    critic_optimizer: nnx.Optimizer
    policy: LBPPolicy
    target_policy: TargetModel[LBPPolicy]
