import flax.linen as nn
import jax
import jax.numpy as jnp

from .weights_overcooked import B_Module, W_Module


class QFix_Overcooked(nn.Module):
    """
    QFIX fixing network for projecting IGM-incomplete fixees into IGM-complete values.
    """

    hidden_size: int
    fixee: nn.Module

    debug_recover_fixee_w: bool
    debug_recover_fixee_b: bool

    def setup(self):
        if not self.debug_recover_fixee_w:
            self.w_module = W_Module(self.hidden_size, 1)
        if not self.debug_recover_fixee_b:
            self.b_module = B_Module(self.hidden_size)

    def __call__(
        self,
        individual_qvalues: jax.Array,
        individual_vvalues: jax.Array,
        states: jax.Array,
        joint_action_n_hot: jax.Array,
        *,
        w_delta: float,
        w_gt: float,
    ) -> jax.Array:
        return self.qvalues(
            individual_qvalues,
            individual_vvalues,
            states,
            joint_action_n_hot,
            w_delta=w_delta,
            w_gt=w_gt,
        )

    def qvalues(
        self,
        individual_qvalues: jax.Array,
        individual_vvalues: jax.Array,
        states: jax.Array,
        joint_action_n_hot: jax.Array,
        *,
        w_delta: float,
        w_gt: float,
    ) -> jax.Array:
        # individual_qvalues.shape == (N, T, B)
        # individual_vvalues.shape == (N, T, B)
        # states.shape == (T, B, DS)
        # joint_action_n_hot.shape == (T, B, N*A), N-hot encoding

        fixee_qvalues = self.fixee(individual_qvalues, states)
        # fixee_qvalues.shape == (T, B)
        fixee_vvalues = self.fixee(individual_vvalues, states)
        # fixee_vvalues.shape == (T, B)
        fixee_advantages = fixee_qvalues - fixee_vvalues
        # fixee_advantages.shape == (T, B)

        if self.debug_recover_fixee_w:
            w = 1
        else:
            w = self.w_module(
                states,
                joint_action_n_hot,
                w_delta=w_delta,
                w_gt=w_gt,
            )
            # w.shape == (T, B, 1)
            # NOTE: keeping the last dimension for sompatibility with qfix-lin

        if self.debug_recover_fixee_b:
            b = fixee_vvalues
            # b.shape == (T, B)
        else:
            b = self.b_module(states).squeeze(-1)
            # b.shape == (T, B)

        joint_qvalues = w * fixee_advantages + b
        # joint_qvalues.shape == (T, B)

        return joint_qvalues

    def vvalues(
        self,
        individual_vvalues: jax.Array,
        states: jax.Array,
    ):
        # individual_vvalues.shape == (N, T, B)
        # states.shape == (T, B, DS)

        if self.debug_recover_fixee_b:
            fixee_vvalues = self.fixee(individual_vvalues, states)
            # fixee_vvalues.shape == (T, B)
            b = fixee_vvalues
            # b.shape == (T, B)
        else:
            b = self.b_module(states).squeeze(-1)
            # b.shape == (T, B)

        joint_vvalues = b
        # joint_qvalues.shape == (T, B)

        return joint_vvalues


class AdditiveQFix_Overcooked(nn.Module):
    """
    Q+FIX fixing network for projecting IGM-incomplete fixees into IGM-complete values.
    """

    hidden_size: int
    fixee: nn.Module

    detach_advantages: bool

    debug_recover_fixee_w: bool
    debug_recover_fixee_b: bool

    def setup(self):
        if not self.debug_recover_fixee_w:
            self.w_module = W_Module(self.hidden_size, 1)
        if not self.debug_recover_fixee_b:
            self.b_module = B_Module(self.hidden_size)

    def __call__(
        self,
        individual_qvalues: jax.Array,
        individual_vvalues: jax.Array,
        states: jax.Array,
        joint_action_n_hot: jax.Array,
        *,
        w_delta: float,
        w_gt: float,
    ):
        # individual_qvalues.shape == (N, T, B)
        # individual_vvalues.shape == (N, T, B)
        # states.shape == (T, B, DS)
        # joint_action_n_hot.shape == (T, B, N*A), N-hot encoding

        w = self.w(states, joint_action_n_hot, w_delta=w_delta, w_gt=w_gt)
        # w.shape == (T, B, 1)
        b = self.b(states)
        # b.shape == (T, B)
        return self.qvalues(
            individual_qvalues,
            individual_vvalues,
            states,
            joint_action_n_hot,
            w=w,
            b=b,
        )

    def qvalues(
        self,
        individual_qvalues: jax.Array,
        individual_vvalues: jax.Array,
        states: jax.Array,
        joint_action_n_hot: jax.Array,
        *,
        w: jax.Array,
        b: jax.Array,
    ):
        # individual_qvalues.shape == (N, T, B)
        # individual_vvalues.shape == (N, T, B)
        # states.shape == (T, B, DS)
        # joint_action_n_hot.shape == (T, B, N*A), N-hot encoding
        # w.shape == (T, B, 1)
        # b.shape == (T, B)

        fixee_qvalues = self.fixee(individual_qvalues, states)
        # fixee_qvalues.shape == (T, B)
        fixee_vvalues = self.fixee(individual_vvalues, states)
        # fixee_vvalues.shape == (T, B)
        fixee_advantages = fixee_qvalues - fixee_vvalues
        # fixee_advantages.shape == (T, B)

        if self.detach_advantages:
            fixee_advantages = jax.lax.stop_gradient(fixee_advantages)

        joint_qvalues = fixee_qvalues + w * fixee_advantages + b
        # joint_qvalues.shape == (T, B)

        return joint_qvalues

    def vvalues(
        self,
        individual_vvalues: jax.Array,
        states: jax.Array,
        b: jax.Array,
    ) -> jax.Array:
        # individual_vvalues.shape == (N, T, B)
        # states.shape == (T, B, DS)
        # b.shape == (T, B)

        fixee_vvalues = self.fixee(individual_vvalues, states)
        # fixee_vvalues.shape == (T, B)

        joint_vvalues = fixee_vvalues + b
        # joint_qvalues.shape == (T, B)

        return joint_vvalues

    def w(
        self,
        states: jax.Array,
        joint_action_n_hot: jax.Array,
        *,
        w_delta: float,
        w_gt: float,
    ) -> jax.Array:
        # states.shape == (T, B, DS)
        # joint_action_n_hot.shape == (T, B, N*A), N-hot encoding

        if self.debug_recover_fixee_w:
            T, B, *_ = states.shape
            shape = T, B, 1
            w = jnp.zeros_like(states, shape=shape, dtype=jnp.float32)
        else:
            w = self.w_module(
                states,
                joint_action_n_hot,
                w_delta=w_delta,
                w_gt=w_gt,
            )
            # w.shape == (T, B, 1)
            # NOTE: keeping the last dimension for sompatibility with qfix-lin

        return w

    def b(self, states: jax.Array) -> jax.Array:
        # states.shape == (T, B, DS)

        if self.debug_recover_fixee_b:
            T, B, *_ = states.shape
            shape = T, B
            b = jnp.zeros_like(states, shape=shape, dtype=jnp.float32)
            # b.shape == (T, B)
        else:
            b = self.b_module(states).squeeze(-1)
            # b.shape == (T, B)

        return b
