import torch
import os
import argparse
import numpy as np
import wandb
from datasets import load_dataset, load_from_disk
from tqdm import tqdm
from PIL import Image
import random
from torch.utils.data import DataLoader
import json
from utils import normalize_answer, compute_vqa_score, is_supported_by_passages, check_answer_in_passages
from utils import supcon_loss, set_seed, build_faiss_index
from utils import colbert_score_matrix
from collections import defaultdict
from model import Generator
import torch.nn.functional as F
from dataloader import OKVQADataset, GeneratorDataset, EVQADataset
from torch.nn.utils.rnn import pad_sequence

def custom_collate_fn(batch):
    images, questions, pos_ids, img_key, answers = zip(*batch)
    return images, questions, pos_ids, img_key, answers

def evaluate_retriever(args, retriever,  test_data, passage_cache, processor, device, all_passage_ids,img2text,pid2text):
    retriever.eval()
    chunk_to_parent = [] 

    if args.mode=='maxp':
        all_chunks = []
        for pid in all_passage_ids:
            passage_data = passage_cache[pid]
            if 'num_chunks' in passage_data:
                chunk_embeddings = passage_data['embedding']  # [num_chunks, D]
                all_chunks.append(chunk_embeddings)

                chunk_to_parent.extend([pid] * passage_data['num_chunks'])
            else:
                all_chunks.append(passage_data['embedding'].unsqueeze(0))
                chunk_to_parent.append(pid)

        if len(all_chunks) == 0:
            passage_vecs = torch.empty(0, retriever.doc_proj[0].in_features)
        else:
            passage_vecs = torch.cat(all_chunks, dim=0).to(device)  # [total_chunks, D]
            with torch.no_grad():
                passage_vecs = retriever.norm_doc(passage_vecs)
                passage_vecs = retriever.doc_proj(passage_vecs)
                passage_vecs = F.normalize(passage_vecs, dim=-1)
        faiss_index = build_faiss_index(passage_vecs.detach())
    elif args.mode == 'lc':
        passage_vecs = retriever.encode_documents(passage_cache, all_passage_ids, eval_mode=True,pid2text=pid2text).detach()
        faiss_index = build_faiss_index(passage_vecs)
    elif args.mode == 'maxp':
        passage_vecs = retriever.encode_documents(passage_cache, all_passage_ids, eval_mode=True, pid2text=pid2text)
    else:
        passage_vecs = retriever.encode_documents(passage_cache, all_passage_ids, eval_mode=True).detach()
        faiss_index = build_faiss_index(passage_vecs)
    ks = [1, 5, 10]
    recall_scores_by_k = {k: [] for k in ks}
    prr_scores_by_k = {k: [] for k in ks}
    answer_recall_scores_by_k = {k: [] for k in ks}
    for images, questions, pos_ids_list, img_key, answers_list in tqdm(test_data, desc=f"Evaluating Retriever"):
        
        if args.model=='ravqa':
            aux_text_batch = [img2text.get(k, "") for k in img_key]
            query_input = [f"Question: {q} Context: {a}" for q, a in zip(questions, aux_text_batch)]
            q_vec = retriever.encode_queries(query_input)
        elif args.model in ['clip', 'longclip']:
            images = torch.stack(images)
            images = images.to(device)
            # print(images.shape, questions)
            q_vec = retriever.encode_queries(images, questions)
        else:
            images = torch.stack(images)
            images = images.to(device)
            
            q_vec = retriever.encode_queries(images, questions)

        q_np = q_vec.detach().cpu().numpy().astype("float32")
        max_k = max(ks)
        _, I = faiss_index.search(q_np, max_k)
        for i in range(len(questions)):
            if args.mode == 'maxp':
                chunk_indices = I[i].tolist()
                parent_ids = list(set(chunk_to_parent[j] for j in chunk_indices))
                topk_passage_ids = parent_ids[:max(ks)]
            else:
                topk_passage_ids = (
                [all_passage_ids[j] for j in I[i].tolist()]
                )
            try:
                pos_ids = set(int(pid) for pid in pos_ids_list[i])
            except:
                pos_ids = pos_ids_list[i][0]
            for k in ks:
                top_k_ids = topk_passage_ids[:k]
                hit_count = sum(1 for pid in top_k_ids if pid in pos_ids)
                recall = hit_count / len(pos_ids) if len(pos_ids) > 0 else 0.0
                prr = min(hit_count, 1)
                recall_scores_by_k[k].append(recall)
                prr_scores_by_k[k].append(prr)
                
                sample_answers = answers_list[i]
                top_k_passages = [pid2text[pid] for pid in top_k_ids if pid in pid2text]
                answer_found = check_answer_in_passages(sample_answers, top_k_passages)
                answer_recall_scores_by_k[k].append(int(answer_found))
            
    results = {}
    for k in ks:
        avg_recall = np.mean(recall_scores_by_k[k])
        avg_prr = np.mean(prr_scores_by_k[k])
        avg_answer_recall = np.mean(answer_recall_scores_by_k[k])
        results[f"Recall@{k}"] = avg_recall
        results[f"PRRecall@{k}"] = avg_prr
        results[f"AnswerRecall@{k}"] = avg_answer_recall
        print(f"Recall@{k}: {avg_recall:.4f}")
        print(f"PRRecall@{k}: {avg_prr:.4f}")
        print(f"AnswerRecall@{k}: {avg_answer_recall:.4f}")
    return results
def cache_passage_embeddings(args, passages, tokenizer, model, device, batch_size, cache_path, mode,
                             chunk_size=77, stride_ratio=0.5):
    model.eval()
    model.to(device)
    stride = int(chunk_size * (1 - stride_ratio))
    cache = {}

    for i in tqdm(range(0, len(passages), batch_size), desc="Caching passages"):
        batch = passages[i:i + batch_size]
        texts = batch["passage_content"]
        ids = batch["passage_id"]

        for pid, text in zip(ids, texts):
            tokens = tokenizer(text, return_tensors='pt', truncation=False, return_offsets_mapping=True)
            input_ids = tokens['input_ids'][0]
            attention_mask = tokens['attention_mask'][0]
            offsets = tokens['offset_mapping'][0]
            if mode == 'cls':
                # Standard CLS-style representation from the pooler_output
                if args.model in ['clip', 'longclip']:
                    inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(device)    
                    with torch.no_grad():
                        outputs = model.text_encoder(**inputs, return_dict=True)
                        pooled = outputs.pooler_output.squeeze(0).cpu()  # [D]
                elif args.model in ['ravqa']:    
                    inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(device)    
                    with torch.no_grad():
                        pooled = model.doc_encoder(**inputs).pooler_output  # [B, H]
                try:
                    cache[int(pid)] = {"embedding": pooled, "text": text}
                except:
                    cache[pid] = {"embedding": pooled, "text": text}
                continue 
            
            if mode == 'maxp':
                chunk_embeddings = []
                chunk_id = 0
                for start in range(0, len(input_ids), stride):
                    end = start + chunk_size
                    if start >= len(input_ids):
                        break

                    input_chunk = input_ids[start:end]
                    attn_chunk = attention_mask[start:end]

                    if len(input_chunk) < chunk_size:
                        pad_len = chunk_size - len(input_chunk)
                        input_chunk = torch.nn.functional.pad(input_chunk, (0, pad_len), value=tokenizer.pad_token_id)
                        attn_chunk = torch.nn.functional.pad(attn_chunk, (0, pad_len), value=0)

                    input_chunk = input_chunk.unsqueeze(0).to(device)
                    attn_chunk = attn_chunk.unsqueeze(0).to(device)

                    with torch.no_grad():
                        if args.model in ['clip']:
                            outputs = model.text_encoder(
                                input_ids=input_chunk,
                                attention_mask=attn_chunk,
                                output_hidden_states=True,
                                return_dict=True
                            )
                        elif args.model in ['longclip']:
                            outputs = model.text_model(
                                input_ids=input_chunk,
                                attention_mask=attn_chunk,
                                output_hidden_states=True,
                                return_dict=True
                            )
                        elif args.model in ['ravqa']:
                            outputs = model.doc_encoder(
                                input_ids=input_chunk,
                                attention_mask=attn_chunk,
                                output_hidden_states=True,
                                return_dict=True
                            )
                        # Use mean pooling for the chunk representation
                        chunk_emb = outputs.hidden_states[-2].squeeze(0).mean(dim=0)  # [D]

                    chunk_embeddings.append(chunk_emb)

                if len(chunk_embeddings) == 0:
                    chunk_embeddings = [torch.zeros(model.config.hidden_size, device=device)]

                # Store individual chunk embeddings with parent mapping
                doc_embedding = torch.stack(chunk_embeddings)  # [num_chunks, D]
                try:
                    cache[int(pid)] = {"embedding": doc_embedding, "num_chunks": len(chunk_embeddings), "text": text}
                except:
                    cache[pid] = {"embedding": doc_embedding, "num_chunks": len(chunk_embeddings), "text": text}
                continue
                
            token_embeddings = defaultdict(list)

            for start in range(0, len(input_ids), stride):
                end = start + chunk_size
                input_chunk = input_ids[start:end]
                attn_chunk = attention_mask[start:end]
                offset_chunk = offsets[start:end]

                if len(input_chunk) < chunk_size:
                    pad_len = chunk_size - len(input_chunk)
                    input_chunk = torch.nn.functional.pad(input_chunk, (0, pad_len), value=tokenizer.pad_token_id)
                    attn_chunk = torch.nn.functional.pad(attn_chunk, (0, pad_len), value=0)
                    offset_chunk = torch.nn.functional.pad(offset_chunk, (0, 0, 0, pad_len), value=0)

                input_chunk = input_chunk.unsqueeze(0).to(device)
                attn_chunk = attn_chunk.unsqueeze(0).to(device)
                
                with torch.no_grad():
                    if args.model in ['clip']:
                        
                        outputs = model.text_encoder(
                            input_ids=input_chunk,
                            attention_mask=attn_chunk,
                            output_hidden_states=True,
                            return_dict=True
                        )
                    elif args.model in ['longclip']:
                        outputs = model.text_model(
                            input_ids=input_chunk,
                            attention_mask=attn_chunk,
                            output_hidden_states=True,
                            return_dict=True
                        )
                    elif args.model in ['ravqa']:
                        outputs = model.doc_encoder(
                            input_ids=input_chunk,
                            attention_mask=attn_chunk,
                            output_hidden_states=True,
                            return_dict=True
                        )
                    hidden = outputs.hidden_states[-2].squeeze(0).cpu()  # [T, D]

                for i, offset in enumerate(offset_chunk):
                    start_c, end_c = offset.tolist()
                    if start_c == 0 and end_c == 0:
                        continue
                    true_idx = start + i
                    token_embeddings[true_idx].append(hidden[i])

            doc_embedding = []
            for idx in sorted(token_embeddings):
                vecs = torch.stack(token_embeddings[idx], dim=0)
                doc_embedding.append(vecs.mean(dim=0))

            if len(doc_embedding) == 0:
                doc_embedding = torch.zeros((1, model.config.hidden_size))
            else:
                doc_embedding = torch.stack(doc_embedding)  # [T_total, D]

            try:
                cache[int(pid)] = {"embedding": doc_embedding,"text": text}
            except:
                cache[pid] = {"embedding": doc_embedding, "text": text}

    torch.save(cache, cache_path)
    
    return cache

def train_generator(
    args, retriever, generator, train_data,
    pid2text, device, passage_cache, processor, all_passage_keys,img2text,
    generator_model_path='checkpoint/generator.pt',
    epochs=3, batch_size=2
):
    # Freeze retriever weights
    retriever.eval()
    for p in retriever.parameters():
        p.requires_grad = False

    dataset = GeneratorDataset(args, train_data, pid2text, retriever, processor, passage_cache, device,all_passage_keys,img2text)

    
    def collate_fn_filter_none(batch):
        
        return tuple(zip(*batch))

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn_filter_none)
    optimizer = torch.optim.Adam(generator.parameters(), lr=5e-5) # Standard T5 learning rate

    for epoch in range(epochs):
        generator.train()
        total_loss, steps = 0, 0

        for image, input_texts, target_answers in tqdm(dataloader, desc=f"Generator Training Epoch {epoch+1}"):
            if input_texts is None:
                continue
            
            loss = generator(list(input_texts), list(target_answers), device)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            steps += 1

        print(f"[Generator Epoch {epoch+1}] Avg Loss: {total_loss / max(steps, 1):.4f}")
        torch.save(generator.state_dict(), generator_model_path)
def evaluate_generator(
    args, retriever, generator, test_data, pid2text,device, passage_cache, processor, \
                    all_passage_keys, img2text,generator_model_path):
    retriever.eval()
    generator.eval()

    exact_matches, vqa_scores, hsr_list, fsr_list = [], [], [], []
    passage_vecs_result = retriever.encode_documents(passage_cache, all_passage_keys,eval_mode=True,pid2text=pid2text)
    if isinstance(passage_vecs_result, list):
        passage_vecs = passage_vecs_result
    else:
        passage_vecs = passage_vecs_result.detach()
    try:
        all_passages = [pid2text[int(pid)] for pid in all_passage_keys]
    except:
        all_passages = [pid2text[pid] for pid in all_passage_keys]
    if isinstance(passage_vecs, list):
        averaged_passage_vecs = [
            (vec.mean(dim=0) if vec.dim() == 2 else vec.squeeze(0))
            for vec in passage_vecs
        ]
        passage_vecs_tensor = torch.stack(averaged_passage_vecs, dim=0)
        faiss_index = build_faiss_index(passage_vecs_tensor.detach())
    else:
        faiss_index = build_faiss_index(passage_vecs)
    for item in tqdm(test_data, desc="Evaluating Generator"):
        questions = [item["question"]]
        answers = item["answers"]
        try:
            img_key = [item["img_key_full"]]
        except:
            img_key = [item["image_id"]]
        
        if not answers:
            continue
        # Encode Query
        with torch.no_grad():
            img_path = os.path.join(args.image_root, item["img_path"])
            image = Image.open(img_path).convert("RGB")
            image_tensor = processor(images=image, return_tensors="pt")["pixel_values"].to(device)
            if args.model=='ravqa':
                aux_text_batch = [img2text.get(k, "") for k in img_key]
                
                query_input = [f"Question: {q} Context: {a}" for q, a in zip(questions, aux_text_batch)]
                q_vec = retriever.encode_queries(query_input)
                aux_text_batch = [img2text.get(k, "") for k in img_key]
            elif args.model in ['clip', 'longclip']:
                image_tensor = image_tensor.to(device)
                q_vec = retriever.encode_queries(image_tensor, questions)
            else:
                image_tensor = image_tensor.to(device)
                q_vec = retriever.encode_queries(image_tensor, questions)
            q_np = q_vec.detach().cpu().numpy().astype("float32")
            
        _, I = faiss_index.search(q_np, args.k_for_gen)
        topk_indices = I[0].tolist()
        topk_passages = [all_passages[i] for i in topk_indices]
        context = " ".join(topk_passages)
        input_text = f"Question: {questions} Context: {context}"
        
        pred = generator.generate([input_text], device)[0].strip()
        pred_norm = normalize_answer(pred)
        norm_answers = [normalize_answer(a) for a in answers]
        exact_matches.append(int(pred_norm in norm_answers))

        vqa_scores.append(compute_vqa_score(pred, answers))
        is_supported = is_supported_by_passages(pred, topk_passages)
        hsr_list.append(int(is_supported))
        fsr_list.append(int(is_supported and (pred_norm in norm_answers)))

    # Final metrics
    em_score = float(np.mean(exact_matches)) if exact_matches else 0.0
    vqa_score = float(np.mean(vqa_scores)) if vqa_scores else 0.0
    hsr = float(np.mean(hsr_list)) if hsr_list else 0.0
    fsr = float(np.mean(fsr_list)) if fsr_list else 0.0

    print(f"\n[Generator Evaluation] Exact Match: {em_score:.4f}")
    print(f"[Generator Evaluation] VQA Score: {vqa_score:.4f}")
    print(f"[Generator Evaluation] HSR: {hsr:.4f}")
    print(f"[Generator Evaluation] FSR: {fsr:.4f}\n")

    if args.wandb:
        wandb.log({"Gen_EM": em_score, "Gen_VQAScore": vqa_score, "Gen_HSR": hsr, "Gen_FSR": fsr})

    return em_score



def main(args, dataset, test_dataset, retriever, processor,device,passage_cache, \
         pid2text,img2text=None,img2text_test=None):
    
    loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True,collate_fn=custom_collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False,collate_fn=custom_collate_fn)
    optimizer = torch.optim.Adam(retriever.parameters(), lr=args.lr)
    total_steps = args.num_epochs * len(loader)
    global_step = 0
    all_passage_ids = list(passage_cache.keys())
    results  = evaluate_retriever(args, retriever,test_loader, passage_cache, processor,device, all_passage_ids,img2text_test,pid2text)
    for epoch in range(args.num_epochs):
        
        retriever.train()
        total_loss = 0 
        for images, questions, pos_ids_list, img_key, answers_list in tqdm(loader, desc=f"Epoch {epoch+1}"):

            if args.mode=='maxp':
                if args.model=='ravqa':
                    aux_text_batch = [img2text.get(k, "") for k in img_key]
                    query_input = [f"Question: {q} Context: {a}" for q, a in zip(questions, aux_text_batch)]
                    query_vecs = retriever.encode_queries(query_input)  # [B, D]
                elif args.model in ['clip', 'longclip']:
                    images = torch.stack(images)
                    images = images.to(device)
                    query_vecs = retriever.encode_queries(images, questions)
                else:
                    images = torch.stack(images)
                    images = images.to(device)
                    query_vecs = retriever.encode_queries(images, questions)

                try:
                    batch_pos_pids = {int(pid) for sub in pos_ids_list for pid in sub}
                except:
                    batch_pos_pids = [pid[0][0] for pid in pos_ids_list]
                    batch_pos_pids = set(batch_pos_pids)
                num_negatives = args.batch_size * 4
                available_negatives = list(set(all_passage_ids) - batch_pos_pids)
                neg_pids = random.sample(available_negatives, min(num_negatives, len(available_negatives)))
                all_pids_in_batch = list(batch_pos_pids) + neg_pids

                # Get chunk-level vectors for these pids
                doc_chunk_lists = retriever.encode_documents(passage_cache, all_pids_in_batch)
                # Average over chunks to get one vector per passage
                doc_vecs_list = [chunks.mean(dim=0) if chunks.dim() == 2 else chunks.squeeze(0) for chunks in doc_chunk_lists]
                doc_vecs = torch.stack(doc_vecs_list, dim=0).to(device)

                pid_to_idx = {pid: i for i, pid in enumerate(all_pids_in_batch)}
                pos_mask = torch.zeros(len(images), len(all_pids_in_batch), dtype=torch.bool).to(device)
                for i, pos_ids in enumerate(pos_ids_list):
                    for pid in pos_ids:
                        try:
                            pid = int(pid)
                        except:
                            pid = pid[0]
                        
                        if pid in pid_to_idx:
                            pos_mask[i, pid_to_idx[pid]] = 1

                loss = supcon_loss(query_vecs, doc_vecs, pos_mask)
            else:
                if args.model=='ravqa':
                    aux_text_batch = [img2text.get(k, "") for k in img_key]
                    
                    query_input = [f"Question: {q} Context: {a}" for q, a in zip(questions, aux_text_batch)]
                    
                    query_vecs = retriever.encode_queries(query_input)
                elif args.model in ['clip', 'longclip']:
                    images = torch.stack(images)
                    images = images.to(device)
                    query_vecs = retriever.encode_queries(images, questions)
                
                try:
                    batch_pos_pids = {int(pid) for sub in pos_ids_list for pid in sub}
                except:
                    batch_pos_pids = [pid[0][0] for pid in pos_ids_list]
                    
                    batch_pos_pids = set(batch_pos_pids)
                num_negatives = args.batch_size * 4 # As per the reference snippet
                available_negatives = list(set(all_passage_ids) - batch_pos_pids)
                neg_pids = random.sample(available_negatives, min(num_negatives, len(available_negatives)))
                all_pids_in_batch = list(batch_pos_pids) + neg_pids
                if args.mode=='lc':
                    doc_vecs = retriever.encode_documents(passage_cache, all_pids_in_batch, pid2text=pid2text)
                else:
                    doc_vecs = retriever.encode_documents(passage_cache, all_pids_in_batch)
            
                pid_to_idx = {pid: i for i, pid in enumerate(all_pids_in_batch)}
                pos_mask = torch.zeros(len(images), len(all_pids_in_batch), dtype=torch.bool).to(device)
                for i, pos_ids in enumerate(pos_ids_list):
                    for pid in pos_ids:
                        try:
                            pid = int(pid)
                        except:
                            pid = pid[0]
                        
                        if pid in pid_to_idx:
                            pos_mask[i, pid_to_idx[pid]] = 1
                
                loss = supcon_loss(query_vecs, doc_vecs, pos_mask)
            optimizer.zero_grad()
            loss.backward()

            torch.nn.utils.clip_grad_norm_(retriever.parameters(), max_norm=1.0)
            optimizer.step()
            
            total_loss += loss.item()
            global_step += 1

        print(f"Epoch {epoch+1} - Loss: {total_loss / len(loader):.4f}")
        if args.mode=='late':
            torch.save(retriever.state_dict(), args.model_save_path)
        else:
            if (epoch) % args.eval_period == 0 or (epoch + 1) == args.num_epochs:
                results  = evaluate_retriever(args, retriever,test_loader, passage_cache, processor,device, all_passage_ids,img2text_test,pid2text)
                
                print(f"Epoch {epoch+1} – Recall@1: {results['Recall@1']:.4f}, Recall@5: {results['Recall@5']:.4f}, Recall@10: {results['Recall@10']:.4f}")

                if args.wandb:
                    wandb.log({**results, "epoch": epoch + 1})
                torch.save(retriever.state_dict(), args.model_save_path)
    results  = evaluate_retriever(args, retriever,test_loader, passage_cache, processor,device, all_passage_ids,img2text_test,pid2text)
    print(f"Final Recall@1: {results['Recall@1']:.4f}, Recall@5: {results['Recall@5']:.4f}, Recall@10: {results['Recall@10']:.4f}")

    if args.wandb:
        wandb.log({f"Final {k}": v for k, v in results.items()})
if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--image_root", type=str, default="data/COCO")
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--num_epochs", type=int, default=100)
    parser.add_argument("--proj_dim", type=int, default=512)
    parser.add_argument("--d_attn", type=int, default=128)
    parser.add_argument("--eval_period", type=int, default=10)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--lr", type=float, default=1e-4)
    
    parser.add_argument('--dataset', default='okvqa')
    parser.add_argument('--mode', default='cls')
    parser.add_argument('--model', default='clip')
    parser.add_argument("--gate", type=str, default="none")
    parser.add_argument('--K', type=int, default=0, help='Number of frequency bins for STC')
    parser.add_argument('--eval_mode', action='store_true', help='Track gradient norm during training')
    parser.add_argument('--normalize_positions', action='store_true', help='Normalize positional encoding to [0,1]')
    parser.add_argument('--new_test', action='store_true')
    parser.add_argument('--query_mode', default='no_pooler')
    parser.add_argument("--temp", type=float, default=0.02)

    parser.add_argument("--wandb", action="store_true")
    parser.add_argument("--gen_only", action="store_true")
    # Arguments for Generator
    parser.add_argument('--k_for_gen', type=int, default=5, help='Number of passages for generator context')
    parser.add_argument('--gen_epochs', type=int, default=10, help='Number of epochs for generator training')
    parser.add_argument('--gen_batch_size', type=int, default=2, help='Batch size for generator training')

    args, _ = parser.parse_known_args()
    if args.dataset=='evqa':
        args.image_root = args.image_root.replace("COCO","EVQA")
    if args.model=='ravqa':
        chunk_size=512
    elif args.model=='clip':
        chunk_size=77
    elif args.model=='longclip':
        chunk_size=248  

    set_seed(args.seed)
    args.cache_path = f"cache/cache_{args.dataset}_{args.model}.pt"
    if args.mode in ['cls','mean','lc','maxp']:
        run_name = f"train_{args.mode}_{args.dataset}_{args.model}_{args.lr}_{args.batch_size}_{args.num_epochs}_{args.proj_dim}"
    elif args.mode=='stc':
        run_name = f"train_stc_{args.dataset}_{args.model}_{args.lr}_{args.batch_size}_{args.num_epochs}_{args.proj_dim}_K{args.K}_gate{args.gate}"
    elif args.mode=='sat':
        run_name = f"train_sat_{args.dataset}_{args.model}_{args.lr}_{args.batch_size}_{args.num_epochs}_{args.proj_dim}_K{args.K}_gate{args.gate}"
        if args.d_attn != 128:
            run_name = run_name + f"_d_attn{args.d_attn}"
    args.model_save_path = os.path.join('checkpoint',run_name+'.pt')

    if args.wandb:
        if args.gen_only:
            wandb.init(
        project="Aug_VisRAG_gen",
        config=vars(args),
        name=run_name)
        else:
            wandb.init(
            project="new_VisRAG",
            config=vars(args),
            name=run_name)


    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if args.model == 'clip':
        from transformers import CLIPProcessor
        from model import CLIPRetriever
        processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        retriever = CLIPRetriever(args,device=device).to(device)
        tokenizer = processor.tokenizer
    elif args.model == 'longclip':
        from transformers import CLIPProcessor
        from model import LongCLIPRetriever
        processor = CLIPProcessor.from_pretrained("zer0int/LongCLIP-GmP-ViT-L-14")
        retriever = LongCLIPRetriever(args, clip_model_name="zer0int/LongCLIP-GmP-ViT-L-14", device=device).to(device)
        tokenizer = processor.tokenizer
    elif args.model == 'ravqa':
        from transformers import BertTokenizerFast
        from model import RAVQARetriever

        # This will not be used.
        from transformers import CLIPProcessor
        processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        
        retriever = RAVQARetriever(args,device=device).to(device)
        tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
    
    
    print("Generator loaded early to claim GPU resources")
    
    img2text = None
    img2text_test = None
    if args.mode=='cls':
        args.cache_path = args.cache_path.replace(".pt","_cls.pt")
    elif args.mode=='maxp':
        args.cache_path = args.cache_path.replace(".pt","_maxp.pt")
    if args.model in ['ravqa','flmr']:

        img2text = {}
        with open(f"./{args.dataset}_image_to_text/train_image_text.jsonl", "r") as fin:
            for line in fin:
                obj = json.loads(line)
                try:
                    img2text[obj["img_key_full"]] = obj["image_text"]
                except:
                    img2text[obj["image_id"]] = obj["image_text"]
        img2text_test = {}
        with open(f"./{args.dataset}_image_to_text/test_image_text.jsonl", "r") as fin:
            for line in fin:
                obj = json.loads(line)
                try:
                    img2text_test[obj["img_key_full"]] = obj["image_text"]
                except:
                    img2text_test[obj["image_id"]] = obj["image_text"]
    
    
    if args.dataset in ['evqa','infoseek']:
        # Replace with string directly if args.dataset is not available
        dataset_key = f"{args.dataset.upper()}_data"  
        base_path = "filtered_landmark_dataset"

        # Helper function for filtering and saving
        def load_or_filter_split(split_name):
            split_path = os.path.join(base_path, split_name)
            
            if os.path.exists(split_path):
                print(f"Loading {split_name} split from disk...")
                return load_from_disk(split_path)
            else:
                print(f"{split_name} split not found. Loading and filtering from original dataset...")
                data_split = load_dataset("BByrneLab/multi_task_multi_modal_knowledge_retrieval_benchmark_M2KR", dataset_key)[split_name]
                filtered_split = data_split.filter(lambda x: 'landmark' in x['img_id'])
                
                os.makedirs(split_path, exist_ok=True)
                filtered_split.save_to_disk(split_path)
                print(f"Saved filtered {split_name} split to disk.")
                return filtered_split

        
        train_data = load_or_filter_split("train")
        test_data = load_or_filter_split("test")

        print(f"Filtered train samples: {len(train_data)}")
        print(f"Filtered test samples: {len(test_data)}")
    else:
        train_data = load_dataset("BByrneLab/multi_task_multi_modal_knowledge_retrieval_benchmark_M2KR", f"{args.dataset.upper()}_data")["train"]
        test_data = load_dataset("BByrneLab/multi_task_multi_modal_knowledge_retrieval_benchmark_M2KR", f"{args.dataset.upper()}_data")["test"]
    if args.new_test:
        args.cache_path = args.cache_path.replace(".pt",'_new.pt')
    if not os.path.exists(args.cache_path):
        if args.new_test:
            passages = load_dataset("BByrneLab/multi_task_multi_modal_knowledge_retrieval_benchmark_M2KR", f"{args.dataset.upper()}_passages")["test_passages"]
        else:
            passages = load_dataset("BByrneLab/multi_task_multi_modal_knowledge_retrieval_benchmark_M2KR", f"{args.dataset.upper()}_passages")["train_passages"]
        if args.dataset  in ['evqa','infoseek']:
            def collect_passage_ids(data):
                all_ids = set()
                for item in data:
                    all_ids.update(item["pos_item_ids"])
                return all_ids
            
            train_pids = collect_passage_ids(train_data)
            test_pids = collect_passage_ids(test_data)
            used_pids = train_pids.union(test_pids)

            original_ids = passages["passage_id"]
            
            passages = passages.add_column("original_passage_id", original_ids)
            
            filtered_passages = passages.filter(lambda x: x["original_passage_id"] in used_pids)
            
            filtered_passages = filtered_passages.remove_columns(["passage_id"])
            filtered_passages = filtered_passages.add_column("passage_id", filtered_passages['original_passage_id'])
            
            passages = filtered_passages
        passage_cache = cache_passage_embeddings(args, passages, tokenizer, retriever, device, args.batch_size, args.cache_path, args.mode,chunk_size = chunk_size)
    else:
        passage_cache = torch.load(args.cache_path)

    pid2text = {pid: v["text"] for pid, v in passage_cache.items()}
    
    if args.dataset=='okvqa':
        dataset = OKVQADataset(args, train_data,  processor, args.image_root)
        test_dataset = OKVQADataset(args, test_data,  processor, args.image_root)
        
    if args.dataset in ['evqa','infoseek']:
        
        dataset = EVQADataset(args, train_data,  processor, args.image_root)
        test_dataset = EVQADataset(args, test_data,  processor, args.image_root)
        
        
    if not args.gen_only:
        if args.eval_mode:
            test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False,collate_fn=custom_collate_fn)
            all_passage_ids = list(passage_cache.keys())
            checkpoint = torch.load(args.model_save_path, map_location=device)
            retriever.load_state_dict(checkpoint, strict=False)
            results  = evaluate_retriever(args, retriever,test_loader, passage_cache, processor,device, all_passage_ids,img2text_test,pid2text)
            if args.wandb:
                wandb.log({f"Final {k}": v for k, v in results.items()})
        else:
            main(args,dataset, test_dataset, retriever, processor, device, passage_cache, pid2text,img2text,img2text_test)

    if args.gen_only:
        # --- Start of Generator Training ---
        print("\n--- Initializing Generator Training ---")
        all_passage_keys = list(passage_cache.keys())
        
        print(f"Loading trained retriever from {args.model_save_path}")
        checkpoint = torch.load(args.model_save_path, map_location=device)
        
        retriever.load_state_dict(checkpoint, strict=False)
        generator = Generator().to(device)
        
        generator_model_path = args.model_save_path.replace('.pt', '_generator.pt')
    
        if args.eval_mode:
            generator.load_state_dict(torch.load(generator_model_path, map_location=device))
            evaluate_generator(args, retriever, generator,test_data, pid2text, device, passage_cache, processor, \
                            all_passage_keys, img2text_test,generator_model_path)
        else:
                
            train_generator(args, retriever, generator,train_data, pid2text, device, passage_cache, processor, \
                            all_passage_keys, img2text,generator_model_path,epochs=args.gen_epochs,batch_size = args.gen_batch_size)
            evaluate_generator(args, retriever, generator,test_data, pid2text, device, passage_cache, processor, \
                            all_passage_keys, img2text_test,generator_model_path)
    