"""
Wrappers for mean-field Q-networks.
"""

from typing import Optional, Dict, Any

import jax
import jax.numpy as jnp

from mfax.utils.nets.qnet import DiscreteQNet, OrdinalQNet
from mfax.utils.nets.base import RecurrentEncoder


class RecurrentMeanFieldQNet:

    def __init__(
        self,
        state_type: str,
        num_states: int | None = None,
        q_net_type: str = "discrete",
        q_net_kwargs: Optional[Dict[str, Any]] = None,
        encoder_kwargs: Optional[Dict[str, Any]] = None,
    ):
        self.state_type = state_type  # "states" or "indices"
        self.num_states = num_states

        default_q_net_kwargs = dict(
            activation="tanh",
            hidden_layer_sizes=(64, 64, 64),
            n_actions=1,
            state_type=self.state_type,
            num_states=self.num_states,
        )
        if q_net_kwargs:
            default_q_net_kwargs.update(q_net_kwargs)

        default_encoder_kwargs = dict(hidden_size=64, embed_size=64, activation="tanh")
        if encoder_kwargs:
            default_encoder_kwargs.update(encoder_kwargs)
        self.global_encoder = RecurrentEncoder(**default_encoder_kwargs)
        if q_net_type == "discrete":
            self.q_net = DiscreteQNet(**default_q_net_kwargs)
        elif q_net_type == "ordinal":
            self.q_net = OrdinalQNet(**default_q_net_kwargs)
        else:
            raise ValueError(f"Invalid q_net_type: {q_net_type}. Expected 'discrete' or 'ordinal'.")

    @staticmethod
    def init_hidden(batch_size: int, hidden_size: int) -> jnp.ndarray:
        return RecurrentEncoder.init_hidden(batch_size, hidden_size)

    def _broadcast_global_obs(self, local_states: jnp.ndarray, global_embedding: jnp.ndarray) -> jnp.ndarray:
        # --- expand global_embedding from [d,] to [1, d], and then broadcast to [N, d] ---
        if self.state_type == "states":
            assert local_states.ndim == 2 and global_embedding.ndim == 1
            return jnp.broadcast_to(global_embedding[None, :], (local_states.shape[0], global_embedding.size))
        else:
            assert local_states.ndim == 1 and global_embedding.ndim == 1
            return jnp.broadcast_to(global_embedding[None, :], (local_states.size, global_embedding.size))

    def _with_global_embedding(
        self,
        params: dict,
        local_states: jnp.ndarray,
        global_obs: jnp.ndarray,
        global_hidden_state: jnp.ndarray,
        global_done: jnp.ndarray,
        fn,
        *fn_inputs,
    ):
        in_axes = (0, 0, 0) + tuple(0 for _ in fn_inputs)

        def _step(_obs, _hidden, _done, *extra):
            new_hidden, global_embedding = self.global_encoder.apply(
                {"params": params["encoder"]},
                _hidden,
                _obs,
                _done,
            )
            broadcasted_global = self._broadcast_global_obs(local_states, global_embedding)
            out = fn(broadcasted_global, *extra)
            return out, new_hidden

        return jax.vmap(_step, in_axes=in_axes)(global_obs, global_hidden_state, global_done, *fn_inputs)

    def __call__(
        self,
        params: dict,
        local_states: jnp.ndarray,
        global_obs: jnp.ndarray,
        global_hidden_state: jnp.ndarray,
        global_done: jnp.ndarray,
    ):
        q_params = {"params": params["q_net"]}
        def fn(broadcasted_global):
            return self.q_net.apply(q_params, local_states, broadcasted_global)
        return self._with_global_embedding(params, local_states, global_obs, global_hidden_state, global_done, fn)

    def init(
        self,
        rng: jnp.ndarray,
        local_states: jnp.ndarray,
        global_obs: jnp.ndarray,
        global_hidden_state: jnp.ndarray,
        global_done: jnp.ndarray,
    ):
        rng_enc, rng_q = jax.random.split(rng)
        encoder_params = self.global_encoder.init(
            rng_enc, global_hidden_state, global_obs, global_done
        )["params"]

        # Compute one embedding to size the Q-net inputs.
        _, init_embeddings = self.global_encoder.apply(
            {"params": encoder_params},
            global_hidden_state,
            global_obs,
            global_done,
        )
        broadcasted = self._broadcast_global_obs(local_states, init_embeddings[0])
        q_params = self.q_net.init(rng_q, local_states, broadcasted)["params"]
        return {"encoder": encoder_params, "q_net": q_params}

    def softmax(
        self,
        params: dict,
        local_states: jnp.ndarray,
        global_obs: jnp.ndarray,
        global_hidden_state: jnp.ndarray,
        global_done: jnp.ndarray,
    ):
        q_params = {"params": params["q_net"]}
        def fn(broadcasted_global):
            return self.q_net.apply(q_params, local_states, broadcasted_global, method="softmax")
        return self._with_global_embedding(
            params, local_states, global_obs, global_hidden_state, global_done, fn
        )

    def epsilon_greedy(
        self,
        params: dict,
        local_states: jnp.ndarray,
        global_obs: jnp.ndarray,
        global_hidden_state: jnp.ndarray,
        global_done: jnp.ndarray,
        eps: float,
        rng: jnp.ndarray,
    ):
        q_params = {"params": params["q_net"]}
        rngs = jax.random.split(rng, global_obs.shape[0])
        def fn(broadcasted_global, rng_i):
            return self.q_net.apply(q_params, local_states, broadcasted_global, eps, rng_i, method="epsilon_greedy")
        return self._with_global_embedding(
            params, local_states, global_obs, global_hidden_state, global_done, fn, rngs
        )

    def argmax(
        self,
        params: dict,
        local_states: jnp.ndarray,
        global_obs: jnp.ndarray,
        global_hidden_state: jnp.ndarray,
        global_done: jnp.ndarray,
    ):
        q_params = {"params": params["q_net"]}
        def fn(broadcasted_global):
            return self.q_net.apply(q_params, local_states, broadcasted_global, method="argmax")
        return self._with_global_embedding(
            params, local_states, global_obs, global_hidden_state, global_done, fn
        )

    def sample_softmax(
        self,
        params: dict,
        local_states: jnp.ndarray,
        global_obs: jnp.ndarray,
        global_hidden_state: jnp.ndarray,
        global_done: jnp.ndarray,
        rng: jnp.ndarray,
    ):
        q_params = {"params": params["q_net"]}
        rngs = jax.random.split(rng, global_obs.shape[0])
        def fn(broadcasted_global, rng_i):
            return self.q_net.apply(q_params, local_states, broadcasted_global, rng_i, method="sample_softmax")
        return self._with_global_embedding(
            params, local_states, global_obs, global_hidden_state, global_done, fn, rngs
        )

    def take_action(
        self,
        params: dict,
        local_states: jnp.ndarray,
        global_obs: jnp.ndarray,
        global_hidden_state: jnp.ndarray,
        global_done: jnp.ndarray,
        action_idxs: jnp.ndarray,
    ):
        q_params = {"params": params["q_net"]}
        def fn(broadcasted_global, _action_idxs):
            return self.q_net.apply(q_params, local_states, broadcasted_global, _action_idxs, method="take_action")
        return self._with_global_embedding(
            params, local_states, global_obs, global_hidden_state, global_done, fn, action_idxs
        )


class MeanFieldQNet:

    def __init__(
        self,
        state_type: str,
        num_states: int | None = None,
        q_net_type: str = "discrete",
        q_net_kwargs: Optional[Dict[str, Any]] = None,
    ):
        self.state_type = state_type  # "states" or "indices"
        self.num_states = num_states

        default_q_net_kwargs = dict(
            activation="tanh",
            hidden_layer_sizes=(64, 64, 64),
            n_actions=1,
            state_type=self.state_type,
            num_states=self.num_states,
        )
        if q_net_kwargs:
            default_q_net_kwargs.update(q_net_kwargs)
        if q_net_type == "discrete":
            self.q_net = DiscreteQNet(**default_q_net_kwargs)
        elif q_net_type == "ordinal":
            self.q_net = OrdinalQNet(**default_q_net_kwargs)
        else:
            raise ValueError(f"Invalid q_net_type: {q_net_type}. Expected 'discrete' or 'ordinal'.")

    def _broadcast_global_obs(self, local_states: jnp.ndarray, global_embedding: jnp.ndarray) -> jnp.ndarray:
        if self.state_type == "states":
            assert local_states.ndim == 2 and global_embedding.ndim == 1
            return jnp.broadcast_to(global_embedding[None, :], (local_states.shape[0], global_embedding.size))
        else:
            assert local_states.ndim == 1 and global_embedding.ndim == 1
            return jnp.broadcast_to(global_embedding[None, :], (local_states.size, global_embedding.size))

    def _with_broadcasted_global(
        self,
        params: dict,
        local_states: jnp.ndarray,
        global_obs: jnp.ndarray,
        fn,
        *fn_inputs,
    ):
        in_axes = (0,) + tuple(0 for _ in fn_inputs)

        def _step(_obs, *extra):
            broadcasted_global = self._broadcast_global_obs(local_states, _obs)
            out = fn(broadcasted_global, *extra)
            return out

        return jax.vmap(_step, in_axes=in_axes)(global_obs, *fn_inputs)

    def __call__(
        self,
        params: dict,
        local_states: jnp.ndarray,
        global_obs: jnp.ndarray,
    ):
        def fn(broadcasted_global):
            return self.q_net.apply(params, local_states, broadcasted_global)
        return self._with_broadcasted_global(
            params, local_states, global_obs, fn
        )

    def init(
        self,
        rng: jnp.ndarray,
        local_states: jnp.ndarray,
        global_obs: jnp.ndarray,
    ):
        broadcasted = self._broadcast_global_obs(local_states, global_obs[0])
        return self.q_net.init(
            rng, local_states, broadcasted
        )

    def softmax(
        self,
        params: dict,
        local_states: jnp.ndarray,
        global_obs: jnp.ndarray,
    ):
        def fn(broadcasted_global):
            return self.q_net.apply(params, local_states, broadcasted_global, method="softmax")
        return self._with_broadcasted_global(
            params, local_states, global_obs, fn
        )

    def epsilon_greedy(
        self,
        params: dict,
        local_states: jnp.ndarray,
        global_obs: jnp.ndarray,
        eps: float,
        rng: jnp.ndarray,
    ):
        rngs = jax.random.split(rng, global_obs.shape[0])
        def fn(broadcasted_global, rng_i):
            return self.q_net.apply(params, local_states, broadcasted_global, eps, rng_i, method="epsilon_greedy")
        return self._with_broadcasted_global(
            params, local_states, global_obs, fn, rngs
        )

    def argmax(
        self,
        params: dict,
        local_states: jnp.ndarray,
        global_obs: jnp.ndarray,
    ):
        def fn(broadcasted_global):
            return self.q_net.apply(params, local_states, broadcasted_global, method="argmax")
        return self._with_broadcasted_global(
            params, local_states, global_obs, fn
        )

    def sample_softmax(
        self,
        params: dict,
        local_states: jnp.ndarray,
        global_obs: jnp.ndarray,
        rng: jnp.ndarray,
    ):
        rngs = jax.random.split(rng, global_obs.shape[0])
        def fn(broadcasted_global, rng_i):
            return self.q_net.apply(params, local_states, broadcasted_global, rng_i, method="sample_softmax")
        return self._with_broadcasted_global(
            params, local_states, global_obs, fn, rngs
        )

    def take_action(
        self,
        params: dict,
        local_states: jnp.ndarray,
        global_obs: jnp.ndarray,
        action_idxs: jnp.ndarray,
    ):
        def fn(broadcasted_global, _action_idxs):
            return self.q_net.apply(params, local_states, broadcasted_global, _action_idxs, method="take_action")
        return self._with_broadcasted_global(
            params, local_states, global_obs, fn, action_idxs
        )
