import flax.linen as nn
import jax
import jax.numpy as jnp

from .weights import B_Module, W_Module
from .weights_overcooked import B_Module as B_Module_Overcooked
from .weights_overcooked import W_Module as W_Module_Overcooked


class QFix(nn.Module):
    """
    QFIX fixing network for projecting IGM-incomplete fixees into IGM-complete values.
    """

    hidden_size: int
    fixee: nn.Module

    debug_recover_fixee_w: bool
    debug_recover_fixee_b: bool

    is_additive: bool = False
    is_overcooked: bool = False

    def setup(self):
        if not self.debug_recover_fixee_w:
            W_Module_class = W_Module_Overcooked if self.is_overcooked else W_Module
            self.w_module = W_Module_class(self.hidden_size, 1)
        if not self.debug_recover_fixee_b:
            B_Module_class = B_Module_Overcooked if self.is_overcooked else B_Module
            self.b_module = B_Module_class(self.hidden_size)

    def qvalues(
        self,
        individual_qvalues: jax.Array,
        individual_vvalues: jax.Array,
        states: jax.Array,
        joint_action_n_hot: jax.Array,
        *,
        w_delta: float,
        w_gt: float,
    ) -> jax.Array:
        # individual_qvalues.shape == (N, T, B)
        # individual_vvalues.shape == (N, T, B)
        # states.shape == (T, B, DS)
        # joint_action_n_hot.shape == (T, B, N*A), N-hot encoding

        fixee_qvalues = self.fixee(individual_qvalues, states)
        # fixee_qvalues.shape == (T, B)
        fixee_vvalues = self.fixee(individual_vvalues, states)
        # fixee_vvalues.shape == (T, B)
        fixee_advantages = fixee_qvalues - fixee_vvalues
        # fixee_advantages.shape == (T, B)

        if self.debug_recover_fixee_w:
            w = 1
        else:
            w = self.w_module(
                states,
                joint_action_n_hot,
                w_delta=w_delta,
                w_gt=w_gt,
            )
            # w.shape == (T, B, 1)
            # NOTE: keeping the last dimension for sompatibility with qfix-lin

        if self.debug_recover_fixee_b:
            b = fixee_vvalues
            # b.shape == (T, B)
        else:
            b = self.b_module(states).squeeze(-1)
            # b.shape == (T, B)

        joint_qvalues = w * fixee_advantages + b
        # joint_qvalues.shape == (T, B)

        return joint_qvalues

    # def vvalues(
    #     self,
    #     individual_vvalues: jax.Array,
    #     states: jax.Array,
    # ):
    #     # individual_vvalues.shape == (N, T, B)
    #     # states.shape == (T, B, DS)
    #
    #     if self.debug_recover_fixee_b:
    #         fixee_vvalues = self.fixee(individual_vvalues, states)
    #         # fixee_vvalues.shape == (T, B)
    #         b = fixee_vvalues
    #         # b.shape == (T, B)
    #     else:
    #         b = self.b_module(states).squeeze(-1)
    #         # b.shape == (T, B)
    #
    #     joint_vvalues = b
    #     # joint_qvalues.shape == (T, B)
    #
    #     return joint_vvalues



class AdditiveQFix(nn.Module):
    """
    Q+FIX fixing network for projecting IGM-incomplete fixees into IGM-complete values.
    """

    hidden_size: int
    fixee: nn.Module

    detach_advantages: bool

    debug_recover_fixee_w: bool
    debug_recover_fixee_b: bool

    is_additive: bool = True
    is_overcooked: bool = False

    def setup(self):
        if not self.debug_recover_fixee_w:
            W_Module_class = W_Module_Overcooked if self.is_overcooked else W_Module
            self.w_module = W_Module_class(self.hidden_size, 1)
        if not self.debug_recover_fixee_b:
            B_Module_class = B_Module_Overcooked if self.is_overcooked else B_Module
            self.b_module = B_Module_class(self.hidden_size)

    def qvalues(
        self,
        individual_qvalues: jax.Array,
        individual_vvalues: jax.Array,
        states: jax.Array,
        joint_action_n_hot: jax.Array,
        *,
        w_delta: float,
        w_gt: float,
    ) -> tuple[jax.Array, dict]:
        # individual_qvalues.shape == (N, T, B)
        # individual_vvalues.shape == (N, T, B)
        # states.shape == (T, B, DS)
        # joint_action_n_hot.shape == (T, B, N*A), N-hot encoding
        # w.shape == (T, B, 1)
        # b.shape == (T, B)

        fixee_qvalues = self.fixee(individual_qvalues, states)
        # fixee_qvalues.shape == (T, B)
        fixee_vvalues = self.fixee(individual_vvalues, states)
        # fixee_vvalues.shape == (T, B)
        fixee_advantages = fixee_qvalues - fixee_vvalues
        # fixee_advantages.shape == (T, B)

        if self.detach_advantages:
            fixee_advantages = jax.lax.stop_gradient(fixee_advantages)

        w = self.w(states, joint_action_n_hot, w_delta=w_delta, w_gt=w_gt)
        # w.shape == (T, B, 1)
        b = self.b(states)
        # w.shape == (T, B)

        intervention = w.squeeze(-1) * fixee_advantages + b
        # intervention.shape == (T, B)
        joint_qvalues = fixee_qvalues + intervention
        # joint_qvalues.shape == (T, B)

        info = {
            "w": w,
            "b": b,
            "intervention": intervention,
        }
        return joint_qvalues, info

    def vvalues(
        self,
        individual_vvalues: jax.Array,
        states: jax.Array,
    ) -> tuple[jax.Array, dict]:
        # individual_qvalues.shape == (N, T, B)
        # individual_vvalues.shape == (N, T, B)
        # states.shape == (T, B, DS)
        # joint_action_n_hot.shape == (T, B, N*A), N-hot encoding
        # w.shape == (T, B, 1)
        # b.shape == (T, B)

        fixee_vvalues = self.fixee(individual_vvalues, states)
        # fixee_vvalues.shape == (T, B)

        b = self.b(states)
        # w.shape == (T, B)

        intervention = b
        # intervention.shape == (T, B)
        joint_vvalues = fixee_vvalues + intervention
        # joint_qvalues.shape == (T, B)

        info = {
            "b": b,
            "intervention": intervention,
        }
        return joint_vvalues, info


    def w(
        self,
        states: jax.Array,
        joint_action_n_hot: jax.Array,
        *,
        w_delta: float,
        w_gt: float,
    ):
        # states.shape == (T, B, DS)
        # joint_action_n_hot.shape == (T, B, N*A), N-hot encoding

        if self.debug_recover_fixee_w:
            T, B, *_ = states.shape
            shape = T, B, 1
            w = jnp.zeros_like(states, shape=shape, dtype=jnp.float32)
        else:
            w = self.w_module(
                states,
                joint_action_n_hot,
                w_delta=w_delta,
                w_gt=w_gt,
            )
            # w.shape == (T, B, 1)
            # NOTE: keeping the last dimension for sompatibility with qfix-lin

        return w

    def b(self, states: jax.Array) -> jax.Array:
        # states.shape == (T, B, DS)

        if self.debug_recover_fixee_b:
            T, B, *_ = states.shape
            shape = T, B
            b = jnp.zeros_like(states, shape=shape, dtype=jnp.float32)
            # b.shape == (T, B)
        else:
            b = self.b_module(states).squeeze(-1)
            # b.shape == (T, B)

        return b
