import torch
from tqdm import tqdm
import pandas as pd
# from eval.utils import load_dexperts_model_and_tokenizer, load_dexperts_model_and_tokenizer_vllm
# from analysis.utils import flatten_batch_results, summarize_results, trim_output
from vllm import LLM, SamplingParams
from openai import OpenAI
import jsonlines
import os
from vllm_inject import sequence_inject, sample_output_inject, model_runner_inject, llm_engine_inject, scheduler_inject, config_inject
from vllm_inject.utils import *
import json, re
import evaluate

def trim_output(output):
    instruction_prefix = "Answer the following question"
    question_prefix = 'Question:'
    comment_prefix = 'Comment:'  # for some reason, Llama 13B likes to generate these comments indefinitely

    for prefix in [instruction_prefix, question_prefix, comment_prefix]:
        if prefix in output:
            output = output.split(prefix)[0]

    return output

def get_equation_lhs_rhs_indices(tokens):
    """
    Returns two lists of indices, one for tokens in the LHS of equations and one for those in the RHS.

    Args:
        tokens: list of str
    """
    equal_indices = [i for i, x in enumerate(tokens) if x == '=']
    lhs_idx, rhs_idx = [], []

    for equal_idx in equal_indices:
        # go left until it's no longer a number or symbol
        left_idx, right_idx = equal_idx - 1, equal_idx + 1
        while True:
            if left_idx < 0 or not (tokens[left_idx].isdigit() or tokens[left_idx] in ",$€+-x*/"):
                break
            lhs_idx.append(left_idx)
            left_idx -= 1

        # go right until it's no longer a number or symbol
        while True:
            if right_idx >= len(tokens) or \
                 not (tokens[right_idx].isdigit() or tokens[right_idx] in ",$€+-x*/"):
                break
            rhs_idx.append(right_idx)
            right_idx += 1

    return lhs_idx, rhs_idx


@torch.inference_mode()
def get_gsm_output(base_model,
                   batch_size,
                   temperature,
                   top_p,
                   use_chat_format=False):
    print("Loading data...")
    test_data = []
    # with open("/xx/data/eval/gsm/test.jsonl") as fin:
    #     for line in fin:
    #         example = json.loads(line)
    #         test_data.append({
    #             "question": example["question"],
    #             "answer": example["answer"].split("####")[1].strip()
    #         })
            # if len(test_data) >= 100: break
    with open("/xx/data/eval/math/math_test.jsonl") as fin:
        for line in fin:
            example = json.loads(line)
            # import pdb;pdb.set_trace()
            matches = re.findall(r'\\boxed\{(.*?(?:\{.*?\}.*?)*?)\}', example["output"], re.DOTALL)
            test_data.append({ 
                "question": example["instruction"],
                "answer": matches[-1]
            })
    # some numbers are in the `x,xxx` format, and we want to remove the comma
    # for example in test_data:
        # example["answer"] = re.sub(r"(\d),(\d)", r"\1\2", example["answer"])
        # assert float(example["answer"]), f"answer is not a valid number: {example['answer']}"
    prompt_prefix = "Answer the following question.\n\n"

    prompts = []
    chat_formatting_function = eval.templates.create_prompt_with_tulu_chat_format if use_chat_format else None
    for example in test_data:
        prompt = prompt_prefix + "Question: " + example["question"].strip()
        if use_chat_format:
            messages = [{"role": "user", "content": prompt}]
            prompt = chat_formatting_function(messages, add_bos=False)
            if prompt[-1] in ["\n", " "]:
                prompt += "Answer:"
            else:
                prompt += " Answer:"
        else:
            prompt += "\nAnswer:"
        prompts.append(prompt)
        
    sampling_params = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=1024)
    all_results = []
    for i in tqdm(range(0, len(prompts), batch_size), desc="Batches"):
        batch_prompts = prompts[i: i + batch_size]
        base_output = base_model.generate(batch_prompts, sampling_params)
        for j in range(len(base_output)):
            all_results.append(
                {"inputs": batch_prompts[j],
                 "output": base_output[j].outputs[0].text,
                 "logits": base_output[j].outputs[0].logits_list}
            )
    return test_data, all_results

@torch.inference_mode()
def main(*,
         model_name: str = "meta-llama/Llama-2-13b-hf",
         batch_size: int = 1024,
         temperature: float = 0.1,
         top_p: float = 0.9,
         tensor_parallel_size : int = 1,
         save_dir: str = "outputs/gsm"):
    # load model
    # clear_share_io()
    exact_match = evaluate.load("exact_match")
    base_model = LLM(model=model_name, tensor_parallel_size=tensor_parallel_size, gpu_memory_utilization=0.95, enforce_eager=True, max_num_seqs=110)
    test_data, all_results = get_gsm_output(base_model, batch_size, temperature, top_p)
    outputs = [trim_output(o["output"]) for o in all_results]

    predictions = []
    for output in outputs:
        # replace numbers like `x,xxx` with `xxxx`
        # output = re.sub(r"(\d),(\d)", r"\1\2", output)
        matches = re.findall(r'\\boxed\{(.*?(?:\{.*?\}.*?)*?)\}',output, re.DOTALL)
        if len(matches) != 0:
            predictions.append(matches[-1])
        else:
            predictions.append(output)
    
    print("Calculating accuracy...")
    targets = [example["answer"] for example in test_data]

    em_score = exact_match.compute(predictions=predictions, references=targets, ignore_case=True, ignore_punctuation=True)["exact_match"]
    print(f"Exact match : {em_score}")

    predictions = [{
        "question": example["question"],
        "answer": example["answer"],
        "model_output": output,
        "prediction": pred
    } for example, output, pred in zip(test_data, outputs, predictions)]
    if model_name == "meta-llama/Llama-2-13b-hf":
        with open("predictions.jsonl", "w") as fout:
            for prediction in predictions:
                fout.write(json.dumps(prediction) + "\n")
    
if __name__ == "__main__":
    import defopt
    try:
        defopt.run(main)
    except:
        import sys,pdb,bdb
        type, value, tb = sys.exc_info()
        if type == bdb.BdbQuit:
            exit()
        print(type,value)
        pdb.post_mortem(tb)
