
import faiss
import numpy as np
import random
import re
import torch
import torch.nn.functional as F
def build_faiss_index(embeddings):
    embeddings = embeddings.cpu().numpy().astype("float32")
    index = faiss.IndexFlatIP(embeddings.shape[1])
    index.add(embeddings)
    return index 

def tokenize_with_offsets(text, tokenizer):
    return tokenizer(text, return_tensors='pt', return_offsets_mapping=True, truncation=True)

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
def chunk_by_sentences(input_text: str, tokenizer: callable):
    if not input_text or not input_text.strip():
        return [(0, 1)]  # Return a single span for empty text
    
    inputs = tokenizer(input_text, return_tensors='pt', return_offsets_mapping=True)
    punctuation_mark_id = tokenizer.convert_tokens_to_ids('.')
    sep_id = tokenizer.convert_tokens_to_ids('[SEP]')
    token_offsets = inputs['offset_mapping'][0]
    token_ids = inputs['input_ids'][0]
    
    # Handle case where token_ids is empty
    if len(token_ids) == 0:
        return [(0, 1)]
    
    chunk_positions = []
    for i, (token_id, (start, end)) in enumerate(zip(token_ids, token_offsets)):
        if token_id == punctuation_mark_id:
            # Check bounds before accessing next element
            if i + 1 < len(token_offsets) and i + 1 < len(token_ids):
                if (token_offsets[i + 1][0] - token_offsets[i][1] > 0 or 
                    token_ids[i + 1] == sep_id):
                    chunk_positions.append((i, int(start + 1)))
            else:
                # Last token is punctuation
                chunk_positions.append((i, int(start + 1)))
    
    # Ensure we have at least one span
    if not chunk_positions:
        return [(0, len(token_ids))]
    
    span_annotations = [(x[0], y[0]) for x, y in zip([(1, 0)] + chunk_positions[:-1], chunk_positions)]
    
    # Validate spans to ensure they're within bounds
    validated_spans = []
    for start, end in span_annotations:
        if start < len(token_ids) and end <= len(token_ids) and start < end:
            validated_spans.append((start, end))
    
    # If no valid spans, return the full sequence
    if not validated_spans:
        return [(0, len(token_ids))]
    
    return validated_spans



def supcon_loss(query, docs, pos_mask, temp=0.2):
    logits = query @ docs.T / temp
    
    # Clamp logits to prevent overflow/underflow
    logits = torch.clamp(logits, min=-50.0, max=50.0)
    
    loss = 0
    for i in range(query.size(0)):
        log_prob = logits[i] - torch.logsumexp(logits[i], dim=-1, keepdim=True)
        
        pos = pos_mask[i].float()
        pos_sum = pos.sum()
        
        # Handle case where no positive examples exist
        if pos_sum < 1e-6:
            continue  # Skip this sample if no positives
            
        loss_i = -torch.sum(pos * log_prob) / pos_sum
        
        if torch.isnan(loss_i):
            continue  # Skip this sample if loss is NaN
            
        loss += loss_i
    
    # Handle case where all samples were skipped
    if loss == 0:
        return torch.tensor(0.0, device=query.device, requires_grad=True)
    
    return loss / query.size(0)

class PassageVectoriser:
    def __init__(self, passage_cache, saa, doc_proj, device,norm_doc_fn):
        self.cache = passage_cache         # pid → [L_i, D]
        self.saa   = saa                   
        self.doc_proj = doc_proj
        self.device = device
        self.vec_cache = {}                
        
    def get(self, pid, train=False):
        if not train and pid in self.vec_cache:
            return self.vec_cache[pid]

        tokens = self.cache[pid].to(self.device)        # [L, D]
        ctx = torch.enable_grad() if train else torch.no_grad()
        with ctx:
            v = self.saa(tokens.unsqueeze(0)).squeeze(0)   # [D]
            v = self.norm_doc_fn(v)
            v = self.doc_proj(v)                           # [D]
            v = F.normalize(v, dim=-1)

        if not train:                                      
            self.vec_cache[pid] = v.cpu()
        return v


def normalize_answer(s):
    """Lowercase, remove punctuation/articles/extra whitespace."""
    import string
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)
    def white_space_fix(text):
        return ' '.join(text.split())
    def remove_punc(text):
        return ''.join(ch for ch in text if ch not in set(string.punctuation))
    def lower(text):
        return text.lower()
    return white_space_fix(remove_articles(remove_punc(lower(s))))

def compute_vqa_score(pred, answers):
    """
    Implements VQA consensus score:
    Score = min(# matching answers / 3, 1.0)
    """
    norm_pred = normalize_answer(pred)
    norm_answers = [normalize_answer(a) for a in answers]
    match_count = sum([1 if norm_pred == a else 0 for a in norm_answers])
    return min(match_count / 3.0, 1.0)

def is_supported_by_passages(pred, passages):
    """
    Returns True if the normalized prediction appears in any of the retrieved passages.
    """
    norm_pred = normalize_answer(pred)
    pattern = r'\b' + re.escape(norm_pred) + r'\b'
    return any(re.search(pattern, normalize_answer(p)) for p in passages)

def check_answer_in_passages(answers, passages):
    """
    Check if any of the gold answers appear in the retrieved passages.
    Returns True if at least one answer is found, False otherwise.
    """
    if not answers or not passages:
        return False
    
    # Normalize all answers and passages
    norm_answers = [normalize_answer(ans) for ans in answers]
    norm_passages = [normalize_answer(p) for p in passages]
    
    for ans in norm_answers:
        if ans.strip():  # Skip empty answers
            pattern = r'\b' + re.escape(ans) + r'\b'
            if any(re.search(pattern, passage) for passage in norm_passages):
                return True
    return False
