import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.linen.initializers import constant, orthogonal

from .cnn import CNN
from .mlp import MLP


class HyperNetwork(nn.Module):
    """HyperNetwork for generating weights of QMix' mixing network."""

    hidden_dim: int
    output_dim: int
    init_scale: float

    @nn.compact
    def __call__(self, x):
        # x = nn.Dense(
        #     self.hidden_dim,
        #     kernel_init=orthogonal(self.init_scale),
        #     bias_init=constant(0.0),
        # )(x)
        # x = nn.relu(x)
        x = nn.Dense(
            self.output_dim,
            kernel_init=orthogonal(self.init_scale),
            bias_init=constant(0.0),
        )(x)
        return x


class QMIX_Overcooked(nn.Module):
    """
    Mixing network for projecting individual utilities into joint qvalues.
    Follows the original QMix implementation.
    """

    embedding_dim: int
    hypernet_hidden_dim: int
    init_scale: float

    state_module: nn.Module | None = None

    @nn.compact
    def __call__(self, individual_qvalues: jax.Array, states):
        # individual_qvalues.shape == (N, T, B)
        N, T, B = individual_qvalues.shape

        if self.state_module is not None:
            states = self.state_module(states)

        # states = nn.LayerNorm()(states)
        # jax.debug.breakpoint()
        # jax.debug.print(
        #     "s=[{: 6.2f}, {: 6.2f}, {: 6.2f}, {: 6.2f}, {: 6.2f}] ",
        #     *jnp.quantile(states, jnp.array([0.0, 0.25, 0.5, 0.75, 1.0])),
        # )

        w1 = MLP(
            features=[N * self.embedding_dim],
            # features=[self.hypernet_hidden_dim, N * self.embedding_dim],
            # hidden_dim=self.hypernet_hidden_dim,
            # output_dim=N * self.embedding_dim,
            init_scale=self.init_scale,
        )(states)
        # w1.shape == (T, B, N*D)
        w1 = w1.reshape(T, B, N, self.embedding_dim)
        # w1.shape == (T, B, N, D)
        w1 = jnp.abs(w1)

        # b1 = nn.Dense(
        #     self.embedding_dim,
        #     kernel_init=orthogonal(self.init_scale),
        #     bias_init=constant(0.0),
        # )(states)
        b1 = MLP(
            # hidden_dim=self.embedding_dim,
            # # output_dim=self.embedding_dim,
            # output_dim=1,
            # features=[self.embedding_dim, self.embedding_dim],
            # features=[self.embedding_dim, 1],
            features=[1],
            init_scale=self.init_scale,
        )(states)
        # b1.shape == (T, B, 1)
        # b1.shape == (T, B, D)

        # w2 = HyperNetwork(
        #     hidden_dim=self.hypernet_hidden_dim,
        #     output_dim=self.embedding_dim,
        #     init_scale=self.init_scale,
        # )(states)
        # # w2.shape == (T, B, D)
        # w2 = jnp.abs(w2)
        #
        # b2 = HyperNetwork(
        #     hidden_dim=self.embedding_dim,
        #     output_dim=1,
        #     init_scale=self.init_scale,
        # )(states)
        # # b2.shape == (T, B, 1)
        # b2 = b2.squeeze(-1)
        # # b2.shape == (T, B)

        # monotonicity and reshaping
        # w1 = w1.reshape(T, B, N, self.embedding_dim)
        # w1.shape == (T, B, N, D)
        # b1 = b1.reshape(time_steps, batch_size, 1, self.embedding_dim)
        # b1.shape == (T, B, 1, D)
        # w2 = jnp.abs(w2.reshape(time_steps, batch_size, self.embedding_dim, 1))
        # w1.shape == (T, B, D, 1)
        # b2 = b2.reshape(time_steps, batch_size, 1, 1)
        # w1.shape == (T, B, 1, 1)

        # mix
        # return individual_qvalues.sum(0)
        return jnp.einsum("ntb,tbnd->tb", individual_qvalues, w1) + b1.squeeze(-1)
        # # mixed_values = jnp.einsum('ntb,tbnd->tbd', individual_qvalues, w1) + b1
        # # mixed_values.shape == (T, B, D)
        # mixed_values = nn.elu(mixed_values)
        # # mixed_values.shape == (T, B, D)
        # qmix_values = jnp.einsum('tbd,tbd->tb', mixed_values, w2) + b2
        # # mixed_values.shape == (T, B)
        # return qmix_values
        #
        # # hidden = nn.elu(jnp.matmul(q_vals[:, :, None, :], w1) + b1)
        # # q_tot = jnp.matmul(hidden, w2) + b2
        #
        # # return q_tot.squeeze()  # (time_steps, batch_size)
