import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel, BertTokenizer, BertTokenizerFast
from collections import defaultdict
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizerFast
from transformers import AutoTokenizer, AutoModel, CLIPVisionModel
from torchvision.models.detection import fasterrcnn_resnet50_fpn
import torchvision
from utils import chunk_by_sentences

# Embedding extraction
def spectral_token_compression(hidden_states, K=4, gate='none',normalize_positions = True):
    """
    Apply STC to token-wise hidden states [B, L, D] → [B, (2K+1)*D]
    """
    B, L, D = hidden_states.shape
    device = hidden_states.device
    if normalize_positions:
        positions = torch.linspace(0, 1, steps=L, device=device) 
    else: 
        positions = torch.arange(1, L + 1, device=device).float()
    omega = torch.pi * torch.arange(1, K + 1, device=device).float() / L
    cos_matrix = torch.cos(torch.outer(omega, positions))  # [K, L]
    sin_matrix = torch.sin(torch.outer(omega, positions))  # [K, L]
    
    compressed = []
    for b in range(B):
        hs = hidden_states[b]  # [L, D]
        C = torch.matmul(cos_matrix, hs)  # [K, D]
        S = torch.matmul(sin_matrix, hs)  # [K, D]
        mean = hs.mean(dim=0, keepdim=True)  # [1, D]

        if gate == 'softmax':
            weights = torch.softmax(torch.arange(1, K + 1, device=device).float(), dim=0)
            C = C * weights[:, None]
            S = S * weights[:, None]

        z = torch.cat([mean, C, S], dim=0).reshape(-1)  # [(2K+1)*D]
        compressed.append(z)

    return torch.stack(compressed)  # [B, (2K+1)*D]


class MaxPPool(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, document_embeddings, query_embedding):
        similarities = torch.matmul(document_embeddings, query_embedding.unsqueeze(-1)).squeeze(-1)
        max_indices = torch.argmax(similarities, dim=1)
        batch_indices = torch.arange(document_embeddings.size(0), device=document_embeddings.device)
        best_chunk_embeddings = document_embeddings[batch_indices, max_indices]

        return best_chunk_embeddings


class SATPool(nn.Module):
    def __init__(self, d_model=768, K=4, d_attn=128, normalize_pos=True,saa_mean=False,gate=None):
        super().__init__()
        self.K, self.d_a, self.norm_pos = K, d_attn, normalize_pos
        self.w_qkv = nn.Linear(d_model, 3 * d_attn, bias=False)
        self.proj_out = nn.Linear((1 if saa_mean else (2 * K + 1)) * d_attn, d_model, bias=False)
        self.eps = 1e-4                     # for positive kernel
        self.sat = saa_mean
        self.gate = gate
    @staticmethod
    def phi(x, eps):                       # Performer feature map
        return F.relu(x) + eps

    def forward(self, H):                  # H: [B,L,D]
        B, L, _ = H.shape
        device = H.device

        qkv = self.w_qkv(H)                # [B,L,3d_a]
        Q, K, V = torch.chunk(qkv, 3, dim=-1)
        Q, K = self.phi(Q, self.eps), self.phi(K, self.eps)
        if hasattr(self, '_return_attention_scores') and self._return_attention_scores:
            # Explicitly calculate the L x L attention map for visualization
            # `blh,bnh->bln` is batch matrix multiplication: (B, L, d_a) @ (B, d_a, L) -> (B, L, L)
            attention_scores = torch.einsum('blh,bnh->bln', Q, K)
            self.attention_scores = attention_scores
        # linear attention:  Q (K^T V)
        KV = torch.einsum('blh,blm->bhm', K, V)        # [B,d_a,d_a]
        Z  = torch.einsum('blh,bhm->blm', Q, KV)       # [B,L,d_a]
        if self.sat:
            z = Z.mean(dim=1)
        else:
            # positions
            pos = torch.linspace(0, 1, L, device=device) if self.norm_pos \
                else torch.arange(L, device=device).float() / L
            omega = torch.pi * torch.arange(1, self.K + 1, device=device)

            # mean component
            comps = [Z.mean(dim=1)]                        # [B,d_a]

            # cosine & sine components
            cos_pos = torch.cos(torch.einsum('k,l->kl', omega, pos))   # [K,L]
            sin_pos = torch.sin(torch.einsum('k,l->kl', omega, pos))

            # einsum over L for each frequency
            C = torch.einsum('kl,blm->bkm', cos_pos, Z)   # [B,K,d_a]
            S = torch.einsum('kl,blm->bkm', sin_pos, Z)   # [B,K,d_a]
            if hasattr(self, '_return_spectral_coeffs') and self._return_spectral_coeffs:
                self.spectral_coeffs = {'C': C, 'S': S}
            if self.gate == "softmax":
                gate_weights = F.softmax(torch.cat([C, S], dim=1), dim=1)  # [B,2K,d_a]
                C, S = gate_weights[:, :self.K], gate_weights[:, self.K:]  # split
                comps += [C.reshape(B, -1), S.reshape(B, -1)]
            else:
                comps += [C.reshape(B, -1), S.reshape(B, -1)]


            z = torch.cat(comps, dim=-1)                  # [B,(2K+1)d_a]
        return self.proj_out(z)                       # [B,D]
    


class CLIPRetriever(nn.Module):
    def __init__(self,args, clip_model_name="openai/clip-vit-base-patch32", embed_dim=512,device=None):
        super().__init__()
        self.clip = CLIPModel.from_pretrained(clip_model_name)
        self.processor = CLIPProcessor.from_pretrained(clip_model_name)
        self.tokenizer = CLIPTokenizerFast.from_pretrained(clip_model_name)
        self.get_text_features = self.clip.get_text_features
        self.get_image_features = self.clip.get_image_features
        # Freeze all CLIP parameters
        for param in self.clip.parameters():
            param.requires_grad = False
        self.K = args.K
        self.image_encoder = self.clip.vision_model
        self.text_encoder = self.clip.text_model
        proj_dim = args.proj_dim
        self.args = args
        self.device = device
        
        # Set input dimension to 1280 to match checkpoint
        in_dim = 1280
            
        self.query_proj = nn.Sequential(
                    nn.Linear(in_dim, proj_dim),
                    nn.GELU(),
                    nn.Linear(proj_dim, proj_dim),
                    nn.GELU(),
                    nn.Linear(proj_dim, proj_dim)
                )
        # Three-layer MLP for document projection
        if args.mode in ['stc']:
            output_dim = int((2*args.K+1)*embed_dim)
        else:
            output_dim = embed_dim
        self.doc_proj = nn.Sequential(
            nn.Linear(output_dim, proj_dim),
            nn.GELU(),
            nn.Linear(proj_dim, proj_dim),
            nn.GELU(),
            nn.Linear(proj_dim, proj_dim)
        )
        self.norm_doc = nn.LayerNorm(output_dim)
        if args.mode=='sat':
            self.sat = SATPool(d_model=embed_dim,
                               K=args.K, d_attn=args.d_attn,
                               normalize_pos=args.normalize_positions,saa_mean = args.saa_mean, gate=args.gate)
        elif args.mode=='lc':
            self.latechunk_proj = nn.Sequential(
                nn.Linear(output_dim, proj_dim),
                nn.GELU(),
                nn.Linear(proj_dim, proj_dim),
                nn.GELU(),
                nn.Linear(proj_dim, proj_dim)
            )
            self.latechunk_agg = nn.Linear(embed_dim, 1)


        
    def get_image_layer(self, layer_idx):
        return self.image_encoder.encoder.layers[layer_idx]
    def get_text_layer(self, layer_idx):
        return self.text_encoder.encoder.layers[layer_idx]

    def image_layer_feature(self, image,layer_idx):
        outputs = self.image_encoder(pixel_values=image, output_hidden_states=True)
        return  outputs.hidden_states[layer_idx]
    def text_layer_feature(self, input_ids, attention_mask,layer_idx):
        outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        return  outputs.hidden_states[layer_idx]
    def encode_image(self, image):
        outputs = self.image_encoder(pixel_values=image, output_hidden_states=True)

        return  outputs.hidden_states[-2].mean(dim=1) 

    def encode_text(self, input_ids, attention_mask):
        outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)

        return  outputs.hidden_states[-2].mean(dim=1) 
    def late_chunk_pooling(self, feat, spans):
        """Apply mean pooling over sentence-defined spans."""
        chunk_embeds = []
        
        for start, end in spans:
            if end > start:
                pooled = feat[start:end].mean(dim=0)
                chunk_embeds.append(pooled)
        if len(chunk_embeds) == 0:
            return feat.mean(dim=0)  # fallback
        chunk_embeds = torch.stack(chunk_embeds)  # [K, D]
        weights = torch.softmax(self.latechunk_agg(chunk_embeds), dim=0)  # [K, 1]
        return (weights * chunk_embeds).sum(dim=0) 
    
    def encode_queries(self, image, text):

        q_inputs = self.processor(text=text, return_tensors="pt", padding=True, truncation=True)
        questions = q_inputs["input_ids"].to(self.device)
        attention_mask = q_inputs["attention_mask"].to(self.device)
        with torch.no_grad():
                
            image_feat = self.image_encoder(
                        pixel_values=image,
                        output_hidden_states=True
                    )
            text_feat = self.text_encoder(
                input_ids=questions,
                attention_mask=attention_mask,
                output_hidden_states=True,
            )
        if self.args.mode=='cls':
            image_feat = image_feat.pooler_output
            text_feat = text_feat.pooler_output
        else:
            
            image_feat = image_feat.pooler_output
            text_feat = text_feat.pooler_output
            
            # image_feat = image_feat.hidden_states[-2].mean(dim=1)
            # text_feat = text_feat.hidden_states[-2].mean(dim=1)
        query_repr = self.query_proj(torch.cat([image_feat, text_feat], dim=-1))
        query_repr = F.normalize(query_repr, dim=-1)

        return query_repr
    def encode_documents(self, passage_cache, keys,eval_mode=False,pid2text=None):
        doc_feats = []
        for pid in keys:
            if self.args.mode == 'maxp':
                feat = passage_cache[pid]['embedding'].to(self.device)  # [num_chunks, D]
            else:
                feat = passage_cache[pid]['embedding'].squeeze(0).to(self.device)  # [T, D]

            if self.args.mode == 'cls':
                vec = feat
            elif self.args.mode == 'mean':
                vec = feat.mean(dim=0)  
            elif self.args.mode == 'stc':
                    vec = spectral_token_compression(
                        feat.unsqueeze(0), K=self.args.K, gate=self.args.gate,
                        normalize_positions=self.args.normalize_positions
                    ).squeeze(0)
            elif self.args.mode == 'sat':
                vec = self.sat(feat.unsqueeze(0)).squeeze(0)
            
            elif self.args.mode == 'maxp':
                # For MaxP, return chunk-level embeddings per document
                feat = self.norm_doc(feat)
                vec = self.doc_proj(feat)         # [T, d]
                vec = F.normalize(vec, dim=-1)
                if eval_mode:
                    vec = vec.detach().cpu()
                doc_feats.append(vec)
                continue
            elif self.args.mode=='lc':
                text = pid2text[pid]
                span_annotations = chunk_by_sentences(text, tokenizer=self.tokenizer)
                vec = self.late_chunk_pooling(feat, span_annotations)  # [D]
                vec = self.latechunk_proj(vec)
            if eval_mode:
                vec = vec.detach().cpu()
            
            doc_feats.append(vec)
        if self.args.mode in ['maxp']:
            return doc_feats
        doc_embeddings = torch.stack(doc_feats).to(self.device)
        doc_embeddings = self.norm_doc(doc_embeddings)
        doc_repr = self.doc_proj(doc_embeddings)
        doc_repr = F.normalize(doc_repr, dim=-1)
        return doc_repr


class LongCLIPRetriever(nn.Module):
    def __init__(self, args, clip_model_name="zer0int/LongCLIP-GmP-ViT-L-14", device=None):
        super().__init__()
        self.clip = CLIPModel.from_pretrained(clip_model_name)
        self.processor = CLIPProcessor.from_pretrained(clip_model_name)
        self.tokenizer = CLIPTokenizerFast.from_pretrained(clip_model_name)
        self.get_text_features = self.clip.get_text_features
        self.get_image_features = self.clip.get_image_features

        # Freeze all CLIP parameters
        for param in self.clip.parameters():
            param.requires_grad = False

        self.K = args.K
        self.image_encoder = self.clip.vision_model
        self.text_encoder = self.clip.text_model

        proj_dim = args.proj_dim
        self.args = args
        self.device = device

        vision_hidden_size = self.clip.vision_model.config.hidden_size  
        text_hidden_size = self.clip.text_model.config.hidden_size      
        in_dim = vision_hidden_size + text_hidden_size 

        self.query_proj = nn.Sequential(
            nn.Linear(in_dim, proj_dim),
            nn.GELU(),
            nn.Linear(proj_dim, proj_dim),
            nn.GELU(),
            nn.Linear(proj_dim, proj_dim)
        )

        model_hidden_size = self.clip.text_model.config.hidden_size 
        if args.mode in ['stc']:
            output_dim = int((2*args.K+1)*model_hidden_size)
        else:
            output_dim = model_hidden_size

        self.doc_proj = nn.Sequential(
            nn.Linear(output_dim, proj_dim),
            nn.GELU(),
            nn.Linear(proj_dim, proj_dim),
            nn.GELU(),
            nn.Linear(proj_dim, proj_dim)
        )
        self.norm_doc = nn.LayerNorm(output_dim)

        if args.mode=='sat':
            self.sat = SATPool(d_model=model_hidden_size,
                               K=args.K, d_attn=args.d_attn,
                               normalize_pos=args.normalize_positions,saa_mean = args.saa_mean, gate=args.gate)
        elif args.mode=='lc':
            self.latechunk_proj = nn.Sequential(
                nn.Linear(output_dim, proj_dim),
                nn.GELU(),
                nn.Linear(proj_dim, proj_dim),
                nn.GELU(),
                nn.Linear(proj_dim, output_dim)
            )
            self.latechunk_agg = nn.Linear(model_hidden_size, 1)


    def get_image_layer(self, layer_idx):
        return self.image_encoder.encoder.layers[layer_idx]

    def get_text_layer(self, layer_idx):
        return self.text_encoder.encoder.layers[layer_idx]

    def image_layer_feature(self, image, layer_idx):
        outputs = self.image_encoder(pixel_values=image, output_hidden_states=True)
        return outputs.hidden_states[layer_idx]

    def text_layer_feature(self, input_ids, attention_mask, layer_idx):
        outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        return outputs.hidden_states[layer_idx]

    def encode_image(self, image):
        outputs = self.image_encoder(pixel_values=image, output_hidden_states=True)
        return outputs.hidden_states[-2].mean(dim=1)

    def encode_text(self, input_ids, attention_mask):
        outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        return outputs.hidden_states[-2].mean(dim=1)

    def late_chunk_pooling(self, feat, spans):
        """Apply mean pooling over sentence-defined spans."""
        chunk_embeds = []
        for start, end in spans:
            if end > start:
                pooled = feat[start:end].mean(dim=0)
                chunk_embeds.append(pooled)
        if len(chunk_embeds) == 0:
            return feat.mean(dim=0) 
        chunk_embeds = torch.stack(chunk_embeds)  # [K, D]
        weights = torch.softmax(self.latechunk_agg(chunk_embeds), dim=0)  # [K, 1]
        return (weights * chunk_embeds).sum(dim=0)

    def encode_queries(self, image, text):
        q_inputs = self.processor(text=text, return_tensors="pt", padding=True, truncation=True)
        questions = q_inputs["input_ids"].to(self.device)
        attention_mask = q_inputs["attention_mask"].to(self.device)
        with torch.no_grad():
            image_feat = self.image_encoder(
                pixel_values=image,
                output_hidden_states=True
            )
            text_feat = self.text_encoder(
                input_ids=questions,
                attention_mask=attention_mask,
                output_hidden_states=True,
            )

        if self.args.mode=='cls':
            image_feat = image_feat.pooler_output
            text_feat = text_feat.pooler_output
        else:
            image_feat = image_feat.pooler_output
            text_feat = text_feat.pooler_output
            # image_feat = image_feat.hidden_states[-2].mean(dim=1)
            # text_feat = text_feat.hidden_states[-2].mean(dim=1)

        query_repr = self.query_proj(torch.cat([image_feat, text_feat], dim=-1))
        query_repr = F.normalize(query_repr, dim=-1)
        return query_repr

    def encode_documents(self, passage_cache, keys, eval_mode=False, pid2text=None):
        doc_feats = []
        for pid in keys:
            if self.args.mode == 'maxp':
                feat = passage_cache[pid]['embedding'].to(self.device)  # [num_chunks, D]
            else:
                feat = passage_cache[pid]['embedding'].squeeze(0).to(self.device)  # [T, D]

            if self.args.mode == 'cls':
                vec = feat
            elif self.args.mode == 'mean':
                vec = feat.mean(dim=0)
            elif self.args.mode == 'stc':
                vec = spectral_token_compression(
                    feat.unsqueeze(0), K=self.args.K, gate=self.args.gate,
                    normalize_positions=self.args.normalize_positions
                ).squeeze(0)
            elif self.args.mode == 'sat':
                vec = self.sat(feat.unsqueeze(0)).squeeze(0)
            elif self.args.mode=='lc':
                text = pid2text[pid]
                span_annotations = chunk_by_sentences(text, tokenizer=self.tokenizer)
                vec = self.late_chunk_pooling(feat, span_annotations)  # [D]
                vec = self.latechunk_proj(vec)

            if eval_mode:
                vec = vec.detach().cpu()
            doc_feats.append(vec)

        if self.args.mode in ['maxp']:
            return doc_feats

        doc_embeddings = torch.stack(doc_feats).to(self.device)
        doc_embeddings = self.norm_doc(doc_embeddings)
        doc_repr = self.doc_proj(doc_embeddings)
        doc_repr = F.normalize(doc_repr, dim=-1)
        return doc_repr


class RAVQARetriever(nn.Module):
    def __init__(self, args, device, model_name='bert-base-uncased'):
        super().__init__()
        self.args = args
        self.device = device
        self.tokenizer = BertTokenizerFast.from_pretrained(model_name)
        # Load the encoders
        self.query_encoder = BertModel.from_pretrained(model_name)
        self.doc_encoder = BertModel.from_pretrained(model_name)

        # Freeze the encoders' parameters
        for param in self.query_encoder.parameters():
            param.requires_grad = False
        for param in self.doc_encoder.parameters():
            param.requires_grad = False

        embed_dim = self.query_encoder.config.hidden_size
        proj_dim = args.proj_dim

        self.query_proj = nn.Sequential(
            nn.Linear(embed_dim, proj_dim),
            nn.GELU(),
            nn.Linear(proj_dim, proj_dim),
            nn.GELU(),
            nn.Linear(proj_dim, proj_dim)
        )

        if args.mode in ['stc']:
            output_dim = int((2*args.K+1)*embed_dim)
        else:
            output_dim = embed_dim
        self.doc_proj = nn.Sequential(
            nn.Linear(output_dim, proj_dim),
            nn.GELU(),
            nn.Linear(proj_dim, proj_dim),
            nn.GELU(),
            nn.Linear(proj_dim, proj_dim)
        )
        self.norm_doc = nn.LayerNorm(output_dim)
        if args.mode=='sat':
            self.sat = SATPool(d_model=embed_dim,
                               K=args.K, d_attn=args.d_attn,
                               normalize_pos=args.normalize_positions,saa_mean = args.saa_mean, gate=args.gate)
        elif args.mode=='maxp':
            self.maxp_pool = MaxPPool()
        elif args.mode=='lc':
            self.latechunk_proj = nn.Sequential(
                nn.Linear(output_dim, proj_dim),
                nn.GELU(),
                nn.Linear(proj_dim, proj_dim),
                nn.GELU(),
                nn.Linear(proj_dim, output_dim)
            )
            self.latechunk_agg = nn.Linear(embed_dim, 1)
    def late_chunk_pooling(self, feat, spans):
        """Apply mean pooling over sentence-defined spans."""
        chunk_embeds = []
        for start, end in spans:
            if end > start:
                pooled = feat[start:end].mean(dim=0)
                chunk_embeds.append(pooled)
        if len(chunk_embeds) == 0:
            return feat.mean(dim=0)  
        chunk_embeds = torch.stack(chunk_embeds)  # [K, D]
        weights = torch.softmax(self.latechunk_agg(chunk_embeds), dim=0)  # [K, 1]
        return (weights * chunk_embeds).sum(dim=0) 
    def encode_queries(self, queries):

        tokens = self.tokenizer(queries, return_tensors='pt', padding=True, truncation=True).to(self.device)
        
        if self.args.mode=='cls':
            embedding = self.query_encoder(**tokens).pooler_output  # [B, H]
        else :
            output = self.query_encoder(**tokens,output_hidden_states=True)
            embedding = output.hidden_states[-2].cpu()
            embedding = torch.mean(embedding, dim=1).to(self.device)
        projected_embedding = self.query_proj(embedding)
        return F.normalize(projected_embedding, dim=-1)
    def encode_documents(self, passage_cache, keys,eval_mode=False,pid2text=None):
        doc_feats = []
        for pid in keys:
            if self.args.mode == 'maxp':
                feat = passage_cache[pid]['embedding'].to(self.device)  # [num_chunks, D]
            else:
                feat = passage_cache[pid]['embedding'].squeeze(0).to(self.device)  # [T, D]

            if self.args.mode == 'cls':
                vec = feat
            elif self.args.mode == 'mean':
                vec = feat.mean(dim=0)
            elif self.args.mode == 'stc':
                    vec = spectral_token_compression(
                        feat.unsqueeze(0), K=self.args.K, gate=self.args.gate,
                        normalize_positions=self.args.normalize_positions
                    ).squeeze(0)
            elif self.args.mode == 'sat':
                vec = self.sat(feat.unsqueeze(0)).squeeze(0)
            elif self.args.mode == 'maxp':
                vec = self.norm_doc(feat)  # [num_chunks, D]
                vec = self.doc_proj(vec)   # [num_chunks, D]
                vec = F.normalize(vec, dim=-1)
                if eval_mode:
                    vec = vec.detach().cpu()
                doc_feats.append(vec)
                continue
            elif self.args.mode=='lc':
                text = pid2text[pid]
                span_annotations = chunk_by_sentences(text, tokenizer=self.tokenizer)
                vec = self.late_chunk_pooling(feat, span_annotations)  # [D]
                vec = self.latechunk_proj(vec)
            if eval_mode:
                vec = vec.detach().cpu()
            
            doc_feats.append(vec)
        if self.args.mode in ['maxp']:
            return doc_feats
        doc_embeddings = torch.stack(doc_feats).to(self.device)
        doc_embeddings = self.norm_doc(doc_embeddings)
        doc_repr = self.doc_proj(doc_embeddings)
        doc_repr = F.normalize(doc_repr, dim=-1)
        return doc_repr
    

class Generator(nn.Module):
    def __init__(self, model_name='t5-large'):
        super().__init__()
        self.model = T5ForConditionalGeneration.from_pretrained(model_name)
        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        
    def forward(self, inputs, targets, device):
        max_len = 512

        input_tokens = self.tokenizer(
            inputs,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=max_len  
        ).to(device)
        
        target_tokens = self.tokenizer(
            targets,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=max_len  
        ).input_ids.to(device)

        target_tokens[target_tokens == self.tokenizer.pad_token_id] = -100
        self.model.to(device)
        
        loss = self.model(**input_tokens, labels=target_tokens).loss
        return loss
    
    def generate(self, inputs, device):
        self.model.to(device)
        
        input_tokens = self.tokenizer(inputs, return_tensors='pt', padding=True, truncation=True).to(device)
        outputs = self.model.generate(**input_tokens, max_length=32)
            
        return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)

class TextClassifier(nn.Module):
    def __init__(self, args, hidden_size, num_labels,use_norm=False,classifier=None, proj_dim=256):
        super(TextClassifier, self).__init__()
        self.use_norm = use_norm
        if use_norm:
            self.norm = nn.LayerNorm(hidden_size)
        
        if classifier=='mlp':
            self.run_mlp = True
            h = proj_dim
            self.classifier = nn.Sequential(
                    nn.Linear(hidden_size, h),
                    nn.ReLU(),
                    nn.Linear(h, h),
                    nn.ReLU(),
                    nn.Linear(h, num_labels)
                )
        else:
            self.run_mlp = False
            self.classifier = nn.Linear(hidden_size, num_labels)
        if args.mode=='lc':
            self.latechunk_agg = nn.Linear(hidden_size, 1)
    def late_chunk_pooling(self, feat, spans):
        """Apply mean pooling over sentence-defined spans."""
        chunk_embeds = []
        
        for start, end in spans:
            if end > start:
                pooled = feat[start:end].mean(dim=0)
                chunk_embeds.append(pooled)
        if len(chunk_embeds) == 0:
            return feat.mean(dim=0)  # fallback
        chunk_embeds = torch.stack(chunk_embeds)  # [K, D]
        weights = torch.softmax(self.latechunk_agg(chunk_embeds), dim=0)  # [K, 1]
        return (weights * chunk_embeds).sum(dim=0) 
    
    def forward(self, embeddings):
        if self.use_norm:
            embeddings = self.norm(embeddings)
        outputs = self.classifier(embeddings)
        return outputs

class FactualityChecker(nn.Module):
    def __init__(self, args, device, model_name='bert-base-uncased'):
        super().__init__()
        self.args = args
        self.device = device
        self.tokenizer = BertTokenizerFast.from_pretrained(model_name)
        self.text_encoder = BertModel.from_pretrained(model_name)
        
        # Freeze the encoders' parameters
        for param in self.text_encoder.parameters():
            param.requires_grad = False
            
        embed_dim = self.text_encoder.config.hidden_size
        proj_dim = args.proj_dim

        self.doc_proj = nn.Sequential(
            nn.Linear(embed_dim, proj_dim),
            nn.GELU(),
            nn.Linear(proj_dim, proj_dim),
            nn.GELU(),
            nn.Linear(proj_dim, proj_dim)
        )
        self.summ_proj = nn.Sequential(
            nn.Linear(embed_dim, proj_dim),
            nn.GELU(),
            nn.Linear(proj_dim, proj_dim),
            nn.GELU(),
            nn.Linear(proj_dim, proj_dim)
        )

        if args.mode in ['stc']:
            output_dim = int((2*args.K+1)*embed_dim)
        else:
            output_dim = embed_dim
        
        self.doc_norm = nn.LayerNorm(output_dim)
        self.summ_norm = nn.LayerNorm(output_dim)
        self.cls  = nn.Linear(1,1)
        if args.mode=='sat':
            self.sat = SATPool(d_model=embed_dim,
                               K=args.K, d_attn=args.d_attn,
                               normalize_pos=args.normalize_positions,saa_mean = args.saa_mean, gate=args.gate)
            self.sat = SATPool(d_model=embed_dim,
                               K=args.K, d_attn=args.d_attn,
                               normalize_pos=args.normalize_positions,saa_mean = args.saa_mean, gate=args.gate)
        elif args.mode=='lc':
            self.latechunk_proj = nn.Sequential(
                nn.Linear(output_dim, proj_dim),
                nn.GELU(),
                nn.Linear(proj_dim, proj_dim),
                nn.GELU(),
                nn.Linear(proj_dim, output_dim)
            )
            self.latechunk_agg = nn.Linear(embed_dim, 1)
    def late_chunk_pooling(self, feat, spans):
        """Apply mean pooling over sentence-defined spans."""
        chunk_embeds = []
        for start, end in spans:
            if end > start:
                pooled = feat[start:end].mean(dim=0)
                chunk_embeds.append(pooled)
        if len(chunk_embeds) == 0:
            return feat.mean(dim=0)  # fallback
        chunk_embeds = torch.stack(chunk_embeds)  # [K, D]
        weights = torch.softmax(self.latechunk_agg(chunk_embeds), dim=0)  # [K, 1]
        return (weights * chunk_embeds).sum(dim=0) 
    
    def encode_documents(self, passage_cache, keys,eval_mode=False,pid2text=None):
        doc_feats = []
        for pid in keys:
            if self.args.mode == 'maxp':
                feat = passage_cache[pid]['embedding'].to(self.device)  # [num_chunks, D]
            else:
                feat = passage_cache[pid]['embedding'].squeeze(0).to(self.device)  # [T, D]

            if self.args.mode == 'cls':
                vec = feat
            elif self.args.mode == 'mean':
                vec = feat.mean(dim=0)  
            elif self.args.mode == 'stc':
                    vec = spectral_token_compression(
                        feat.unsqueeze(0), K=self.args.K, gate=self.args.gate,
                        normalize_positions=self.args.normalize_positions
                    ).squeeze(0)
            elif self.args.mode == 'sat':
                vec = self.sat(feat.unsqueeze(0)).squeeze(0)
            elif self.args.mode=='lc':
                text = pid2text[pid]
                span_annotations = chunk_by_sentences(text, tokenizer=self.tokenizer)
                vec = self.late_chunk_pooling(feat, span_annotations)  # [D]
                vec = self.latechunk_proj(vec)
            if eval_mode:
                vec = vec.detach().cpu()
            
            doc_feats.append(vec)
        if self.args.mode in ['maxp']:
            return doc_feats
        doc_embeddings = torch.stack(doc_feats).to(self.device)
        doc_embeddings = self.doc_norm(doc_embeddings)
        doc_repr = self.doc_proj(doc_embeddings)
        return doc_repr
    

    def encode_summary(self, passage_cache, keys,eval_mode=False,pid2text=None):
        doc_feats = []
        for pid in keys:
            if self.args.mode == 'maxp':
                feat = passage_cache[pid]['embedding'].to(self.device)  # [num_chunks, D]
            else:
                feat = passage_cache[pid]['embedding'].squeeze(0).to(self.device)  # [T, D]

            if self.args.mode == 'cls':
                vec = feat
            elif self.args.mode == 'mean':
                vec = feat.mean(dim=0)  
            elif self.args.mode == 'stc':
                    vec = spectral_token_compression(
                        feat.unsqueeze(0), K=self.args.K, gate=self.args.gate,
                        normalize_positions=self.args.normalize_positions
                    ).squeeze(0)
            elif self.args.mode == 'sat':
                vec = self.sat(feat.unsqueeze(0)).squeeze(0)
            elif self.args.mode=='lc':
                text = pid2text[pid]
                span_annotations = chunk_by_sentences(text, tokenizer=self.tokenizer)
                vec = self.late_chunk_pooling(feat, span_annotations)  # [D]
                vec = self.latechunk_proj(vec)
            if eval_mode:
                vec = vec.detach().cpu()
            
            doc_feats.append(vec)
        if self.args.mode in ['maxp']:
            return doc_feats
        doc_embeddings = torch.stack(doc_feats).to(self.device)
        doc_embeddings = self.summ_norm(doc_embeddings)
        doc_repr = self.summ_proj(doc_embeddings)
        return doc_repr
    



            
class GTRRetriver(nn.Module):
    def __init__(self, args, device, model_name='sentence-transformers/gtr-t5-base'):
        from sentence_transformers import SentenceTransformer
        super().__init__()
        self.args = args
        self.device = device
        
        self.doc_encoder = SentenceTransformer(model_name)
        from transformers import AutoTokenizer

        self.tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/gtr-t5-base')
        
        for param in self.doc_encoder.parameters():
            param.requires_grad = False
            
        embed_dim = self.doc_encoder.get_sentence_embedding_dimension()
        proj_dim = args.proj_dim

        self.query_proj = nn.Sequential(
            nn.Linear(embed_dim, proj_dim),
            nn.GELU(),
            nn.Linear(proj_dim, proj_dim),
            nn.GELU(),
            nn.Linear(proj_dim, proj_dim)
        )

        if args.mode in ['stc']:
            output_dim = int((2*args.K+1)*embed_dim)
        else:
            output_dim = embed_dim
        self.doc_proj = nn.Sequential(
            nn.Linear(output_dim, proj_dim),
            nn.GELU(),
            nn.Linear(proj_dim, proj_dim),
            nn.GELU(),
            nn.Linear(proj_dim, proj_dim)
        )
        self.norm_doc = nn.LayerNorm(output_dim)
        if args.mode=='sat':
            self.sat = SATPool(d_model=embed_dim,
                               K=args.K, d_attn=args.d_attn,
                               normalize_pos=args.normalize_positions,saa_mean = args.saa_mean, gate=args.gate)
        elif args.mode=='lc':
            self.latechunk_proj = nn.Sequential(
                nn.Linear(output_dim, proj_dim),
                nn.GELU(),
                nn.Linear(proj_dim, proj_dim),
                nn.GELU(),
                nn.Linear(proj_dim, output_dim)
            )
            self.latechunk_agg = nn.Linear(embed_dim, 1)
    def late_chunk_pooling(self, feat, spans):
        """Apply mean pooling over sentence-defined spans."""
        chunk_embeds = []
        
        for start, end in spans:
            if end > start:
                pooled = feat[start:end].mean(dim=0)
                chunk_embeds.append(pooled)
        if len(chunk_embeds) == 0:
            if feat.numel() == 0:
                embed_dim = self.latechunk_agg.in_features
                return torch.zeros(embed_dim, device=feat.device)
            return feat.mean(dim=0)  

        chunk_embeds = torch.stack(chunk_embeds)  # [K, D]
        weights = torch.softmax(self.latechunk_agg(chunk_embeds), dim=0)  # [K, 1]
        return (weights * chunk_embeds).sum(dim=0)
    def encode_queries(self, queries):
        with torch.no_grad():
            if self.args.mode=='cls':
                embedding = self.doc_encoder.encode(queries, convert_to_tensor=True)
            else :
                embedding = self.doc_encoder.encode(queries, convert_to_tensor=True)
        
        if self.args.mode=='cls':
            embedding = embedding.clone().detach().requires_grad_(True)
        else:
            embedding = embedding.clone().detach().requires_grad_(True)

        projected_embedding = self.query_proj(embedding)
        return F.normalize(projected_embedding, dim=-1)
    def encode_documents(self, passage_cache, keys,eval_mode=False,pid2text=None):
        doc_feats = []
        for pid in keys:
            if self.args.mode == 'maxp':
                feat = passage_cache[pid]['embedding'].to(self.device)  # [num_chunks, D]
            else:
                feat = passage_cache[pid]['embedding'].squeeze(0).to(self.device)  # [T, D]
            if self.args.mode == 'cls':
                vec = feat
            elif self.args.mode == 'mean':
                vec = feat.mean(dim=0)  
            elif self.args.mode == 'stc':
                    vec = spectral_token_compression(
                        feat.unsqueeze(0), K=self.args.K, gate=self.args.gate,
                        normalize_positions=self.args.normalize_positions
                    ).squeeze(0)
            elif self.args.mode == 'sat':
                vec = self.sat(feat.unsqueeze(0)).squeeze(0)
            elif self.args.mode == 'maxp':
                vec = self.norm_doc(feat)  # [num_chunks, D]
                vec = self.doc_proj(vec)   # [num_chunks, D]
                vec = F.normalize(vec, dim=-1)
                if eval_mode:
                    vec = vec.detach().cpu()
                doc_feats.append(vec)
                continue
            elif self.args.mode=='lc':
                if 'spans' in passage_cache[pid]:
                    span_annotations = passage_cache[pid]['spans']
                else:
                    text = pid2text[pid]
                    span_annotations = chunk_by_sentences(text, tokenizer=self.tokenizer)
                vec = self.late_chunk_pooling(feat, span_annotations)  # [D]
                vec = self.latechunk_proj(vec)
                
            if eval_mode:
                vec = vec.detach().cpu()
            
            doc_feats.append(vec)
        if self.args.mode in ['maxp']:
            return doc_feats
        doc_embeddings = torch.stack(doc_feats).to(self.device)
        doc_embeddings = self.norm_doc(doc_embeddings)
        doc_repr = self.doc_proj(doc_embeddings)
        doc_repr = F.normalize(doc_repr, dim=-1)
        return doc_repr
    

from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
import torch

class PassageSummarizer:

    def __init__(self, model_name="facebook/bart-large-cnn", device=None,
                 max_new_tokens=80, batch_size=8):
        self.device = device if isinstance(device, int) else (0 if torch.cuda.is_available() else -1)
        self.max_new_tokens = max_new_tokens
        self.batch_size = batch_size

        self.pipe = pipeline(
            task="summarization",
            model=AutoModelForSeq2SeqLM.from_pretrained(model_name),
            tokenizer=AutoTokenizer.from_pretrained(model_name),
            device=self.device
        )

    def summarize_passages(self, passages):
        if not passages:
            return ""
        summaries = []
        for i in range(0, len(passages), self.batch_size):
            batch = passages[i:i+self.batch_size]

            out = self.pipe(
                    batch,
                    max_length=self.max_new_tokens,
                    min_new_tokens=min(20, self.max_new_tokens//2),
                    min_length=20,
                    truncation=True,
                    no_repeat_ngram_size=3
                )

            summaries.extend([o["summary_text"] for o in out])

        return summaries