import torch
from tqdm import tqdm
import json
from utils import set_seed
from model import PassageSummarizer
import warnings
import argparse
import os
from run_rag import cache_passage_embeddings
from datasets import load_from_disk, load_dataset
warnings.filterwarnings("ignore")
from collections import OrderedDict
from collections import OrderedDict
from tqdm import tqdm
import json

def summarize_all_documents(pid2text, summarizer, output_path, batch_size=32):
    
    print(f"Starting to summarize {len(pid2text)} documents...")
    
    id_to_summary = OrderedDict()
    batch_pids = []
    batch_texts = []

    
    pbar = tqdm(total=len(pid2text), desc="Summarizing Corpus")
    
    for pid, text in pid2text.items():
        batch_pids.append(pid)
        batch_texts.append(text)
        
        if len(batch_pids) == batch_size:
            batch_summaries = summarizer.summarize_passages(batch_texts)
            for p, summary in zip(batch_pids, batch_summaries):
                id_to_summary[p] = summary
            
            pbar.update(len(batch_pids))
            batch_pids = []
            batch_texts = []
            
    if batch_pids:
        batch_summaries = summarizer.summarize_passages(batch_texts)
        for p, summary in zip(batch_pids, batch_summaries):
            id_to_summary[p] = summary
        pbar.update(len(batch_pids))

    pbar.close()
            
    with open(output_path, 'w') as f:
        json.dump(id_to_summary, f, indent=4)
        
    print(f"Corpus summarization complete. Saved to {output_path}")
    return id_to_summary

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--num_epochs", type=int, default=100)
    parser.add_argument("--proj_dim", type=int, default=512)
    parser.add_argument("--eval_period", type=int, default=10)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--lr", type=float, default=1e-5)
    
    parser.add_argument('--dataset', default='nq')
    parser.add_argument('--mode', default='cls')
    parser.add_argument('--model', default='bert')
    parser.add_argument("--gate", type=str, default="none")
    parser.add_argument('--K', type=int, default=0, help='Number of frequency bins for STC')
    parser.add_argument('--saa_mean', action='store_true', help='Track gradient norm during training')
    parser.add_argument('--normalize_positions', action='store_true', help='Normalize positional encoding to [0,1]')
    
    parser.add_argument("--temp", type=float, default=0.02)

    parser.add_argument("--wandb", action="store_true")
    parser.add_argument("--gen_only", action="store_true")
    parser.add_argument("--eval_mode", action="store_true")

    # Arguments for Generator
    parser.add_argument('--k_for_gen', type=int, default=5, help='Number of passages for generator context')
    parser.add_argument('--gen_epochs', type=int, default=10, help='Number of epochs for generator training')
    parser.add_argument('--gen_batch_size', type=int, default=2, help='Batch size for generator training')
    
    args, _ = parser.parse_known_args()
    if args.model=='bert':
        chunk_size=512
    
    args, _ = parser.parse_known_args()

    set_seed(args.seed)
    args.cache_path = f"cache/cache_{args.dataset}_{args.model}.pt"
    if args.mode in ['cls','sat','mean','lc']:
        run_name = f"train_{args.mode}_{args.dataset}_{args.model}_{args.lr}_{args.batch_size}_{args.num_epochs}_{args.proj_dim}"
    elif args.mode=='maxp':
        run_name = f"train_maxp_{args.dataset}_{args.model}_{args.lr}_{args.batch_size}_{args.num_epochs}_{args.proj_dim}"
    elif args.mode=='stc':
        run_name = f"train_stc_{args.dataset}_{args.model}_{args.lr}_{args.batch_size}_{args.num_epochs}_{args.proj_dim}_K{args.K}_gate{args.gate}"
    elif args.mode=='sat':
        run_name = f"train_sat_{args.dataset}_{args.model}_{args.lr}_{args.batch_size}_{args.num_epochs}_{args.proj_dim}_K{args.K}_gate{args.gate}"
        if args.d_attn != 128:
            run_name = run_name + f"_d_attn{args.d_attn}"
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    args.model_save_path = os.path.join('checkpoint',run_name+'.pt')
    if args.model == 'bert':
        from transformers import BertTokenizerFast
        from model import RAVQARetriever
        retriever = RAVQARetriever(args,device=device).to(device) # RAVQA utilize BERT encoders, so has the same architecture with BERT-DPR.
        tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
    elif args.model=='gtr':
        from model import GTRRetriver
        from transformers import AutoTokenizer
        retriever = GTRRetriver(args, device).to(device)
        tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/gtr-t5-base")

    if args.mode=='cls':
        args.cache_path = args.cache_path.replace(".pt","_cls.pt")
    if not os.path.exists(args.cache_path):
        if args.dataset=='nq':
            passages = load_from_disk("wiki_nq_subset")
        elif  args.dataset=='hotpotqa':
            passages = load_from_disk("wiki_hotpot_subset")
        passage_cache = cache_passage_embeddings(args, passages, tokenizer, retriever, device, args.batch_size, args.cache_path, args.mode,chunk_size = chunk_size)
    else:
        passage_cache = torch.load(args.cache_path)

    pid2text = None
    if args.mode=='lc':
        text_cache_path = args.cache_path.replace(".pt", "_texts.json")
        with open(text_cache_path, "r") as f:
            pid2text = json.load(f) 
    summarizer = PassageSummarizer(
        model_name="facebook/bart-large-cnn",
        device=device,
        max_new_tokens=80,
        batch_size=8
    )
    if pid2text is None:
        text_cache_path = args.cache_path.replace(".pt", "_texts.json")
        with open(text_cache_path, "r") as f:
            pid2text = json.load(f) 

    summarize_all_documents(
        pid2text=pid2text,
        summarizer=summarizer,
        output_path=f"cache/{args.dataset}_corpus_summaries.json",
        batch_size=32 # Larger batch size for the outer loop
    )
