import torch
from tqdm import tqdm
import pandas as pd
# from eval.utils import load_dexperts_model_and_tokenizer, load_dexperts_model_and_tokenizer_vllm
# from analysis.utils import flatten_batch_results, summarize_results, trim_output
from vllm import LLM, SamplingParams
import jsonlines
import os
# from vllm_inject import sequence_inject, sample_output_inject, model_runner_inject, llm_engine_inject, scheduler_inject, config_inject
# from vllm_inject.utils import *
import json, re
import evaluate
from transformers import AutoTokenizer

def trim_output(output):
    instruction_prefix = "Answer the following question"
    question_prefix = 'Question:'
    comment_prefix = 'Comment:'  # for some reason, Llama 13B likes to generate these comments indefinitely

    for prefix in [instruction_prefix, question_prefix, comment_prefix]:
        if prefix in output:
            output = output.split(prefix)[0]

    return output

def get_equation_lhs_rhs_indices(tokens):
    """
    Returns two lists of indices, one for tokens in the LHS of equations and one for those in the RHS.

    Args:
        tokens: list of str
    """
    equal_indices = [i for i, x in enumerate(tokens) if x == '=']
    lhs_idx, rhs_idx = [], []

    for equal_idx in equal_indices:
        # go left until it's no longer a number or symbol
        left_idx, right_idx = equal_idx - 1, equal_idx + 1
        while True:
            if left_idx < 0 or not (tokens[left_idx].isdigit() or tokens[left_idx] in ",$€+-x*/"):
                break
            lhs_idx.append(left_idx)
            left_idx -= 1

        # go right until it's no longer a number or symbol
        while True:
            if right_idx >= len(tokens) or \
                 not (tokens[right_idx].isdigit() or tokens[right_idx] in ",$€+-x*/"):
                break
            rhs_idx.append(right_idx)
            right_idx += 1

    return lhs_idx, rhs_idx

@torch.inference_mode()
def get_gsm_output(base_model,
                   tokenizer,
                   max_tokens,
                   batch_size,
                   temperature,
                   top_p,
                   use_chat_format=False,
                   icl=False):
    print("Loading data...")
    test_data = []
    with open("/xx/data/eval/gsm/test.jsonl") as fin:
        for line in fin:
            example = json.loads(line)
            test_data.append({
                "question": example["question"],
                "answer": example["answer"].split("####")[1].strip()
            })
            # if len(test_data) >= 1:break

    # some numbers are in the `x,xxx` format, and we want to remove the comma
    for example in test_data:
        example["answer"] = re.sub(r"(\d),(\d)", r"\1\2", example["answer"])
        assert float(example["answer"]), f"answer is not a valid number: {example['answer']}"
    prompt_prefix = "Answer the following question.\n\n"
    if icl == True:
        icl_context = ""
        with open("/xx/data/eval/gsm/train.jsonl", "r") as f:
            tot = 0
            for i in f:
                tot += 1
                data = json.loads(i)
                icl_context = icl_context + prompt_prefix + "Question: " + data["prompt"].strip() + " Answer:" + data["completion"].strip() + "\n"
                if tot >= 5:break
        prompt_prefix = icl_context + prompt_prefix
    prompts = []
    chat_formatting_function = eval.templates.create_prompt_with_tulu_chat_format if use_chat_format else None
    for example in test_data:
        prompt = prompt_prefix + "Question: " + example["question"].strip()
        if use_chat_format:
            messages = [{"role": "user", "content": prompt}]
            prompt = chat_formatting_function(messages, add_bos=False)
            if prompt[-1] in ["\n", " "]:
                prompt += "Answer:"
            else:
                prompt += " Answer:"
        else:
            prompt += "\nAnswer:"
        prompts.append(prompt)
    with open("example_prompt.txt", 'w') as fout:
        fout.write(prompts[0])
    sampling_params = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=max_tokens)
    all_results = []
    for i in tqdm(range(0, len(prompts), batch_size), desc="Batches"):
        batch_prompts = prompts[i: i + batch_size]
        base_output = base_model.generate(batch_prompts, sampling_params)
        for j in range(len(base_output)):
            all_results.append(
                {"inputs": batch_prompts[j],
                 "output": base_output[j].outputs[0].text}
            )
    return test_data, all_results

@torch.inference_mode()
def main(*,
         model_name: str = "meta-llama/Llama-2-13b-hf",
         batch_size: int = 1024,
         temperature: float = 0.1,
         top_p: float = 0.9,
         tensor_parallel_size : int = 1,
         max_num_seqs : int = 256,
         max_tokens : int = 256,
         save_dir: str = "outputs/gsm",
         icl: int = 0):
    # load model
    # clear_share_io()
    icl_type = False if icl == 0 else True
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    exact_match = evaluate.load("exact_match")
    base_model = LLM(model=model_name, tensor_parallel_size=tensor_parallel_size, gpu_memory_utilization=0.95, enforce_eager=True, max_num_seqs=max_num_seqs)
    test_data, all_results = get_gsm_output(base_model, tokenizer, max_tokens, batch_size, temperature, top_p, icl=icl_type)
    outputs = [trim_output(o["output"]) for o in all_results]

    predictions = []
    for output in outputs:
        # replace numbers like `x,xxx` with `xxxx`
        output = re.sub(r"(\d),(\d)", r"\1\2", output)
        numbers = re.findall(r"[-+]?\d*\.\d+|\d+", output)
        if numbers:
            predictions.append(numbers[-1])
        else:
            predictions.append(output)
    
    print("Calculating accuracy...")
    targets = [example["answer"] for example in test_data]

    em_score = exact_match.compute(predictions=predictions, references=targets, ignore_case=True, ignore_punctuation=True)["exact_match"]
    print(f"Exact match : {em_score}")

    predictions = [{
        "question": example["question"],
        "answer": example["answer"],
        "model_output": output,
        "prediction": pred
    } for example, output, pred in zip(test_data, outputs, predictions)]
    with open(f"predictions_{model_name.replace('/', '#')}.jsonl", "w") as fout:
        for prediction in predictions:
            fout.write(json.dumps(prediction) + "\n")
    
if __name__ == "__main__":
    import defopt
    try:
        defopt.run(main)
    except:
        import sys,pdb,bdb
        type, value, tb = sys.exc_info()
        if type == bdb.BdbQuit:
            exit()
        print(type,value)
        pdb.post_mortem(tb)
