import torch
import torch.nn as nn
import torch.nn.functional as F

POOLMETHODS = [
    "adapool", 
    'adapoolent',

]

import pickle

def save_pickle(obj, path):
    with open(path, "wb") as f:
        pickle.dump(obj, f)

def build_vision_processor(mode, hidden_size, visual_use_feature = None):
    if mode == "adapool": 
        return AdaPoolVisionToken(hidden_size)
    elif mode == 'adapoolent':
        return AdaPoolVisionTokenWithENT(hidden_size, use_ent_token_in_visual=True)
   



class AdaPoolVisionToken(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.attn_proj = nn.Linear(hidden_dim, 1)  # project to attention score
        self.projector = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
        )
       

    def forward(self, x):
        # x: [B, 576, 4096]
        attn_scores = self.attn_proj(x)  # [B, 576, 1]
        attn_weights = F.softmax(attn_scores, dim=1)  # normalize
        pooled = (x * attn_weights).sum(dim=1)  # weighted sum [B, 4096]
        output = self.projector(pooled)
        return output
    


class AdaPoolVisionTokenWithENT(nn.Module):
    def __init__(self, hidden_dim, use_ent_token_in_visual=False):
        super().__init__()
        self.use_ent_token_in_visual = use_ent_token_in_visual

        # For entity-conditioned attention
        if use_ent_token_in_visual:
            self.q_proj = nn.Linear(hidden_dim, hidden_dim)
            self.k_proj = nn.Linear(hidden_dim, hidden_dim)
        else:
            self.attn_proj = nn.Linear(hidden_dim, 1)

        self.projector = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim),
        )

    def forward(self, x, ent_token=None):
        """
        x: visual patch tokens [B, 576, 4096]
        ent_token: entity token features [B, 4096] (optional)
        """
        # if self.use_ent_token_in_visual:
        if ent_token is not None:
            # Entity-conditioned attention
            # Q from entity token, K from patches
            q = self.q_proj(ent_token).unsqueeze(1)       # [B, 1, D]
            k = self.k_proj(x)                            # [B, 576, D]

            attn_scores = (q * k).sum(dim=-1) / (x.size(-1) ** 0.5)  # [B, 576]
            attn_weights = F.softmax(attn_scores, dim=1).unsqueeze(-1)  # [B, 576, 1]
            
        else:
            # Original attention
            attn_scores = self.attn_proj(x)               # [B, 576, 1]
            attn_weights = F.softmax(attn_scores, dim=1)

        pooled = (x * attn_weights).sum(dim=1)            # [B, 4096]
        out = self.projector(pooled)

        return out


