import flax.linen as nn
import jax
import jax.numpy as jnp

from .weights import B_Module, W_Module
from .weights_overcooked import B_Module as B_Module_Overcooked
from .weights_overcooked import W_Module as W_Module_Overcooked


class QFixLin(nn.Module):
    """
    QFIX-lin fixing network for projecting IGM-incomplete fixees into IGM-complete values.
    """

    hidden_size: int
    num_agents: int

    debug_recover_fixee_w: bool
    debug_recover_fixee_b: bool

    is_additive: bool = False
    is_overcooked: bool = False

    def setup(self):
        if not self.debug_recover_fixee_w:
            W_Module_class = W_Module_Overcooked if self.is_overcooked else W_Module
            self.w_module = W_Module_class(self.hidden_size, self.num_agents)
        if not self.debug_recover_fixee_b:
            self.b_module = B_Module(self.hidden_size)
            B_Module_class = B_Module_Overcooked if self.is_overcooked else B_Module
            self.b_module = B_Module_class(self.hidden_size)

    def qvalues(
        self,
        individual_qvalues: jax.Array,
        individual_vvalues: jax.Array,
        states: jax.Array,
        joint_action_n_hot: jax.Array,
        *,
        w_delta: float,
        w_gt: float,
    ) -> jax.Array:
        # individual_qvalues.shape == (N, T, B)
        # individual_vvalues.shape == (N, T, B)
        # states.shape == (T, B, DS)
        # joint_action_n_hot.shape == (T, B, N*A), N-hot encoding

        individual_advantages = individual_qvalues - individual_vvalues
        # fixee_advantages.shape == (N, T, B)

        if self.debug_recover_fixee_w:
            N, T, B = individual_advantages.shape
            w = jnp.ones((T, B, N))
            # w.shape == (T, B, N)
        else:
            w = self.w_module(
                states,
                joint_action_n_hot,
                w_delta=w_delta,
                w_gt=w_gt,
            )
            # w.shape == (T, B, N)

        if self.debug_recover_fixee_b:
            b = individual_vvalues.sum(axis=0)
            # b.shape == (T, B)
        else:
            b = self.b_module(states).squeeze(-1)
            # b.shape == (T, B)

        joint_qvalues = jnp.einsum("TBN,NTB->TB", w, individual_advantages) + b
        # joint_qvalues.shape == (T, B)

        return joint_qvalues

    # def vvalues(
    #     self,
    #     individual_vvalues: jax.Array,
    #     states: jax.Array,
    # ) -> jax.Array:
    #     # individual_vvalues.shape == (N, T, B)
    #     # states.shape == (T, B, DS)
    #
    #     b = self.b_module(states).squeeze(-1)
    #     # b.shape == (T, B)
    #
    #     joint_vvalues = b
    #     # joint_vvalues.shape == (T, B)
    #
    #     return joint_vvalues


class AdditiveQFixLin(nn.Module):
    """
    Q+FIX-lin fixing network for projecting IGM-incomplete fixees into IGM-complete values.
    """

    hidden_size: int
    num_agents: int

    detach_advantages: bool

    debug_recover_fixee_w: bool
    debug_recover_fixee_b: bool

    is_additive: bool = True
    is_overcooked: bool = False

    def setup(self):
        if not self.debug_recover_fixee_w:
            W_Module_class = W_Module_Overcooked if self.is_overcooked else W_Module
            self.w_module = W_Module_class(self.hidden_size, self.num_agents)
        if not self.debug_recover_fixee_b:
            B_Module_class = B_Module_Overcooked if self.is_overcooked else B_Module
            self.b_module = B_Module_class(self.hidden_size)

    def qvalues(
        self,
        individual_qvalues: jax.Array,
        individual_vvalues: jax.Array,
        states: jax.Array,
        joint_action_n_hot: jax.Array,
        *,
        w_delta: float,
        w_gt: float,
    ) -> tuple[jax.Array,dict]:
        # individual_qvalues.shape == (N, T, B)
        # individual_vvalues.shape == (N, T, B)
        # states.shape == (T, B, DS)
        # joint_action_n_hot.shape == (T, B, N*A), N-hot encoding
        # w.shape == (T, B, N)
        # b.shape == (T, B)

        fixee_qvalues = individual_qvalues.sum(axis=0)
        # fixee_qvalues.shape == (T, B)
        individual_advantages = individual_qvalues - individual_vvalues
        # individual_advantages.shape == (T, B)

        if self.detach_advantages:
            individual_advantages = jax.lax.stop_gradient(individual_advantages)

        w = self.w(states, joint_action_n_hot, w_delta=w_delta, w_gt=w_gt)
        b = self.b(states)

        intervention = jnp.einsum("TBN,NTB->TB", w, individual_advantages) + b
        # intervention.shape == (T, B)
        joint_qvalues = fixee_qvalues + intervention
        # joint_qvalues.shape == (T, B)

        info = {
            "w": w,
            "b": b,
            "intervention": intervention,
        }
        return joint_qvalues, info

    def vvalues(
        self,
        individual_vvalues: jax.Array,
        states: jax.Array,
    ) -> tuple[jax.Array,dict]:
        # individual_vvalues.shape == (N, T, B)
        # states.shape == (T, B, DS)

        fixee_vvalues = individual_vvalues.sum(axis=0)
        # fixee_qvalues.shape == (T, B)

        b = self.b(states)

        intervention = b
        # intervention.shape == (T, B)
        joint_vvalues = fixee_vvalues + intervention
        # joint_vvalues.shape == (T, B)

        info = {
            "b": b,
            "intervention": intervention,
        }
        return joint_vvalues, info


    def w(
        self,
        states: jax.Array,
        joint_action_n_hot: jax.Array,
        *,
        w_delta: float,
        w_gt: float,
    ) -> jax.Array:
        # states.shape == (T, B, DS)
        # joint_action_n_hot.shape == (T, B, N*A), N-hot encoding

        if self.debug_recover_fixee_w:
            T, B, _ = states.shape
            shape = T, B, self.num_agents
            w = jnp.zeros_like(states, shape=shape, dtype=jnp.float32)
            # w.shape == (T, B, N)
        else:
            w = self.w_module(
                states,
                joint_action_n_hot,
                w_delta=w_delta,
                w_gt=w_gt,
            )
            # w.shape == (T, B, N)

        return w

    def b(
        self,
        states: jax.Array,
    ) -> jax.Array:
        # states.shape == (T, B, DS)
        # TODO shape wrong

        if self.debug_recover_fixee_b:
            T, B, *_ = states.shape
            shape = T, B
            b = jnp.zeros_like(states, shape=shape, dtype=jnp.float32)
            # b.shape == (T, B)
        else:
            b = self.b_module(states).squeeze(-1)
            # b.shape == (T, B)

        return b
