import flax.linen as nn
import jax
import jax.numpy as jnp

from .weights_overcooked import B_Module, W_Module


class QFixLin_Overcooked(nn.Module):
    """
    QFIX-lin fixing network for projecting IGM-incomplete fixees into IGM-complete values.
    """

    hidden_size: int
    num_agents: int

    debug_recover_fixee_w: bool
    debug_recover_fixee_b: bool

    def setup(self):
        if not self.debug_recover_fixee_w:
            self.w_module = W_Module(self.hidden_size, self.num_agents)
        if not self.debug_recover_fixee_b:
            self.b_module = B_Module(self.hidden_size)

    def __call__(
        self,
        individual_qvalues: jax.Array,
        individual_vvalues: jax.Array,
        states: jax.Array,
        joint_action_n_hot: jax.Array,
        *,
        w_delta: float,
        w_gt: float,
    ) -> jax.Array:
        # individual_qvalues.shape == (N, T, B)
        # individual_vvalues.shape == (N, T, B)
        # states.shape == (T, B, DS)
        # joint_action_n_hot.shape == (T, B, N*A), N-hot encoding
        return self.qvalues(
            individual_qvalues,
            individual_vvalues,
            states,
            joint_action_n_hot,
            w_delta=w_delta,
            w_gt=w_gt,
        )

    def qvalues(
        self,
        individual_qvalues: jax.Array,
        individual_vvalues: jax.Array,
        states: jax.Array,
        joint_action_n_hot: jax.Array,
        *,
        w_delta: float,
        w_gt: float,
    ) -> jax.Array:
        # individual_qvalues.shape == (N, T, B)
        # individual_vvalues.shape == (N, T, B)
        # states.shape == (T, B, DS)
        # joint_action_n_hot.shape == (T, B, N*A), N-hot encoding

        individual_advantages = individual_qvalues - individual_vvalues
        # fixee_advantages.shape == (N, T, B)

        if self.debug_recover_fixee_w:
            N, T, B = individual_advantages.shape
            w = jnp.ones((T, B, N))
            # w.shape == (T, B, N)
        else:
            w = self.w_module(
                states,
                joint_action_n_hot,
                w_delta=w_delta,
                w_gt=w_gt,
            )
            # w.shape == (T, B, N)

        if self.debug_recover_fixee_b:
            b = individual_vvalues.sum(axis=0)
            # b.shape == (T, B)
        else:
            b = self.b_module(states).squeeze(-1)
            # b.shape == (T, B)

        joint_qvalues = jnp.einsum("TBN,NTB->TB", w, individual_advantages) + b
        # joint_qvalues.shape == (T, B)

        return joint_qvalues

    def vvalues(
        self,
        individual_vvalues: jax.Array,
        states: jax.Array,
    ) -> jax.Array:
        # individual_vvalues.shape == (N, T, B)
        # states.shape == (T, B, DS)

        b = self.b_module(states).squeeze(-1)
        # b.shape == (T, B)

        joint_vvalues = b
        # joint_vvalues.shape == (T, B)

        return joint_vvalues


class AdditiveQFixLin_Overcooked(nn.Module):
    """
    Q+FIX-lin fixing network for projecting IGM-incomplete fixees into IGM-complete values.
    """

    hidden_size: int
    num_agents: int

    detach_advantages: bool

    debug_recover_fixee_w: bool
    debug_recover_fixee_b: bool

    def setup(self):
        if not self.debug_recover_fixee_w:
            self.w_module = W_Module(self.hidden_size, self.num_agents)
        if not self.debug_recover_fixee_b:
            self.b_module = B_Module(self.hidden_size)

    def __call__(
        self,
        individual_qvalues: jax.Array,
        individual_vvalues: jax.Array,
        states: jax.Array,
        joint_action_n_hot: jax.Array,
        *,
        w_delta: float,
        w_gt: float,
    ) -> jax.Array:
        # individual_qvalues.shape == (N, T, B)
        # individual_vvalues.shape == (N, T, B)
        # states.shape == (T, B, DS)
        # joint_action_n_hot.shape == (T, B, N*A), N-hot encoding

        w = self.w(
            states,
            joint_action_n_hot,
            w_delta=w_delta,
            w_gt=w_gt,
        )
        # w.shape == (T, B, N)
        b = self.b(states)
        # b.shape == (T, B)
        return self.qvalues(
            individual_qvalues,
            individual_vvalues,
            states,
            joint_action_n_hot,
            w=w,
            b=b,
        )

    def qvalues(
        self,
        individual_qvalues: jax.Array,
        individual_vvalues: jax.Array,
        states: jax.Array,
        joint_action_n_hot: jax.Array,
        *,
        w: jax.Array,
        b: jax.Array,
    ) -> jax.Array:
        # individual_qvalues.shape == (N, T, B)
        # individual_vvalues.shape == (N, T, B)
        # states.shape == (T, B, DS)
        # joint_action_n_hot.shape == (T, B, N*A), N-hot encoding
        # w.shape == (T, B, N)
        # b.shape == (T, B)

        fixee_qvalues = individual_qvalues.sum(axis=0)
        # fixee_qvalues.shape == (T, B)
        individual_advantages = individual_qvalues - individual_vvalues
        # individual_advantages.shape == (T, B)

        if self.detach_advantages:
            individual_advantages = jax.lax.stop_gradient(individual_advantages)

        joint_qvalues = (
            fixee_qvalues + jnp.einsum("TBN,NTB->TB", w, individual_advantages) + b
        )
        # joint_qvalues.shape == (T, B)

        return joint_qvalues

    def vvalues(
        self,
        individual_vvalues: jax.Array,
        states: jax.Array,
        b: jax.Array,
    ) -> jax.Array:
        # individual_vvalues.shape == (N, T, B)
        # states.shape == (T, B, DS)
        # b.shape == (T, B)

        fixee_vvalues = individual_vvalues.sum(axis=0)
        # fixee_vvalues.shape == (T, B)

        joint_vvalues = fixee_vvalues + b
        # joint_qvalues.shape == (T, B)

        return joint_vvalues

    def w(
        self,
        states: jax.Array,
        joint_action_n_hot: jax.Array,
        *,
        w_delta: float,
        w_gt: float,
    ) -> jax.Array:
        # states.shape == (T, B, DS)
        # joint_action_n_hot.shape == (T, B, N*A), N-hot encoding

        if self.debug_recover_fixee_w:
            T, B, *_ = states.shape
            shape = T, B, self.num_agents
            w = jnp.zeros_like(states, shape=shape, dtype=jnp.float32)
            # w.shape == (T, B, N)
        else:
            w = self.w_module(
                states,
                joint_action_n_hot,
                w_delta=w_delta,
                w_gt=w_gt,
            )
            # w.shape == (T, B, N)

        return w

    def b(
        self,
        states: jax.Array,
    ) -> jax.Array:
        # states.shape == (T, B, DS)
        # TODO shape wrong

        if self.debug_recover_fixee_b:
            T, B, *_ = states.shape
            shape = T, B
            b = jnp.zeros_like(states, shape=shape, dtype=jnp.float32)
            # b.shape == (T, B)
        else:
            b = self.b_module(states).squeeze(-1)
            # b.shape == (T, B)

        return b
