import os 
import argparse
import warnings
from tqdm import tqdm
from typing import Tuple, Dict, Optional

import torch
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizer

from llm import LLM
from utils import *
from prompt_dataset import PromptDataset

setting='gold_distraction3'
os.environ["TOKENIZERS_PARALLELISM"] = "false"
device = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu")
warnings.filterwarnings('ignore')
SEED=10

info = {
    "data_path": 'data/10k_train_dataset.json',
    "random_results_path": "data/10k_random_results_at60.pkl",
    "adore_search_results_path": "data/adore_search_results_at200.pkl",
    "contriever_search_results_path": "data/contriever_search_results_at150.pkl",
}


def parse_arguments():
    parser = argparse.ArgumentParser(description="Run LLM Generation.")
    parser.add_argument('--output_dir', type=str, default='data/gen_res', help='Output directory')
    parser.add_argument('--setting', type=str, default='data/gen_res', help='Output directory')
    parser.add_argument('--llm_id', type=str, default='meta-llama/Meta-Llama-3.1-8B-Instruct', help='LLM model identifier')
    parser.add_argument('--model_max_length', type=int, help='Maximum input length for the LLM model', default=4096)
    parser.add_argument('--load_full_corpus', type=str2bool, help='Load the full corpus', default=True)    
    parser.add_argument('--use_random', type=str2bool, default=False, help='Use random irrelevant documents')
    parser.add_argument('--use_adore', type=str2bool, default=False, help="Use the retrieved documents from ADORE")
    parser.add_argument('--gold_position', type=int, default=None, help='The (0-indexed) position of the gold document in the context')
    parser.add_argument('--num_documents_in_context', type=int, help='Total number of documents in the context')
    parser.add_argument('--get_documents_without_answer', type=str2bool, help='Select only documents without the answer (e.g., distracting)', default=True)
    parser.add_argument('--max_new_tokens', type=int, help='Maximum number of tokens to generate', default=15)
    parser.add_argument('--batch_size', type=int, default=10)
    parser.add_argument('--num_of_query', type=int, default=None)
    parser.add_argument('--save_every', type=int, default=250)
    parser.add_argument('--sparsity', type=float, default=None)
    parser.add_argument('--use_check_point', type=str2bool, default=False)
    parser.add_argument('--reverse', type=str2bool, default=False)
    args = parser.parse_args()

    if args.num_documents_in_context is None:
        parser.error("'num_documents_in_context' must be specified.")
    if args.num_documents_in_context <= 0:
        parser.error("'num_documents_in_context' must be a positive integer.")
    if args.gold_position is not None and (args.gold_position < 0 or args.gold_position >= args.num_documents_in_context):
        parser.error("'gold_position' must be within the range of 'num_documents_in_context'.")

    return args


def load_corpus(
    args: argparse.Namespace
) -> Tuple[List[Dict], Optional[Dict[int, int]]]:
    # Load the corpus
    if args.load_full_corpus:
        corpus = read_corpus_json('data/corpus.json')
        return corpus, None

    if args.use_random:
        corpus, full_to_subset_idx_map = read_corpus_with_random()
    elif args.use_adore:
        corpus, full_to_subset_idx_map = read_corpus_with_adore()
    else: 
        # Corpus with documents from Contriever
        corpus, full_to_subset_idx_map = read_corpus_with_contriever()

    return corpus, full_to_subset_idx_map


def load_search_results(args: argparse.Namespace) -> List[Tuple[List[int], List[float]]]:
    # Decide on search results path based on conditions
    if args.use_random:
        search_results_path = info['random_results_path']
    elif args.use_adore:
        search_results_path = info['adore_search_results_path']
    else:
        # Search results from Contriever
        search_results_path = info['contriever_search_results_path'] 

    search_results = read_pickle(search_results_path)
    return search_results


def initialize_dataset_and_loader(
    args: argparse.Namespace, 
    corpus: List[Dict], 
    full_to_subset_idx_map: Optional[Dict[int, int]], 
    search_results: List[Tuple[List[int], List[float]]], 
    tokenizer: PreTrainedTokenizer
) -> DataLoader:
    
    prompt_ds = PromptDataset(args,
        corpus=corpus, data_path=info['data_path'], 
        tokenizer=tokenizer, 
        max_tokenized_length=args.model_max_length - 2, 
        search_results=search_results,
        full_to_subset_idx_map=full_to_subset_idx_map,
        do_normalize_query=True, 
        num_documents_in_context=args.num_documents_in_context,
        gold_position=args.gold_position,
        get_documents_without_answer=args.get_documents_without_answer,
        randomize_gold_position=True,
        num_of_query=args.num_of_query
    )


    prompt_ds.save_as_json()
    exit(0)
    prompt_dataloader = DataLoader(
        prompt_ds,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=8,
        pin_memory=True,
    )
    return prompt_dataloader


def print_info(args: argparse.Namespace):
    print("INFO:")
    print(f"DATA: {info['data_path']}")
    print(f"MODEL: {args.llm_id}")
    print(f"USE RANDOM IN CONTEXT: {args.use_random}")
    print(f"USE ADORE: {args.use_adore}")
    print(f"GOLD POSITION: {args.gold_position}")
    print(f"NUM DOCUMENTS IN CONTEXT: {args.num_documents_in_context}")
    print(f"DOCUMENTS WITHOUT ANSWER: {args.get_documents_without_answer}")
    print(f"BATCH SIZE: {args.batch_size}")
    print(f"SAVE EVERY: {args.save_every}")


def generate_and_save(
    args: argparse.Namespace, 
    llm: LLM, 
    prompt_dataloader: DataLoader
):
    # Info from arguments
    llm_id = args.llm_id
    num_doc = args.num_documents_in_context
    save_every = args.save_every
    gold_pos = args.gold_position
    retriever_str = "adore" if args.use_adore else "contriever"
    rand_str = "_rand" if args.use_random else ""
    answerless_str = "_answerless" if args.get_documents_without_answer else ""

    # Create the saving directory
    llm_folder = llm_id.split("/")[1] if '/' in llm_id else llm_id
    saving_dir = f"{args.output_dir}/{llm_folder}/train/classic/{retriever_str}/{num_doc}_doc"
    if not os.path.exists(saving_dir):
        os.makedirs(saving_dir)

    
    # MPT has a different answer string in the prompt
    answer_string_in_prompt = "### Response:" if 'mpt' in llm_id else "Answer:"

    all_info = []
    ans_match_after_norms=[[],[],[]]
    train_example_id,test_example_id=get_train_test()
    for idx, prompt_batch in enumerate(tqdm(prompt_dataloader)):
        prompts = prompt_batch['prompt']
        generated_output = llm.generate(prompts, max_new_tokens=args.max_new_tokens)
        generated_answers = []
        for output in generated_output:
            start = output.find(answer_string_in_prompt) + len(answer_string_in_prompt)
            response = output[start:].strip()
            generated_answers.append(response)

        prompt_batch['generated_answer'] = generated_answers
        all_info.append(prompt_batch)
        
        if (idx + 1) % save_every == 0 or (idx + 1) == len(prompt_dataloader):
            print(f"Saving at {idx + 1}...")
            file_name = f"{saving_dir}/numdoc{num_doc}_gold_at{gold_pos}{rand_str}{answerless_str}_info_{idx+1}.pkl"
            write_pickle(all_info, file_name)
            all_info = []

        from read_generation_results import are_answers_matching
        for i in range(len(prompt_batch)):
            length=len(prompt_batch['answers'])
            if i>=length:
                continue
            answers=prompt_batch['answers'][i].split('&')
            answers=[answers[0]]
            ans_match_after_norm: bool = are_answers_matching(generated_answers[i], answers)
            ans_match_after_norms[0].append(ans_match_after_norm)
            if prompt_batch['example_id'][i] in train_example_id:
                ans_match_after_norms[1].append(ans_match_after_norm)
            if prompt_batch['example_id'][i] in test_example_id:
                ans_match_after_norms[2].append(ans_match_after_norm)
            print('generated answers==',generated_answers[i])
            print('answers==',answers)
    llm.remove_hooks()
    return ans_match_after_norms
def get_checkpoint_path(path):
    import os
    all_folders = [f for f in os.listdir(path) if os.path.isdir(os.path.join(path, f))]
    epochs = []
    for folder in all_folders:
        try:
            # 假设文件夹名称格式为 'checkpoint-epoch'
            epoch = int(folder.split('-')[-1])
            epochs.append((epoch, folder))
        except ValueError:
            pass  # 跳过非标准命名的文件夹
    max_epoch,_=max(epochs)
    path=f'{path}/checkpoint-{max_epoch}'
    print('max epoch==',max_epoch)
    return path
def main():
    print("Loading LLM...")
    args = parse_arguments()
    global setting
    setting=(list)(setting)
    setting[-1]=(str)(args.num_documents_in_context)
    if args.reverse:
        setting.extend('_reverse')
    setting=''.join(setting)
    args.setting=setting
    if args.sparsity==None or args.sparsity==0:
        llm_id=args.llm_id
    else:
        llm_id=f'/data/somebody/data/wanda/gold/{args.llm_id}/sparsity_{args.sparsity}/model'
    # if args.use_check_point:
    #     check_point=get_checkpoint_path(f'/data/somebody/data/checkpoints/gold/{args.llm_id}/{args.sparsity}/{setting}')
    # else:
    #     check_point=None
    check_point='/home/somebody/codes/RAGE/ret-robust-main/checkpoints/gold/meta-llama/Llama-2-7b-chat-hf/None/gold_only/checkpoint-28'
    use_lora=False if check_point==None else True
    llm = LLM(
        llm_id,setting, device, quantization_bits=4, 
        model_max_length=args.model_max_length,
        use_lora=use_lora, check_point=check_point
    )
    tokenizer = llm.tokenizer
    print("LLM loaded")


    print("Loading corpus and search results...")
    corpus, full_to_subset_idx_map = load_corpus(args)
    #print(corpus[0])
    search_results = load_search_results(args)
    print("Corpus and search results loaded")


    print("Loading prompt dataset...")
    prompt_dataloader = initialize_dataset_and_loader(
        args, corpus, full_to_subset_idx_map, search_results, tokenizer
    )
    print("Prompt dataset loaded")

    print_info(args)
    print("*************************************************************")
    print('setting==',setting)
    print("*************************************************************")
    ans_match_after_norms=generate_and_save(args, llm, prompt_dataloader)
    acc=[0,0,0]
    for i in range(3): 
        acc[i]=1.0*np.sum(ans_match_after_norms[i])/len(ans_match_after_norms[i])
    
    print("acc==",acc)
    with open('result.txt','a') as f:
        content=f'{args.llm_id} {setting} {args.sparsity} {args.use_check_point} {acc[0]} {acc[1]} {acc[2]}\n'
        f.write(content)
        
if __name__ == "__main__":
    seed_everything(SEED)
    main()