import torch
from tqdm import tqdm
import pandas as pd
# from eval.utils import load_dexperts_model_and_tokenizer, load_dexperts_model_and_tokenizer_vllm
# from analysis.utils import flatten_batch_results, summarize_results, trim_output
from vllm import LLM, SamplingParams
from openai import OpenAI
import jsonlines
import os
from vllm_inject.utils import *
import json, re
import evaluate, random
from transformers import AutoTokenizer
import numpy as np

BEST_COL = 'Best Answer'
ANSWER_COL = 'Correct Answers'
INCORRECT_COL = 'Incorrect Answers'
CHOICES = 'ABCD'
MC_ANSWER_PREFIX = 'The answer is:'

def format_best(best_ans, close=True):
    """Formats best answer to match format of reference answers"""

    best = best_ans.strip()
    if close:
        if best[-1] != '.':
            best = best + '.'
    return best

def split_multi_answer(ans, sep=';', close=True):
    """Splits string of all reference answers into a list of formatted answers"""

    answers = ans.strip().split(sep)
    split_answers = []
    for a in answers:
        a = a.strip()
        if len(a):
            if close:  # add a period after all answers
                if a[-1] != '.':
                    split_answers.append(a + '.')
                else:
                    split_answers.append(a)
            else:
                split_answers.append(a)

    return split_answers

def format_question_mc(row):
    """Format prompt for MC evaluation"""
    choices = 'ABCD'
    # construct answer options from the best answer, and 3 incorrect answers
    # or fewer if there are fewer incorrect answers
    ref_true = format_best(row[BEST_COL])
    refs_false = split_multi_answer(row[INCORRECT_COL])
    refs_false = random.sample(refs_false, min(3, len(refs_false)))

    # shuffle answer options
    shuffled_answers = [ref_true] + refs_false
    random.shuffle(shuffled_answers)

    # construct answer choices string
    answer_options = ''
    answer_idx = None
    for c, a in zip(choices, shuffled_answers):
        answer_options += f'\n{c}. {a}'
        if a == ref_true:
            answer_idx = c
    prompt = row['Question'] + answer_options
    return prompt, answer_idx
def trim_answer(answer):
    """
    Trim generated answer for open-ended evaluation setting.
    """
    # remove spaces at the beginning and end
    answer = answer.strip()
    # remove the "Answer:" prefix if it exists
    if answer.startswith('Answer:'):
        answer = answer[len('Answer:'):].strip()
    # reformat line-breaks for long-form answers
    answer = answer.replace('[INST]', '')
    answer = answer.replace('[\/INST]', '')
    answer = re.sub(r"\[([A-Z])\]", lambda x: f"\n\n{x.group(1)}", answer)
    answer = answer.replace('\n\n', '\n')
    
    return answer

@torch.inference_mode()
def get_mmlu_output(base_model,
                        tokenizer,
                        max_tokens,
                        batch_size,
                        temperature,
                        top_p,
                        use_chat_format=False,
                        icl=False,
                        system_prompt="You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."):
    print("Loading data...")
    prompts = []
    answer_idxs = []
    test_data = []
    with open("/xx/data/eval/mmlu/test.jsonl") as fin:
        for line in fin:
            example = json.loads(line)
            test_data.append({
                "question": example["question"],
                "answer": example["answer"],
                "class": example["class"]
            })
            # if len(test_data) >= 2: break
   

    prompt_prefix = "The following are multiple choice questions (with answers). Please choose a correct answer.\n"
    #prompt_prefix = "The following are multiple choice questions (with answers). Please choose one of the A,B,C,D as the correct answer.\n"
    prompts = []
    if icl == True:
        # icl_context = ""
        # with open("/xx/data/eval/mmlu/train.jsonl", "r") as f:
        #     tot = 0
        #     for i in f:
        #         tot += 1
        #         data = json.loads(i)
        #         icl_context = icl_context + "Question: " + data["prompt"].strip() + "\nThe answer is:" + data["completion"].strip() + "\n"
        #         if tot >= 5:break
        # prompt_prefix = icl_context + prompt_prefix
        few_shot_dict = json.load(open("/xx/data/eval/few_shot.json", "r"))
        icl_context_dict = few_shot_dict["mmlu"]
    chat_formatting_function = eval.templates.create_prompt_with_tulu_chat_format if use_chat_format else None
    for example in test_data:
        if icl == True:
            # if example["class"] == "conceptual_physics":
            #     prompt = icl_context_dict["high_school_physics"].replace("high school physics", "conceptual physics") + "\nQuestion: " + example["question"].strip()    
            #     # print(example["class"])
            # elif example["class"] == "elementary mathematics":
            #     prompt = icl_context_dict["high_school_mathematics"].replace("high school mathematics", "elementary_mathematics") + "\nQuestion: " + example["question"].strip()    
            # else:
                prompt = icl_context_dict[example["class"]] + "\nQuestion: " + example["question"].strip()
        else:
            prompt = prompt_prefix + "Question: " + example["question"].strip()
        if use_chat_format:
            messages = [{"role": "user", "content": prompt}]
            prompt = chat_formatting_function(messages, add_bos=False)
            if prompt[-1] in ["\n", " "]:
                prompt += "The answer is:"
            else:
                prompt += " The answer is:"
        else:
            prompt += "\nThe answer is:"
        # now_prompt_prefix = prompt_prefix
        # if len(prompt) + len(prompt_prefix) >= 4096:
        #     now_prompt_prefix = prompt_prefix[:4096 - len(prompt)]
        # prompt = now_prompt_prefix + prompt
        prompts.append(prompt)
        
    sampling_params = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=max_tokens)
    all_results = []
    outputs = []
    for i in tqdm(range(0, len(prompts), batch_size), desc="Batches"):
        batch_prompts = prompts[i: i + batch_size]
        base_output = base_model.generate(batch_prompts, sampling_params)
        for j in range(len(base_output)):
            all_results.append(
                {"inputs": batch_prompts[j],
                 "output": base_output[j].outputs[0].text}
            )
            outputs.append(base_output[j].outputs[0].text)

    return test_data, all_results

@torch.inference_mode()
def main(*,
         model_name: str = "meta-llama/Llama-2-13b-hf",
         batch_size: int = 1024,
         temperature: float = 0.1,
         top_p: float = 0.9,
         tensor_parallel_size : int = 1,
         max_num_seqs : int = 256,
         max_tokens : int = 256,
         save_dir: str = "outputs/mmlu",
         icl: int = 0):

    icl_type = False if icl == 0 else True
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    base_model = LLM(model=model_name, tensor_parallel_size=tensor_parallel_size, gpu_memory_utilization=0.75, enforce_eager=True, max_num_seqs=max_num_seqs)
    test_data, all_results = get_mmlu_output(base_model, tokenizer, max_tokens, batch_size, temperature, top_p, icl=icl_type)
    outputs = [trim_answer(o["output"]) for o in all_results]
    parsed_outputs = []
    for idx, o in enumerate(outputs):
        fi = o.split("\n")[0].strip()
        last_index = all_results[idx]["inputs"].rfind(fi)
        parsed_output = None
        if last_index - 3 >= 0 and all_results[idx]["inputs"][last_index-3] in CHOICES and all_results[idx]["inputs"][last_index-2] == ".":
            parsed_output = all_results[idx]["inputs"][last_index-3]
        if parsed_output is None:
            # remove strings that sometimes appear before the answer option
            to_remove = ['(', '\\begin{blockquote}', '\\begin{code}', '<blockquote>', ' **', '>', '```\n', "I think the answer is ", "I think it is "]
            for r in to_remove:
                o = o.replace(r, '')
            o = o.lstrip()
            # interpret first character as prediction
            if o and o[0] in CHOICES:  # o is not empty string
                parsed_output = o[0]
            else:
                parsed_output = ''
        parsed_outputs.append(parsed_output)
    

    correct = np.array([
        pred == true["answer"] if pred else float('nan')
        for pred, true in zip(parsed_outputs, test_data)
    ])
    for i, (r, pout, cor) in enumerate(zip(all_results, parsed_outputs, correct)):
        test_data[i]["prompt"] = r["inputs"]
        test_data[i]["output"] = r["output"]
        test_data[i]["parsed_output"] = pout
        test_data[i]["correct"] = str(cor)
    
    acc = np.sum(correct==1) / len(correct)
    num_invalid_pred = np.sum(np.isnan(correct))

    print(f"Invalid predictions: {num_invalid_pred}/{len(parsed_outputs)}")
    # if model_name == "meta-llama/Llama-2-13b-hf":
    with open(os.path.join(save_dir, "mmlu_results.jsonl"), "w") as f:
        for i in test_data:
            f.write(json.dumps(i) + "\n")
    # format and print basic results
    results = {
        'acc': acc,
        'num_invalid_predictions': int(num_invalid_pred),
        'tot': len(parsed_outputs)
    }
    print(results)

    # if model_name == "meta-llama/Llama-2-13b-hf":
    with open(os.path.join(save_dir, 'mmlu_metrics.json'), 'w') as f:
        json.dump(results, f, indent=2)
    
if __name__ == "__main__":
    import defopt
    try:
        defopt.run(main)
    except:
        import sys,pdb,bdb
        type, value, tb = sys.exc_info()
        if type == bdb.BdbQuit:
            exit()
        print(type,value)
        pdb.post_mortem(tb)
