import torch
import torch.nn.functional as F
from torch import nn
from typing import Dict, Optional, Tuple, Union, Any

from pipeline.registry import registry
from model.vision.lpss import LPSSSubspaceRouting


class FC(nn.Module):
    def __init__(self, in_size, out_size, pdrop=0.0, use_gelu=True):
        super().__init__()
        self.pdrop = pdrop
        self.use_gelu = use_gelu
        self.linear = nn.Linear(in_size, out_size)
        if use_gelu:
            self.gelu = nn.GELU()
        if pdrop > 0:
            self.dropout = nn.Dropout(pdrop)

    def forward(self, x):
        x = self.linear(x)
        if self.use_gelu:
            x = self.gelu(x)
        if self.pdrop > 0:
            x = self.dropout(x)
        return x


class MLP(nn.Module):
    def __init__(self, in_size, mid_size, out_size, pdrop=0.0, use_gelu=True):
        super().__init__()
        self.fc = FC(in_size, mid_size, pdrop=pdrop, use_gelu=use_gelu)
        self.linear = nn.Linear(mid_size, out_size)

    def forward(self, x):
        return self.linear(self.fc(x))


class AttFlat(nn.Module):
    def __init__(self, hidden_size, flat_mlp_size=512, flat_glimpses=1, flat_out_size=1024, pdrop=0.1):
        super().__init__()
        self.mlp = MLP(
            in_size=hidden_size,
            mid_size=flat_mlp_size,
            out_size=flat_glimpses,
            pdrop=pdrop,
            use_gelu=True,
        )
        self.flat_glimpses = flat_glimpses
        self.linear_merge = nn.Linear(hidden_size * flat_glimpses, flat_out_size)

    def forward(self, x, x_mask):
        att = self.mlp(x)
        if x_mask is not None:
            att = att.masked_fill(
                x_mask.squeeze(1).squeeze(1).unsqueeze(2),
                -1e9,
            )
        att = F.softmax(att, dim=1)
        att_list = []
        for i in range(self.flat_glimpses):
            att_list.append(torch.sum(att[:, :, i : i + 1] * x, dim=1))
        x_atted = torch.cat(att_list, dim=1)
        x_atted = self.linear_merge(x_atted)
        return x_atted


@registry.register_other_model("qa_head_lpss")
class QAHeadWithLPSS(nn.Module):
    """
    QA head with Late LPSS (logits injection into language attention).
    """

    def __init__(
        self,
        hidden_size: int = 768,
        mlp_size: int = 256,
        glimpse: int = 1,
        flat_out_size: int = 512,
        num_answers: int = 8864,
        use_lpss: bool = True,
        num_subspaces: int = 6,
        seq_len: int = 50,
        lpss_temperature: float = 0.3,
        lpss_dropout: float = 0.1,
        lpss_bias_scale: float = 15.0,
        lpss_router_gain: float = 1.5,
        lpss_use_gumbel: bool = True,
        lpss_gumbel_tau: float = 1.0,
        lpss_gumbel_hard: bool = False,
        lpss_top_k: int = 2,
        lpss_grad_scale: float = 10.0,
        lpss_logits_scale: float = 5.0,
        lpss_enable_metrics: bool = False,
        dropout: float = 0.3,
    ):
        super().__init__()

        self.hidden_size = hidden_size
        self.flat_out_size = flat_out_size
        self.use_lpss = use_lpss
        self.seq_len = seq_len

        self.attflat_visual = AttFlat(hidden_size, mlp_size, glimpse, flat_out_size, 0.1)

        self.attflat_lang_mlp = MLP(
            in_size=hidden_size,
            mid_size=mlp_size,
            out_size=glimpse,
            pdrop=0.1,
            use_gelu=True,
        )
        self.attflat_lang_glimpses = glimpse
        self.attflat_lang_merge = nn.Linear(hidden_size * glimpse, flat_out_size)

        if use_lpss:
            self.lpss = LPSSSubspaceRouting(
                num_subspaces=num_subspaces,
                seq_len=seq_len,
                routing_dim=flat_out_size,
                hidden_dim=256,
                dropout=lpss_dropout,
                temperature=lpss_temperature,
                bias_scale_init=lpss_bias_scale,
                router_gain=lpss_router_gain,
                use_gumbel=lpss_use_gumbel,
                gumbel_tau=lpss_gumbel_tau,
                gumbel_hard=lpss_gumbel_hard,
                top_k=lpss_top_k,
                grad_scale=lpss_grad_scale,
                logits_scale=lpss_logits_scale,
                enable_metrics=lpss_enable_metrics,
            )
        else:
            self.lpss = None

        self.answer_qr = nn.Sequential(
            nn.Linear(flat_out_size, hidden_size),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, num_answers),
        )

        self.fusion_norm = nn.LayerNorm(flat_out_size)

        self._last_lpss_info = None
        self._last_attention_weights = None

    def _attflat_lang_with_bias(
        self,
        txt_embeds: torch.Tensor,
        txt_masks: torch.Tensor,
        attention_bias: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        original_logits = self.attflat_lang_mlp(txt_embeds)
        injected_logits = original_logits + attention_bias.unsqueeze(-1) if attention_bias is not None else original_logits
        if txt_masks is not None:
            mask = txt_masks.logical_not()
            injected_logits = injected_logits.masked_fill(mask.unsqueeze(-1), -1e9)
        attention_weights = F.softmax(injected_logits, dim=1)
        att_list = []
        for i in range(self.attflat_lang_glimpses):
            att_list.append(torch.sum(attention_weights[:, :, i : i + 1] * txt_embeds, dim=1))
        x_atted = torch.cat(att_list, dim=1)
        lang_feat = self.attflat_lang_merge(x_atted)
        return lang_feat, attention_weights

    def forward(
        self,
        obj_embeds: torch.Tensor,
        obj_masks: torch.Tensor,
        txt_embeds: torch.Tensor,
        txt_masks: torch.Tensor,
        return_lpss_info: bool = False,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, Any]]]:
        object_feat = self.attflat_visual(obj_embeds, obj_masks.logical_not())
        attention_bias = None
        lpss_info = None
        if self.lpss is not None:
            attention_bias, lpss_info = self.lpss(object_feat, return_info=True)
            self._last_lpss_info = lpss_info
        lang_feat, attention_weights = self._attflat_lang_with_bias(txt_embeds, txt_masks, attention_bias)
        self._last_attention_weights = attention_weights.detach()
        fuse_feat = self.fusion_norm(lang_feat + object_feat)
        answer_scores = self.answer_qr(fuse_feat)
        if return_lpss_info and lpss_info is not None:
            return answer_scores, lpss_info
        return answer_scores
