import torch
from tqdm import tqdm
import pandas as pd
# from eval.utils import load_dexperts_model_and_tokenizer, load_dexperts_model_and_tokenizer_vllm
# from analysis.utils import flatten_batch_results, summarize_results, trim_output
from vllm import LLM, SamplingParams
import jsonlines
import os
# from vllm_inject import sequence_inject, sample_output_inject, model_runner_inject, llm_engine_inject, scheduler_inject, config_inject
# from vllm_inject.utils import *
import json, re
import evaluate
from transformers import AutoTokenizer
from typing import Iterable, Dict, List, Optional, Union
import gzip
import numpy as np


os.environ["TOKENIZERS_PARALLELISM"] = "false"
def create_prompt(row):
    return f'Question: {row["question"]}\nAnswer:'
@torch.inference_mode()
def get_triviaqa_output(base_model,
                   tokenizer,
                   max_tokens,
                   batch_size,
                   temperature,
                   top_p,
                   use_chat_format=False,
                   save_dir= "outputs/triviaqa",
                   icl=False):
    print("Loading data...")
    # use dev set because test set answers are hidden
    test_df = pd.read_json("/xx/data/eval/triviaqa/dev.jsonl", lines=True)
    prompt_prefix = ""
    if icl == True:
        icl_context = ""
        with open("/xx/data/eval/triviaqa/train.jsonl", "r") as f:
            tot = 0
            for i in f:
                tot += 1
                data = json.loads(i)
                icl_context = icl_context + "Question: " + data["question"].strip() + "\nAnswer:" + data["answer"].strip() + "\n"
                if tot >= 5:break
        prompt_prefix = icl_context + prompt_prefix
    # Create prompts
    prompts = []
    for i, row in test_df.iterrows():
        prompts.append(prompt_prefix + create_prompt(row))

    new_line_token = [tokenizer.encode("\n\n", add_special_tokens=False)[-1]]
    
    sampling_params = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=max_tokens, stop_token_ids=new_line_token)

    all_results = []
    for i in tqdm(range(0, len(prompts), batch_size), desc="Batches"):
        batch_prompts = prompts[i: i + batch_size]
        base_output = base_model.generate(batch_prompts, sampling_params)
        for j in range(len(base_output)):
            all_results.append(
                {"inputs": batch_prompts[j],
                "output": base_output[j].outputs[0].text}
            )

    return test_df, all_results

@torch.inference_mode()
def main(*,
         model_name: str = "meta-llama/CodeLlama-7b-hf",
         batch_size: int = 1024,
         temperature: float = 0.1,
         top_p: float = 0.9,
         tensor_parallel_size : int = 1,
         max_num_seqs : int = 256,
         max_tokens : int = 256,
         save_dir: str = "outputs/triviaqa",
         icl: int = 0):
    # load model
    # clear_share_io()
    icl_type = False if icl == 0 else True
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    exact_match = evaluate.load("exact_match")
    base_model = LLM(model=model_name, tensor_parallel_size=tensor_parallel_size, gpu_memory_utilization=0.75, enforce_eager=True, max_num_seqs=max_num_seqs)
    
    test_df, all_results = get_triviaqa_output(base_model, tokenizer, max_tokens, batch_size, temperature, top_p, icl=icl_type, save_dir=save_dir)
    
    test_df['output'] = [o["output"].strip() for o in all_results]
    cors = []
    for i, row in test_df.iterrows():
        # ignore casing
        pred = row['output'].lower()
        answers = [a.strip().lower() for a in row['answers']]
        cors.append(pred in answers)

    test_df['correct'] = cors
    acc = np.nanmean(cors)
    print(f"Accuracy: {np.round(acc, 5)}")

    test_df.to_json(os.path.join(save_dir, "predictions.jsonl"), lines=True, orient='records')

    # save results
    with open(os.path.join(save_dir,  f"metrics_{model_name.replace('/', '#')}.json"), "w") as fo:
        json.dump({
            "acc": acc,
            "tot": len(test_df)
        }, fo)
    
    
if __name__ == "__main__":
    import defopt
    try:
        defopt.run(main)
    except:
        import sys,pdb,bdb
        type, value, tb = sys.exc_info()
        if type == bdb.BdbQuit:
            exit()
        print(type,value)
        pdb.post_mortem(tb)
