import torch
import os
import argparse
import numpy as np
import wandb
from datasets import load_from_disk, load_dataset
from tqdm import tqdm
import random
from torch.utils.data import DataLoader
import json
from utils import supcon_loss, set_seed, build_faiss_index, normalize_answer
from collections import defaultdict
from model import Generator
import torch.nn.functional as F
from dataloader import RAGDataset, GeneratorRAGDataset
import warnings
from collections import Counter

warnings.filterwarnings("ignore")
import time
def custom_collate_fn(batch):
    questions, pos_ids = zip(*batch)
    return questions, pos_ids

def evaluate_retriever(args, retriever,  test_data, passage_cache,  device, all_passage_ids,pid2text):
    retriever.eval()
    chunk_to_parent = []  # Maps chunk index to parent passage ID (used for maxp mode)

    if args.mode=='maxp':
        # For MaxP, create a flat list of all chunks and track parent mapping
        all_chunks = []

        for pid in all_passage_ids:
            passage_data = passage_cache[pid]
            if 'num_chunks' in passage_data:
                # MaxP mode: each passage has multiple chunks
                chunk_embeddings = passage_data['embedding']  # [num_chunks, D]
                all_chunks.append(chunk_embeddings)
                # Map each chunk to its parent passage
                chunk_to_parent.extend([pid] * passage_data['num_chunks'])
            else:
                # Fallback: treat as single chunk
                all_chunks.append(passage_data['embedding'].unsqueeze(0))
                chunk_to_parent.append(pid)

        if len(all_chunks) == 0:
            passage_vecs = torch.empty(0, retriever.doc_proj[0].in_features)
        else:
            passage_vecs = torch.cat(all_chunks, dim=0).to(device)  # [total_chunks, D]
            with torch.no_grad():
                passage_vecs = retriever.norm_doc(passage_vecs)
                passage_vecs = retriever.doc_proj(passage_vecs)
                passage_vecs = F.normalize(passage_vecs, dim=-1)
        faiss_index = build_faiss_index(passage_vecs.detach())
    elif args.mode == 'lc':
        passage_vecs = retriever.encode_documents(passage_cache, all_passage_ids, eval_mode=True,pid2text=pid2text)
        if isinstance(passage_vecs, list):
            passage_vecs = torch.cat(passage_vecs, dim=0)
        else:
            passage_vecs = passage_vecs.detach()
        faiss_index = build_faiss_index(passage_vecs)
    else:
        passage_vecs = retriever.encode_documents(passage_cache, all_passage_ids, eval_mode=True)
        if isinstance(passage_vecs, list):
            passage_vecs = torch.cat(passage_vecs, dim=0)
        else:
            passage_vecs = passage_vecs.detach()
        faiss_index = build_faiss_index(passage_vecs)
    ks = [1, 5, 10]
    recall_scores_by_k = {k: [] for k in ks}
    prr_scores_by_k = {k: [] for k in ks}
    
    for questions, pos_ids_list in tqdm(test_data, desc=f"Evaluating Retriever"):
    
        with torch.no_grad():
            q_vec = retriever.encode_queries(questions)
    
        q_np = q_vec.detach().cpu().numpy().astype("float32")
        max_k = max(ks)
        _, I = faiss_index.search(q_np, max_k)

        for i in range(len(questions)):
            if args.mode == 'maxp':
                # For MaxP, map chunk indices back to parent passage IDs
                chunk_indices = I[i].tolist()
                # Get unique parent passage IDs from the retrieved chunks
                parent_ids = list(set(chunk_to_parent[j] for j in chunk_indices))
                topk_passage_ids = parent_ids[:max(ks)]
            else:
                topk_passage_ids = ([all_passage_ids[j] for j in I[i].tolist()])

            pos_ids = set(pid for pid in pos_ids_list[i])

            for k in ks:
                top_k_ids = topk_passage_ids[:k]
                hit_count = sum(1 for pid in top_k_ids if pid in pos_ids)
                recall = hit_count / len(pos_ids) if len(pos_ids) > 0 else 0.0
                prr = min(hit_count, 1)
                recall_scores_by_k[k].append(recall)
                prr_scores_by_k[k].append(prr)
            
    results = {}
    for k in ks:
        avg_recall = np.mean(recall_scores_by_k[k])
        avg_prr = np.mean(prr_scores_by_k[k])
        results[f"Recall@{k}"] = avg_recall
        results[f"PRRecall@{k}"] = avg_prr
        print(f"Recall@{k}: {avg_recall:.4f}")
        print(f"PRRecall@{k}: {avg_prr:.4f}")
    return results
def cache_passage_embeddings(args, passages, tokenizer, model, device, batch_size, cache_path, mode,
                             chunk_size=77, stride_ratio=0.5):
    model.eval()
    model.to(device)
    stride = int(chunk_size * (1 - stride_ratio))
    cache = {}
    text_dict = {}
        
    for i in tqdm(range(0, len(passages), batch_size), desc="Caching passages"):
        batch = passages[i:i + batch_size]
        texts = batch["text"]
        ids = batch["wikipedia_id"]
        
        for pid, text in zip(ids, texts):
            text = "".join(text['paragraph'])
            tokens = tokenizer(text, return_tensors='pt', truncation=False, return_offsets_mapping=True)
            input_ids = tokens['input_ids'][0]
            attention_mask = tokens['attention_mask'][0]
            offsets = tokens['offset_mapping'][0]
            if mode == 'cls':

                inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(device)
                with torch.no_grad():
                    pooled = model.doc_encoder(**inputs).pooler_output  # [B, H]
                cache[pid] = {"embedding": pooled.detach().cpu(), "text": text}
                continue 
            token_embeddings = defaultdict(list)

            for start in range(0, len(input_ids), stride):
                end = start + chunk_size
                input_chunk = input_ids[start:end]
                attn_chunk = attention_mask[start:end]
                offset_chunk = offsets[start:end]

                if len(input_chunk) < chunk_size:
                    pad_len = chunk_size - len(input_chunk)
                    input_chunk = torch.nn.functional.pad(input_chunk, (0, pad_len), value=tokenizer.pad_token_id)
                    attn_chunk = torch.nn.functional.pad(attn_chunk, (0, pad_len), value=0)
                    offset_chunk = torch.nn.functional.pad(offset_chunk, (0, 0, 0, pad_len), value=0)

                input_chunk = input_chunk.unsqueeze(0).to(device)
                attn_chunk = attn_chunk.unsqueeze(0).to(device)
                
                with torch.no_grad():

                    outputs = model.doc_encoder(
                            input_ids=input_chunk,
                            attention_mask=attn_chunk,
                            output_hidden_states=True,
                            return_dict=True
                        )
                    hidden = outputs.hidden_states[-2].squeeze(0).cpu()  # [T, D]

                for i, offset in enumerate(offset_chunk):
                    start_c, end_c = offset.tolist()
                    if start_c == 0 and end_c == 0:
                        continue
                    true_idx = start + i
                    token_embeddings[true_idx].append(hidden[i])

            if mode == 'maxp':
                # For MaxP, store individual chunk embeddings
                chunk_embeddings = []
                chunk_id = 0
                for start in range(0, len(input_ids), stride):
                    end = start + chunk_size
                    if start >= len(input_ids):
                        break

                    input_chunk = input_ids[start:end]
                    attn_chunk = attention_mask[start:end]

                    if len(input_chunk) < chunk_size:
                        pad_len = chunk_size - len(input_chunk)
                        input_chunk = torch.nn.functional.pad(input_chunk, (0, pad_len), value=tokenizer.pad_token_id)
                        attn_chunk = torch.nn.functional.pad(attn_chunk, (0, pad_len), value=0)

                    input_chunk = input_chunk.unsqueeze(0).to(device)
                    attn_chunk = attn_chunk.unsqueeze(0).to(device)

                    with torch.no_grad():
                        outputs = model.doc_encoder(
                            input_ids=input_chunk,
                            attention_mask=attn_chunk,
                            output_hidden_states=True,
                            return_dict=True
                        )
                        # Use mean pooling for the chunk representation
                        chunk_emb = outputs.hidden_states[-2].squeeze(0).mean(dim=0)  # [D]

                    chunk_embeddings.append(chunk_emb)

                if len(chunk_embeddings) == 0:
                    chunk_embeddings = [torch.zeros(model.config.hidden_size, device=device)]

                doc_embedding = torch.stack(chunk_embeddings)  # [num_chunks, D]
                cache[pid] = {"embedding": doc_embedding, "num_chunks": len(chunk_embeddings)}
            else:
                doc_embedding = []
                for idx in sorted(token_embeddings):
                    vecs = torch.stack(token_embeddings[idx], dim=0)
                    doc_embedding.append(vecs.mean(dim=0))

                if len(doc_embedding) == 0:
                    doc_embedding = torch.zeros((1, model.config.hidden_size))
                else:
                    doc_embedding = torch.stack(doc_embedding)  # [T_total, D]

                cache[pid] = {"embedding": doc_embedding}

            text_dict[pid] = text


    
    # Save the embeddings
    torch.save(cache, cache_path)

    # Save the texts separately
    text_cache_path = cache_path.replace('.pt', '_texts.json')
    with open(text_cache_path, 'w') as f:
        json.dump(text_dict, f, ensure_ascii=False, indent=2)
    return cache

def cache_passage_gtr_embeddings(args,passages,model,device,batch_size,cache_path,mode,chunk_size=384,stride_ratio=0.5):
    model.eval()
    model.to(device)
    tokenizer = model.tokenizer
    stride = int(chunk_size * (1 - stride_ratio))
    cache = {}
    text_dict = {}
    for passage in tqdm(passages, desc="Caching passages (GTR Full Length Sliding Window)"):
        text = "".join(passage["text"]['paragraph'])
        pid = passage["wikipedia_id"]
        text_dict[pid] = text

        if mode == 'cls':
            with torch.no_grad():
                sent_emb = model.doc_encoder.encode(
                    text, convert_to_tensor=True, device=str(device), output_value='sentence_embedding'
                )
            if sent_emb.dim() == 1:
                sent_emb = sent_emb.unsqueeze(0)
            cache[pid] = {"embedding": sent_emb.detach().cpu()}
            continue

        full_input_ids = tokenizer(text, truncation=False, return_tensors='pt')['input_ids'][0]
        
        if len(full_input_ids) < 1:
            embed_dim = model.doc_encoder.get_sentence_embedding_dimension()
            cache[pid] = {"embedding": torch.zeros((0, embed_dim))}
            continue

        token_embeddings = defaultdict(list)

        for start in range(0, len(full_input_ids), stride):
            end = start + chunk_size
            id_chunk = full_input_ids[start:end]
            
            text_chunk = tokenizer.decode(id_chunk, skip_special_tokens=True)
            if not text_chunk.strip():
                continue

            with torch.no_grad():
                chunk_hidden_states = model.doc_encoder.encode(
                    text_chunk, convert_to_tensor=True, device=str(device), output_value='token_embeddings'
                )

            for i in range(len(chunk_hidden_states)):
                true_idx = start + i
                if true_idx < len(full_input_ids):
                    token_embeddings[true_idx].append(chunk_hidden_states[i].cpu())

        final_doc_embedding = []
        embed_dim = model.doc_encoder.get_sentence_embedding_dimension()
        for idx in range(len(full_input_ids)):
            if token_embeddings[idx]:
                vecs = torch.stack(token_embeddings[idx], dim=0)
                final_doc_embedding.append(vecs.mean(dim=0))
            else:
                final_doc_embedding.append(torch.zeros(embed_dim))

        doc_embedding_tensor = torch.stack(final_doc_embedding) # [T_total, D]
        
        # For lc mode, also cache the span information to avoid recomputing
        if mode == 'lc':
            from utils import chunk_by_sentences
            span_annotations = chunk_by_sentences(text, tokenizer=tokenizer)
            cache[pid] = {"embedding": doc_embedding_tensor, "spans": span_annotations}
        else:
            cache[pid] = {"embedding": doc_embedding_tensor}

    # Save the cache and texts
    torch.save(cache, cache_path)
    text_cache_path = cache_path.replace('.pt', '_texts.json')
    with open(text_cache_path, 'w') as f:
        json.dump(text_dict, f, ensure_ascii=False, indent=2)

    return cache
def get_f1_stats(prediction, ground_truth):
    """
    Computes and returns the true positives, false positives, and false negatives.
    """
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    
    tp = sum(common.values())
    fp = len(prediction_tokens) - tp
    fn = len(ground_truth_tokens) - tp
    
    return tp, fp, fn

def f1_score(prediction, ground_truth):
    """Computes word-level F1 score for a single pair."""
    tp, fp, fn = get_f1_stats(prediction, ground_truth)

    if tp == 0:
        return 0.0
    
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1

def exact_match_score(prediction, ground_truth):
    """Computes exact match score."""
    return (normalize_answer(prediction) == normalize_answer(ground_truth))

# --- Main Evaluation Function ---

def evaluate_results(file_path):

    results = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            results.append(json.loads(line))

    if not results:
        print("Error: The results file is empty.")
        return

    macro_f1_scores = [] 
    em_scores = []
    total_tp, total_fp, total_fn = 0, 0, 0 # For micro-F1

    print(f"Evaluating {len(results)} prediction pairs...")
    for entry in tqdm(results, desc="Calculating Metrics"):
        question = entry.get("question")
        ground_truth = entry.get("ground_truth")
        prediction = entry.get("prediction")

        if not all([question, ground_truth, prediction]):
            print(f"Skipping malformed entry: {entry}")
            continue

        tp, fp, fn = get_f1_stats(prediction, ground_truth)
        
        # Accumulate counts for Micro-F1
        total_tp += tp
        total_fp += fp
        total_fn += fn
        
        # Calculate and store single F1 score for Macro-F1
        if tp + fp == 0:
            precision = 0.0
        else:
            precision = tp / (tp + fp)
        
        if tp + fn == 0:
            recall = 0.0
        else:
            recall = tp / (tp + fn)
            
        if precision + recall == 0:
            f1 = 0.0
        else:
            f1 = (2 * precision * recall) / (precision + recall)
        macro_f1_scores.append(f1)

        # Calculate Exact Match
        em = exact_match_score(prediction, ground_truth)
        em_scores.append(em)

    # --- Calculate Final Averages ---
    macro_avg_f1 = np.mean(macro_f1_scores) if macro_f1_scores else 0.0
    avg_em = np.mean(em_scores) if em_scores else 0.0

    # Calculate Micro-F1 from total counts
    micro_precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0.0
    micro_recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0.0
    
    if micro_precision + micro_recall == 0:
        micro_avg_f1 = 0.0
    else:
        micro_avg_f1 = (2 * micro_precision * micro_recall) / (micro_precision + micro_recall)

    print("\n--- Evaluation Results ---")
    # print(f"  Answer Relevancy (Cosine Sim): {avg_relevancy:.4f}")
    print(f"  Macro-F1 Score:                {macro_avg_f1:.4f} (Average of F1 per item)")
    print(f"  Micro-F1 Score:                {micro_avg_f1:.4f} (Calculated from global TP/FP/FN)")
    print(f"  Exact Match:                   {avg_em:.4f}")
    print("--------------------------\n")
    return avg_em, macro_avg_f1

def evaluate_generator(
    args, 
    retriever, 
    generator, 
    test_data, 
    pid2text, 
    device,
    passage_cache,
    all_passage_keys,
    summary_map_path, 
    run_name 
):
    retriever.eval()
    generator.eval()


    output_dir = 'rag_result'
    os.makedirs(output_dir, exist_ok=True)
    
    output_file_path = os.path.join(output_dir, f"{run_name}.jsonl")
    
    # Load the pre-computed summary map
    with open(summary_map_path, 'r') as f:
        id_to_summary = json.load(f)

    exact_matches, vqa_scores = [], []
    
    print("Building FAISS index for evaluation...")
    passage_vecs = retriever.encode_documents(passage_cache, all_passage_keys, eval_mode=True, pid2text=pid2text)
    
    if isinstance(passage_vecs, list):
        passage_vecs = torch.cat(passage_vecs, dim=0)
    else:
        passage_vecs = passage_vecs.detach()
    
    chunk_to_parent = []
    if args.mode == 'maxp':
        for pid in all_passage_keys:
            passage_data = passage_cache[pid]
            if 'num_chunks' in passage_data:
                # Map each chunk to its parent passage
                chunk_to_parent.extend([pid] * passage_data['num_chunks'])
            else:
                # Fallback: treat as single chunk
                chunk_to_parent.append(pid)
    
    faiss_index = build_faiss_index(passage_vecs)
    print("Evaluation index built.")

    # Open the file to write results in append mode
    with open(output_file_path, 'w', encoding='utf-8') as f_out:
        for item in tqdm(test_data, desc="Evaluating Generator"):
            question = item["input"]
            # Ensure answers is a single string for consistency
            answers = item['output'][0]['answer'] 
            
            if not answers:
                continue

            with torch.no_grad():
                q_vec = retriever.encode_queries([question])
                q_np = q_vec.detach().cpu().numpy().astype("float32")
                
            _, I = faiss_index.search(q_np, args.k_for_gen)
            topk_indices = I[0]
            
            if args.mode == 'maxp':
                chunk_indices = topk_indices.tolist()
                parent_ids = list(set(chunk_to_parent[j] for j in chunk_indices))
                retrieved_pids = parent_ids[:args.k_for_gen]
            else:
                retrieved_pids = [all_passage_keys[i] for i in topk_indices]
            context_summaries = [id_to_summary.get(pid, "") for pid in retrieved_pids] # Use .get for safety
            context = " ".join(context_summaries)
            
            input_text = f"Question: {question} Context: {context}"
            
            pred = generator.generate([input_text], device)[0].strip()
          
            result_entry = {
                "question": question,
                "ground_truth": answers,
                "prediction": pred
            }
            f_out.write(json.dumps(result_entry) + '\n')

            pred_norm = normalize_answer(pred)
            # Assuming 'answers' is a single correct string, not a list
            norm_answers = [normalize_answer(answers)] 
            em = int(pred_norm in norm_answers)
            exact_matches.append(em)

    em_score, f1 =evaluate_results(output_file_path)
    if args.wandb:
        wandb.log({"Gen_EM": em_score, "Gen_F1": f1})
def train_generator(
    args, retriever, generator, train_data,summary_map_path,
    pid2text, device, passage_cache, all_passage_keys,
    generator_model_path='checkpoint/generator.pt',
    epochs=3, batch_size=2
):
    
    # Freeze retriever weights
    retriever.eval()
    for p in retriever.parameters():
        p.requires_grad = False
    dataset = GeneratorRAGDataset(args, train_data, pid2text, retriever, passage_cache, device,all_passage_keys,summary_map_path=summary_map_path)
    
    def collate_fn_filter_none(batch):
        return tuple(zip(*batch))

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn_filter_none)
    optimizer = torch.optim.Adam(generator.parameters(), lr=5e-5) # Standard T5 learning rate

    for epoch in range(epochs):
        generator.train()
        total_loss, steps = 0, 0

        for input_texts, target_answers in tqdm(dataloader, desc=f"Generator Training Epoch {epoch+1}"):
            if input_texts is None:
                continue

            loss = generator(list(input_texts), list(target_answers), device)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            steps += 1

        print(f"[Generator Epoch {epoch+1}] Avg Loss: {total_loss / max(steps, 1):.4f}")
        torch.save(generator.state_dict(), generator_model_path)

def main(args, dataset, test_dataset, retriever, device,passage_cache, \
         pid2text):
    
    loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True,collate_fn=custom_collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False,collate_fn=custom_collate_fn)
    optimizer = torch.optim.Adam(retriever.parameters(), lr=args.lr)
    total_steps = args.num_epochs * len(loader)
    
    global_step = 0
    all_passage_ids = list(passage_cache.keys())
    results  = evaluate_retriever(args, retriever,test_loader, passage_cache, device, all_passage_ids,pid2text)
    for epoch in range(args.num_epochs):
        
        retriever.train()
        total_loss = 0 
        for question, pos_ids_list in tqdm(loader, desc=f"Epoch {epoch+1}"):
            if args.mode=='maxp':
                query_vecs = retriever.encode_queries(question)  # [B, D]

                batch_pos_pids = {pid for sub in pos_ids_list for pid in sub}
                num_negatives = args.batch_size * 4
                available_negatives = list(set(all_passage_ids) - batch_pos_pids)
                neg_pids = random.sample(available_negatives, min(num_negatives, len(available_negatives)))
                all_pids_in_batch = list(batch_pos_pids) + neg_pids

                doc_chunk_lists = retriever.encode_documents(passage_cache, all_pids_in_batch)
                doc_vecs_list = [chunks.mean(dim=0) if chunks.dim() == 2 else chunks.squeeze(0) for chunks in doc_chunk_lists]
                doc_vecs = torch.stack(doc_vecs_list, dim=0).to(device)

                pid_to_idx = {pid: i for i, pid in enumerate(all_pids_in_batch)}
                pos_mask = torch.zeros(len(question), len(all_pids_in_batch), dtype=torch.bool).to(device)
                for i, pos_ids in enumerate(pos_ids_list):
                    for pid in pos_ids:
                        if pid in pid_to_idx:
                            pos_mask[i, pid_to_idx[pid]] = 1

                loss = supcon_loss(query_vecs, doc_vecs, pos_mask)
            else:
                query_vecs = retriever.encode_queries(question)
                    
                batch_pos_pids = {pid for sub in pos_ids_list for pid in sub}
                num_negatives = args.batch_size * 4 
                available_negatives = list(set(all_passage_ids) - batch_pos_pids)
                neg_pids = random.sample(available_negatives, min(num_negatives, len(available_negatives)))
                all_pids_in_batch = list(batch_pos_pids) + neg_pids
                if args.mode=='lc':
                    doc_vecs = retriever.encode_documents(passage_cache, all_pids_in_batch, pid2text=pid2text)
                else:
                    doc_vecs = retriever.encode_documents(passage_cache, all_pids_in_batch)
                
                pid_to_idx = {pid: i for i, pid in enumerate(all_pids_in_batch)}
                pos_mask = torch.zeros(len(question), len(all_pids_in_batch), dtype=torch.bool).to(device)
                for i, pos_ids in enumerate(pos_ids_list):
                    for pid in pos_ids:
                        
                        if pid in pid_to_idx:
                            pos_mask[i, pid_to_idx[pid]] = 1
                
                loss = supcon_loss(query_vecs, doc_vecs, pos_mask)
            optimizer.zero_grad()
            loss.backward()

            
            torch.nn.utils.clip_grad_norm_(retriever.parameters(), max_norm=1.0)
            optimizer.step()
            
            total_loss += loss.item()
            global_step += 1

        print(f"Epoch {epoch+1} - Loss: {total_loss / len(loader):.4f}")
        if args.mode=='late':
            torch.save(retriever.state_dict(), args.model_save_path)
        else:
            if (epoch) % args.eval_period == 0 or (epoch + 1) == args.num_epochs:
                results  = evaluate_retriever(args, retriever,test_loader, passage_cache, device, all_passage_ids,pid2text)
                
                print(f"Epoch {epoch+1} – Recall@1: {results['Recall@1']:.4f}, Recall@5: {results['Recall@5']:.4f}, Recall@10: {results['Recall@10']:.4f}")

                if args.wandb:
                    wandb.log({**results, "epoch": epoch + 1})
                torch.save(retriever.state_dict(), args.model_save_path)
    results  = evaluate_retriever(args, retriever,test_loader, passage_cache, device, all_passage_ids,pid2text)
    print(f"Final Recall@1: {results['Recall@1']:.4f}, Recall@5: {results['Recall@5']:.4f}, Recall@10: {results['Recall@10']:.4f}")

    if args.wandb:
        wandb.log({f"Final {k}": v for k, v in results.items()})
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--num_epochs", type=int, default=100)
    parser.add_argument("--proj_dim", type=int, default=512)
    parser.add_argument("--eval_period", type=int, default=10)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--lr", type=float, default=1e-5)
    parser.add_argument('--dataset', default='nq')
    parser.add_argument('--model', default='bert')
    parser.add_argument("--temp", type=float, default=0.02)

    parser.add_argument('--mode', default='cls')
    parser.add_argument("--d_attn", type=int, default=128)
    parser.add_argument("--gate", type=str, default="none")
    parser.add_argument('--K', type=int, default=0, help='Number of frequency bins for STC')
    
    parser.add_argument("--wandb", action="store_true")
    parser.add_argument("--gen_only", action="store_true")
    # parser.add_argument("--eval_mode", action="store_true")

    # Arguments for Generator
    parser.add_argument('--k_for_gen', type=int, default=5, help='Number of passages for generator context')
    parser.add_argument('--gen_epochs', type=int, default=1, help='Number of epochs for generator training')
    parser.add_argument('--gen_batch_size', type=int, default=2, help='Batch size for generator training')
    start = time.time()
    args, _ = parser.parse_known_args()
    if args.model=='bert':
        chunk_size = 512
    elif args.model=='gtr':
        chunk_size = 768
    
    args, _ = parser.parse_known_args()

    set_seed(args.seed)
    args.cache_path = f"cache/cache_{args.dataset}_{args.model}.pt"
    if args.mode in ['cls','sat','mean','lc']:
        run_name = f"train_{args.mode}_{args.dataset}_{args.model}_{args.lr}_{args.batch_size}_{args.num_epochs}_{args.proj_dim}"
    elif args.mode=='maxp':
        run_name = f"train_maxp_{args.dataset}_{args.model}_{args.lr}_{args.batch_size}_{args.num_epochs}_{args.proj_dim}"
    elif args.mode=='stc':
        run_name = f"train_stc_{args.dataset}_{args.model}_{args.lr}_{args.batch_size}_{args.num_epochs}_{args.proj_dim}_K{args.K}_gate{args.gate}"
    elif args.mode=='sat':
        run_name = f"train_sat_{args.dataset}_{args.model}_{args.lr}_{args.batch_size}_{args.num_epochs}_{args.proj_dim}_K{args.K}_gate{args.gate}"
        if args.d_attn != 128:
            run_name = run_name + f"_d_attn{args.d_attn}"
    
    args.model_save_path = os.path.join('checkpoint',run_name+'.pt')

    if args.wandb:
        wandb.init(
        project="LongRAG",
        config=vars(args),
        name=run_name)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    if args.model == 'bert':
        from transformers import BertTokenizerFast
        from model import RAVQARetriever
        retriever = RAVQARetriever(args,device=device).to(device) # RAVQA utilize BERT encoders, so has the same architecture with BERT-DPR.
        tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
    elif args.model=='gtr':
        from model import GTRRetriver
        retriever = GTRRetriver(args, device).to(device)
    
    
    
    
    if args.mode=='cls':
        args.cache_path = args.cache_path.replace(".pt","_cls.pt")
    elif args.mode=='maxp':
        args.cache_path = args.cache_path.replace(".pt","_maxp.pt")
    elif args.mode=='lc':
        args.cache_path = args.cache_path.replace(".pt","_lc.pt")
    if not os.path.exists(args.cache_path):
        if args.dataset=='nq':
            passages = load_from_disk("wiki_nq_subset")
        elif  args.dataset=='hotpotqa':
            passages = load_from_disk("wiki_hotpot_subset")
        if args.model=='gtr':        
            passage_cache = cache_passage_gtr_embeddings(args, passages, retriever, device, args.batch_size, args.cache_path, args.mode,chunk_size = chunk_size)
        else:
            passage_cache = cache_passage_embeddings(args, passages, tokenizer, retriever, device, args.batch_size, args.cache_path, args.mode,chunk_size = chunk_size)
    
    else:
        passage_cache = torch.load(args.cache_path)
    
    train_data = load_dataset("kilt_tasks", args.dataset, split="train[:20000]")
    test_data = load_dataset("kilt_tasks", args.dataset, split="validation[:1000]")
    # Initialize pid2text based on mode
    if args.mode=='cls':
        pid2text = {pid: v["text"] for pid, v in passage_cache.items()}
    else:
        text_cache_path = args.cache_path.replace(".pt", "_texts.json")
        if os.path.exists(text_cache_path):
            with open(text_cache_path, "r") as f:
                pid2text = json.load(f)
        else:
            pid2text = {pid: "".join(v["text"]['paragraph']) for pid, v in passage_cache.items()}

    dataset = RAGDataset(args, train_data)
    test_dataset = RAGDataset(args, test_data)
    
    if not args.gen_only:
        main(args,dataset, test_dataset, retriever, device, passage_cache, pid2text)

    if args.gen_only:
        # --- Start of Generator Training ---
        print("\n--- Initializing Generator Training ---")
        all_passage_keys = list(passage_cache.keys())
        # Load the trained retriever model
        print(f"Loading trained retriever from {args.model_save_path}")
        checkpoint = torch.load(args.model_save_path, map_location=device)
        retriever.load_state_dict(checkpoint,strict=False)
        generator = Generator().to(device)
        generator_model_path = args.model_save_path.replace('.pt', '_generator.pt')
        summary_map_path = f"cache/{args.dataset}_corpus_summaries.json"
        if args.eval_mode:
            generator.load_state_dict(torch.load(generator_model_path, map_location=device))
            evaluate_generator(args, retriever, generator,test_data, pid2text, device, passage_cache, \
                            all_passage_keys,summary_map_path,run_name)
        else:
                
            train_generator(args, retriever, generator,train_data,summary_map_path, pid2text, device, passage_cache, \
                            all_passage_keys,generator_model_path,epochs=args.gen_epochs,batch_size = args.gen_batch_size)
            evaluate_generator(args, retriever, generator,test_data, pid2text, device, passage_cache, \
                            all_passage_keys,summary_map_path,run_name)
        