"""
Wrappers for mean-field policies. 
"""

import jax.numpy as jnp
from typing import Optional, Dict, Any
import jax

from mfax.utils.nets.value import ValueNetwork
from mfax.utils.nets.base import RecurrentEncoder


# --- Recurrent Mean Field Value ---
class RecurrentMeanFieldValue:

    def __init__(
        self,
        state_type: str,
        num_states: int | None = None,
        value_kwargs: Optional[Dict[str, Any]] = None,
        encoder_kwargs: Optional[Dict[str, Any]] = None,
    ):
        self.state_type = state_type  # "states" or "indices"
        self.num_states = num_states
        default_value_kwargs = dict(
            activation="tanh",
            hidden_layer_sizes=(64, 64, 64),
            state_type=self.state_type,
            num_states=self.num_states,
        )
        if value_kwargs:
            default_value_kwargs.update(value_kwargs)
        
        default_encoder_kwargs = dict(hidden_size=64, embed_size=64, activation="tanh")
        if encoder_kwargs:
            default_encoder_kwargs.update(encoder_kwargs)
        self.global_encoder = RecurrentEncoder(**default_encoder_kwargs)
        self.value = ValueNetwork(**default_value_kwargs)

    @staticmethod
    def init_hidden(batch_size: int, hidden_size: int) -> jnp.ndarray:
        return RecurrentEncoder.init_hidden(batch_size, hidden_size)

    def _broadcast_global_obs(self, local_states: jnp.ndarray, global_embedding: jnp.ndarray) -> jnp.ndarray:
        # --- expand global_embedding from [d,] to [1, d], and then broadcast to [N, d] ---
        if self.state_type == "states":
            assert local_states.ndim == 2 and global_embedding.ndim == 1
            return jnp.broadcast_to(global_embedding[None, :], (local_states.shape[0], global_embedding.size))
        else:
            assert local_states.ndim == 1 and global_embedding.ndim == 1
            return jnp.broadcast_to(global_embedding[None, :], (local_states.size, global_embedding.size))

    def _with_global_embedding(
        self,
        params: dict,
        local_states: jnp.ndarray,
        global_obs: jnp.ndarray,
        global_hidden_state: jnp.ndarray,
        global_done: jnp.ndarray,
        fn,
        *fn_inputs,
    ):
        in_axes = (0, 0, 0) + tuple(0 for _ in fn_inputs)

        def _step(_obs, _hidden, _done, *extra):
            new_hidden, global_embedding = self.global_encoder.apply(
                {"params": params["encoder"]},
                _hidden,
                _obs,
                _done,
            )
            broadcasted_global = self._broadcast_global_obs(local_states, global_embedding)
            out = fn(broadcasted_global, *extra)
            return out, new_hidden

        return jax.vmap(_step, in_axes=in_axes)(global_obs, global_hidden_state, global_done, *fn_inputs)

    def __call__(
        self,
        params: dict,
        local_states: jnp.ndarray,
        global_obs: jnp.ndarray,
        global_hidden_state: jnp.ndarray,
        global_done: jnp.ndarray,
    ):
        value_params = {"params": params["value"]}
        def fn(broadcasted_global):
            return self.value.apply(value_params, local_states, broadcasted_global)
        return self._with_global_embedding(params, local_states, global_obs, global_hidden_state, global_done, fn)

    def init(
        self,
        rng: jnp.ndarray,
        local_states: jnp.ndarray,
        global_obs: jnp.ndarray,
        global_hidden_state: jnp.ndarray,
        global_done: jnp.ndarray,
    ):
        rng_enc, rng_val = jax.random.split(rng)
        encoder_params = self.global_encoder.init(
            rng_enc, global_hidden_state, global_obs, global_done
        )["params"]

        # Compute one embedding to size the value inputs.
        _, init_embeddings = self.global_encoder.apply(
            {"params": encoder_params},
            global_hidden_state,
            global_obs,
            global_done,
        )
        broadcasted = self._broadcast_global_obs(local_states, init_embeddings[0])
        value_params = self.value.init(
            rng_val, local_states, broadcasted
        )["params"]
        return {"encoder": encoder_params, "value": value_params}


class MeanFieldValue:

    def __init__(
        self,
        state_type: str,
        num_states: int | None = None,
        value_kwargs: Optional[Dict[str, Any]] = None,
    ):
        self.state_type = state_type  # "states" or "indices"
        self.num_states = num_states
        default_value_kwargs = dict(
            activation="tanh",
            hidden_layer_sizes=(64, 64, 64),
            state_type=self.state_type,
            num_states=self.num_states,
        )
        if value_kwargs:
            default_value_kwargs.update(value_kwargs)
        
        self.value = ValueNetwork(**default_value_kwargs)

    def _broadcast_global_obs(self, local_states: jnp.ndarray, global_embedding: jnp.ndarray) -> jnp.ndarray:
        # --- expand global_embedding from [d,] to [1, d], and then broadcast to [N, d] ---
        if self.state_type == "states":
            assert local_states.ndim == 2 and global_embedding.ndim == 1
            return jnp.broadcast_to(global_embedding[None, :], (local_states.shape[0], global_embedding.size))
        else:
            assert local_states.ndim == 1 and global_embedding.ndim == 1
            return jnp.broadcast_to(global_embedding[None, :], (local_states.size, global_embedding.size))

    def _with_broadcasted_global(
        self,
        params: dict,
        local_states: jnp.ndarray,
        global_obs: jnp.ndarray,
        fn,
        *fn_inputs,
    ):
        in_axes = (0,) + tuple(0 for _ in fn_inputs)

        def _step(_obs, *extra):
            broadcasted_global = self._broadcast_global_obs(local_states, _obs)
            out = fn(broadcasted_global, *extra)
            return out

        return jax.vmap(_step, in_axes=in_axes)(global_obs, *fn_inputs)

    def __call__(
        self,
        params: dict,
        local_states: jnp.ndarray,
        global_obs: jnp.ndarray,
    ):
        def fn(broadcasted_global):
            return self.value.apply(params, local_states, broadcasted_global)
        return self._with_broadcasted_global(params, local_states, global_obs, fn)

    def init(
        self,
        rng: jnp.ndarray,
        local_states: jnp.ndarray,
        global_obs: jnp.ndarray,
    ):
        broadcasted = self._broadcast_global_obs(local_states, global_obs[0])
        return self.value.init(
            rng, local_states, broadcasted
    )