"""
Wrappers for mean-field policies. 
"""
from typing import Any
import jax.numpy as jnp
import flax.linen as nn

# --- agent-wrapper ---
class SAActorWrapper:

    def __init__(
        self,
        policy: nn.Module,
        params: dict,
        obs_normalizer: Any,
        normalize_obs: bool,
        local_s_normalizer: Any,
        normalize_local_s: bool):
        self.policy = policy
        self.params = params
        self.obs_normalizer = obs_normalizer
        self.normalize_obs = normalize_obs
        self.local_s_normalizer = local_s_normalizer
        self.normalize_local_s = normalize_local_s

    def __call__(self, local_states, global_obs):
        global_obs = self.obs_normalizer(global_obs, self.normalize_obs)
        local_states = self.local_s_normalizer(local_states, self.normalize_local_s)
        action = self.policy.apply(self.params, local_states, global_obs, method="mode")
        return action


class SARecurrentActorWrapper:

    def __init__(
        self,
        policy: nn.Module,
        params: dict,
        obs_normalizer: Any,
        normalize_obs: bool,
        local_s_normalizer: Any,
        normalize_local_s: bool,
    ):
        self.policy = policy
        self.params = params
        self.obs_normalizer = obs_normalizer
        self.normalize_obs = normalize_obs
        self.local_s_normalizer = local_s_normalizer
        self.normalize_local_s = normalize_local_s
        self.hidden_size = policy.encoder.hidden_size

    def init_hidden(self, batch_size: int) -> jnp.ndarray:
        return self.policy.init_hidden(batch_size, self.hidden_size)

    def __call__(self, local_states, local_obs, hidden_state, done=None):
        if done is None:
            done = jnp.zeros((local_obs.shape[0],), dtype=bool)
        local_obs = self.obs_normalizer(local_obs, self.normalize_obs)
        local_states = self.local_s_normalizer(local_states, self.normalize_local_s)
        action, next_hidden = self.policy.mode(
            self.params, local_states, local_obs, hidden_state, done
        )
        return action, next_hidden


class SAQNetWrapper:

    def __init__(
        self,
        qnet: nn.Module,
        params: dict, 
        obs_normalizer: Any,
        normalize_obs: bool,
        local_s_normalizer: Any,
        normalize_local_s: bool
    ):
        self.qnet = qnet
        self.params = params
        self.obs_normalizer = obs_normalizer
        self.normalize_obs = normalize_obs
        self.local_s_normalizer = local_s_normalizer
        self.normalize_local_s = normalize_local_s

    def __call__(self, local_obs, global_obs):
        global_obs = self.obs_normalizer(global_obs, self.normalize_obs)
        local_obs = self.local_s_normalizer(local_obs, self.normalize_local_s)
        q_vals, action = self.qnet.apply(self.params, local_obs, global_obs, method="argmax")
        return action


class SARecurrentQNetWrapper:

    def __init__(
        self,
        qnet: nn.Module,
        params: dict, 
        obs_normalizer: Any,
        normalize_obs: bool,
        local_s_normalizer: Any,
        normalize_local_s: bool
    ):
        self.qnet = qnet
        self.params = params
        self.obs_normalizer = obs_normalizer
        self.normalize_obs = normalize_obs
        self.local_s_normalizer = local_s_normalizer
        self.normalize_local_s = normalize_local_s
        self.hidden_size = qnet.encoder.hidden_size

    def __call__(self, local_states, global_obs, hidden_state, done=None):
        if done is None:
            done = jnp.zeros((global_obs.shape[0],), dtype=bool)
        global_obs = self.obs_normalizer(global_obs, self.normalize_obs)
        local_states = self.local_s_normalizer(local_states, self.normalize_local_s)
        (_, action), next_hidden = self.qnet.argmax(self.params, local_states, global_obs, hidden_state, done)
        return action, next_hidden


# --- value-wrapper ---
class SAValueWrapper:

    def __init__(
        self,
        value: nn.Module,
        params: dict,
        obs_normalizer: Any,
        normalize_obs: bool,
        local_s_normalizer: Any,
        normalize_local_s: bool):
        self.value = value
        self.params = params
        self.obs_normalizer = obs_normalizer
        self.normalize_obs = normalize_obs
        self.local_s_normalizer = local_s_normalizer
        self.normalize_local_s = normalize_local_s

    def __call__(self, local_states, global_obs):
        global_obs = self.obs_normalizer(global_obs, self.normalize_obs)
        local_states = self.local_s_normalizer(local_states, self.normalize_local_s)
        value = self.value.apply(self.params, local_states, global_obs)
        return value


class SARecurrentValueWrapper:

    def __init__(
        self,
        value: nn.Module,
        params: dict,
        obs_normalizer: Any,
        normalize_obs: bool,
        local_s_normalizer: Any,
        normalize_local_s: bool,
    ):
        self.value = value
        self.params = params
        self.obs_normalizer = obs_normalizer
        self.normalize_obs = normalize_obs
        self.local_s_normalizer = local_s_normalizer
        self.normalize_local_s = normalize_local_s
        self.hidden_size = value.encoder.hidden_size

    def init_hidden(self, batch_size: int) -> jnp.ndarray:
        return self.value.init_hidden(batch_size, self.hidden_size)

    def __call__(self, local_states, global_obs, hidden_state, done=None):
        if done is None:
            done = jnp.zeros((global_obs.shape[0],), dtype=bool)
        global_obs = self.obs_normalizer(global_obs, self.normalize_obs)
        local_states = self.local_s_normalizer(local_states, self.normalize_local_s)
        value = self.value(
            self.params, local_states, global_obs, hidden_state, done
        )
        return value