
from PIL import Image
import random
from torch.utils.data import Dataset
import torch.nn.functional as F
import torch
import os
from utils import build_faiss_index, normalize_answer
import json

class OKVQADataset(Dataset):
    def __init__(self,args,  data, processor, image_root):
        self.data = data
        self.processor = processor
        self.image_root = image_root
        self.model = args.model
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image_path = os.path.join(self.image_root, item["img_path"])
        image = Image.open(image_path).convert("RGB")
        image = self.processor(images=image, return_tensors="pt")["pixel_values"].squeeze(0)
        pos_ids = [int(pid) for pid in item["pos_item_ids"]]
        img_key = item["img_key_full"]
        answers = item.get("answers", [])  
        return image,  item["question"], pos_ids, img_key, answers
    def get_img_path_from_idx(self, idx):
        return self.data[idx]["img_path"]

class EVQADataset(Dataset):
    def __init__(self,args,  data, processor, image_root):
        self.data = data
        self.processor = processor
        self.image_root = image_root
        self.model = args.model
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image_path = os.path.join(self.image_root, item["img_path"])
        image = Image.open(image_path).convert("RGB")
        image = self.processor(images=image, return_tensors="pt")["pixel_values"].squeeze(0)
        pos_ids = [item["pos_item_ids"]]
        img_key = item["image_id"]
        answers = item.get("answers", [])  
        return image,  item["question"], pos_ids, img_key, answers
    
    def get_img_path_from_idx(self, idx):
        return self.data[idx]["img_path"]
class RAGDataset(Dataset):
    def __init__(self, args, data):
        self.model = args.model
        
        self.data = []
        for item in data:
            try:
                _ = item['output'][0]['provenance'][0]['wikipedia_id']
                self.data.append(item)
                        
            except (KeyError, IndexError, ValueError):
                continue  

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        pos_id = item['output'][0]['provenance'][0]['wikipedia_id']
        return item["input"], [pos_id]


class GeneratorDataset(Dataset):
    def __init__(self, args, data, pid2text, retriever, processor, passage_cache, device, all_passage_ids, img2text):
        self.data = data
        self.pid2text = pid2text
        self.retriever = retriever
        self.processor = processor
        self.image_root = args.image_root
        self.device = device
        self.args = args
        self.img2text = img2text
        self.all_passage_ids = all_passage_ids
        self.k = args.k_for_gen
        self.passage_cache = passage_cache  


        
        passage_vecs_result = self.retriever.encode_documents(passage_cache, all_passage_ids, eval_mode=True, pid2text=pid2text)
        if isinstance(passage_vecs_result, list):
            self.passage_vecs = passage_vecs_result
            if args.mode == 'maxp':
                self._precompute_maxp_embeddings()
        else:
            self.passage_vecs = passage_vecs_result.detach()
            self.faiss_index = build_faiss_index(self.passage_vecs)

    def _precompute_maxp_embeddings(self):
        import torch
        
        self.maxp_chunk_embeddings = []
        self.maxp_chunk_counts = [] 
        
        for i, pid in enumerate(self.all_passage_ids):
            doc_chunks = self.passage_vecs[i]  # [num_chunks, D]
            if isinstance(doc_chunks, torch.Tensor):
                chunk_tensor = doc_chunks.to(self.device)
            else:
                chunk_tensor = torch.tensor(doc_chunks).to(self.device)
            
            if chunk_tensor.dim() == 1:
                chunk_tensor = chunk_tensor.unsqueeze(0)  # [1, D]
            elif chunk_tensor.dim() > 2:
                chunk_tensor = chunk_tensor.view(-1, chunk_tensor.shape[-1])
            
            self.maxp_chunk_embeddings.append(chunk_tensor)
            self.maxp_chunk_counts.append(chunk_tensor.shape[0])
        
        if not self.maxp_chunk_embeddings:
            self.maxp_all_chunks = torch.empty(0, 0, 0, device=self.device)
            self.maxp_chunk_mask = torch.empty(0, 0, dtype=torch.bool, device=self.device)
            return
            
        max_chunks = max(self.maxp_chunk_counts)
        chunk_dim = self.maxp_chunk_embeddings[0].shape[1]
        
        padded_chunks = []
        for chunk_tensor in self.maxp_chunk_embeddings:
            if chunk_tensor.shape[0] < max_chunks:
                padding = torch.zeros(max_chunks - chunk_tensor.shape[0], chunk_dim, 
                                    device=chunk_tensor.device, dtype=chunk_tensor.dtype)
                padded_chunk = torch.cat([chunk_tensor, padding], dim=0)
            else:
                padded_chunk = chunk_tensor
            padded_chunks.append(padded_chunk)
        
        self.maxp_all_chunks = torch.stack(padded_chunks, dim=0)
        
        self.maxp_chunk_mask = torch.zeros(len(self.all_passage_ids), max_chunks, device=self.device)
        for i, count in enumerate(self.maxp_chunk_counts):
            self.maxp_chunk_mask[i, :count] = 1.0

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        question = item["question"]
        answers = item["answers"]
        try:
            img_key = item["img_key_full"]
        except:
            img_key = item["image_id"]
        if not answers:
            return None, None
        
        target_answer = normalize_answer(item.get("gold_answer", answers[0]))


        
        with torch.no_grad():
            img_path = os.path.join(self.image_root, item["img_path"])
            try:
                image = Image.open(img_path).convert("RGB")
            except FileNotFoundError:
                print(f"Warning: Image not found at {img_path}. Skipping item.")
                return None, None

            image_tensor = self.processor(images=image, return_tensors="pt")["pixel_values"].to(self.device)

            if self.args.model == 'ravqa':
                aux_text = self.img2text.get(img_key, "")
                query_input = f"Question: {question} Context: {aux_text}"
                q_vec = self.retriever.encode_queries([query_input])
            
            else:
                q_vec = self.retriever.encode_queries(image_tensor, [question])

            q_np = q_vec.detach().cpu().numpy().astype("float32")
        if self.args.mode == 'maxp':
            q = q_vec
            if q.dim() == 1:
                q = q.unsqueeze(0)  # [1, D]
            q = q.to(self.maxp_all_chunks.device)

            num_docs = self.maxp_all_chunks.shape[0]
            doc_batch_size = 256  # tune if needed
            per_doc_max_scores = []

            for start in range(0, num_docs, doc_batch_size):
                end = min(start + doc_batch_size, num_docs)
                chunk_slice = self.maxp_all_chunks[start:end]  # [b, max_chunks, D]
                mask_slice = self.maxp_chunk_mask[start:end]   # [b, max_chunks]

                chunk_t = chunk_slice.transpose(-1, -2)  # [b, D, max_chunks]
                sims = torch.einsum('nd,bdk->nbk', q, chunk_t)
                sims = sims.masked_fill(mask_slice.unsqueeze(0) == 0, float('-inf'))
                max_over_chunks = sims.max(dim=-1).values  # [1, b]
                per_doc_max_scores.append(max_over_chunks)

            max_similarities = torch.cat(per_doc_max_scores, dim=1).squeeze(0)  # [num_docs]

            _, topk_indices = torch.topk(max_similarities, self.k)
            topk_passages = [self.pid2text[self.all_passage_ids[i]] for i in topk_indices.tolist()]
        else:
            _, I = self.faiss_index.search(q_np, self.k)
            topk_indices = I[0].tolist()
            topk_passages = [self.pid2text[self.all_passage_ids[i]] for i in topk_indices]
       
        context = " ".join(topk_passages)
        input_text = f"Question: {question} Context: {context}"

        return image_tensor, input_text, target_answer

class GeneratorRAGDataset(Dataset):
    def __init__(self, args, data, pid2text, retriever, passage_cache, device, all_passage_id, summary_map_path):
        self.args = args
        self.data = data
        self.pid2text = pid2text
        self.retriever = retriever
        self.device = device
        self.all_passage_ids = all_passage_id
        self.k = args.k_for_gen

        print(f"Loading summary map from {summary_map_path}...")
        with open(summary_map_path, 'r') as f:
            self.id_to_summary = json.load(f)
        print("Summary map loaded.")

        print("Encoding documents and building FAISS index for training...")
        
        if self.args.mode == 'maxp':
            doc_embeddings_list = self.retriever.encode_documents(
                passage_cache, all_passage_id, pid2text=pid2text
            )
            if len(doc_embeddings_list) == 0:
                self.passage_vecs = torch.empty(0, self.retriever.doc_proj[0].in_features)
                self.chunk_to_parent = []
            else:
                processed_embeddings = []
                for i, emb in enumerate(doc_embeddings_list):
                    if emb.dim() == 1:
                        emb = emb.unsqueeze(0)  # [1, embedding_dim]
                    elif emb.dim() == 0:
                        emb = emb.unsqueeze(0).unsqueeze(0)  # [1, 1]
                   
                    processed_embeddings.append(emb)
                
                self.passage_vecs = torch.cat(processed_embeddings, dim=0).detach()
                
                self.chunk_to_parent = []
                for i, pid in enumerate(all_passage_id):
                    num_chunks = processed_embeddings[i].shape[0]
                    self.chunk_to_parent.extend([pid] * num_chunks)
        else:
            self.passage_vecs = self.retriever.encode_documents(
                passage_cache, all_passage_id, pid2text=pid2text
            ).detach()
            self.chunk_to_parent = None
        
        self.faiss_index = build_faiss_index(self.passage_vecs)
        print("FAISS index built.")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        question = item["input"]
        answers = item['output'][0]['answer']
    
        if not answers:
            return None, None
        
        target_answer = normalize_answer(answers)

        
        with torch.no_grad():
            q_vec = self.retriever.encode_queries([question])
            q_np = q_vec.detach().cpu().numpy().astype("float32")

        
        _, topk_indices = self.faiss_index.search(q_np, self.k)
        
        if self.args.mode == 'maxp':
            chunk_indices = topk_indices[0].tolist()
            retrieved_pids = list(set(self.chunk_to_parent[j] for j in chunk_indices))
        else:
            retrieved_pids = [self.all_passage_ids[i] for i in topk_indices[0]]

        
        context_summaries = [self.id_to_summary[pid] for pid in retrieved_pids]
        
        context = " ".join(context_summaries)

        input_text = f"Question: {question} Context: {context}"

        return input_text, target_answer
    
class EmbeddingDataset(Dataset):
    def __init__(self, embeddings, labels):
        self.embeddings = embeddings
        self.labels = labels

    def __len__(self):
        return len(self.embeddings)

    def __getitem__(self, idx):
        return {
            'embeddings': self.embeddings[idx],
            'labels': self.labels[idx],
        }


class LateChunkingEmbeddingDataset(Dataset):
    def __init__(self, embeddings, labels,spans):
        self.spans = spans
        self.embeddings = embeddings
        self.labels = labels

    def __len__(self):
        return len(self.embeddings)

    def __getitem__(self, idx):
        return {
            'spans': self.spans[idx],
            'embeddings': self.embeddings[idx],
            'labels': self.labels[idx],
        }
