import torch
from tqdm import tqdm
import pandas as pd
# from eval.utils import load_dexperts_model_and_tokenizer, load_dexperts_model_and_tokenizer_vllm
# from analysis.utils import flatten_batch_results, summarize_results, trim_output
from vllm import LLM, SamplingParams
from openai import OpenAI
import jsonlines
import os
from vllm_inject import sequence_inject, sample_output_inject, model_runner_inject, llm_engine_inject, scheduler_inject, config_inject
from vllm_inject.utils import *
import json, re
import evaluate
from transformers import AutoTokenizer

def trim_output(output):
    instruction_prefix = "Answer the following question"
    question_prefix = 'Question:'
    comment_prefix = 'Comment:'  # for some reason, Llama 13B likes to generate these comments indefinitely

    for prefix in [instruction_prefix, question_prefix, comment_prefix]:
        if prefix in output:
            output = output.split(prefix)[0]

    return output

def get_equation_lhs_rhs_indices(tokens):
    """
    Returns two lists of indices, one for tokens in the LHS of equations and one for those in the RHS.

    Args:
        tokens: list of str
    """
    equal_indices = [i for i, x in enumerate(tokens) if x == '=']
    lhs_idx, rhs_idx = [], []

    for equal_idx in equal_indices:
        # go left until it's no longer a number or symbol
        left_idx, right_idx = equal_idx - 1, equal_idx + 1
        while True:
            if left_idx < 0 or not (tokens[left_idx].isdigit() or tokens[left_idx] in ",$€+-x*/"):
                break
            lhs_idx.append(left_idx)
            left_idx -= 1

        # go right until it's no longer a number or symbol
        while True:
            if right_idx >= len(tokens) or \
                 not (tokens[right_idx].isdigit() or tokens[right_idx] in ",$€+-x*/"):
                break
            rhs_idx.append(right_idx)
            right_idx += 1

    return lhs_idx, rhs_idx


@torch.inference_mode()
def get_gsm_output(base_model,
                    tokenizer,
                    max_tokens,
                    batch_size,
                    temperature,
                    top_p,
                    use_chat_format=False,
                    system_prompt=None,
                    icl=False):
    print("Loading data...")
    test_data = []
    with open("/xx/data/eval/gsm/test.jsonl") as fin:
        for line in fin:
            example = json.loads(line)
            test_data.append({
                "question": example["question"],
                "answer": example["answer"].split("####")[1].strip()
            })
            # if len(test_data) >= 1 : break
    # some numbers are in the `x,xxx` format, and we want to remove the comma
    for example in test_data:
        example["answer"] = re.sub(r"(\d),(\d)", r"\1\2", example["answer"])
        # assert float(example["answer"]), f"answer is not a valid number: {example['answer']}"
    prompt_prefix = "Answer the following question.\n\n"
    if icl == True:
        icl_context = ""
        with open("/xx/data/eval/gsm/train.jsonl", "r") as f:
            tot = 0
            for i in f:
                tot += 1
                data = json.loads(i)
                icl_context = icl_context + prompt_prefix + "Question: " + data["prompt"].strip() + " Answer:" + data["completion"].strip() + "\n"
                if tot >= 5:break
        prompt_prefix = icl_context + prompt_prefix
    prompts = []
    chat_formatting_function = eval.templates.create_prompt_with_tulu_chat_format if use_chat_format else None
    for example in test_data:
        prompt = prompt_prefix + "Question: " + example["question"].strip()
        if use_chat_format:
            messages = [{"role": "user", "content": prompt}]
            prompt = chat_formatting_function(messages, add_bos=False)
            if prompt[-1] in ["\n", " "]:
                prompt += "Answer:"
            else:
                prompt += " Answer:"
        else:
            if system_prompt:
                prompt = system_prompt + prompt
            prompt += "\nAnswer:"
        prompts.append(prompt)
        
    sampling_params = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=max_tokens)
    all_results = []
    for i in tqdm(range(0, len(prompts), batch_size), desc="Batches"):
        batch_prompts = prompts[i: i + batch_size]
        base_output = base_model.generate(batch_prompts, sampling_params)
        for j in range(len(base_output)):
            all_results.append(
                {"inputs": batch_prompts[j],
                 "output": base_output[j].outputs[0].text,
                 "logits": base_output[j].outputs[0].logits_list}
            )
    return test_data, all_results

@torch.inference_mode()
def main(*,
         model_name: str = "meta-llama/Llama-2-13b-hf",
         batch_size: int = 1024,
         temperature: float = 0.1,
         top_p: float = 0.9,
         tensor_parallel_size : int = 1,
         max_num_seqs : int = 256,
         max_tokens : int = 256,
         save_dir: str = "outputs/gsm",
         system_prompt_type: int = 1,
         icl: int = 0):
    # load model
    # clear_share_io()
    icl_type = False if icl == 0 else True
    # if "chat" in model_name:
    system_prompt = "[INST]You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.[/INST]"
    # else:
        # system_prompt = None
    exact_match = evaluate.load("exact_match")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    base_model = LLM(model=model_name, tensor_parallel_size=tensor_parallel_size, gpu_memory_utilization=0.95, enforce_eager=True, max_num_seqs=max_num_seqs)
    test_data, all_results = get_gsm_output(base_model, tokenizer, max_tokens, batch_size, temperature, top_p, system_prompt=system_prompt, icl=icl_type)
    outputs = [trim_output(o["output"]) for o in all_results]

    predictions = []
    for output in outputs:
        # replace numbers like `x,xxx` with `xxxx`
        output = re.sub(r"(\d),(\d)", r"\1\2", output)
        numbers = re.findall(r"[-+]?\d*\.\d+|\d+", output)
        if numbers:
            predictions.append(numbers[-1])
        else:
            predictions.append(output)
    
    print("Calculating accuracy...")
    targets = [example["answer"] for example in test_data]

    em_score = exact_match.compute(predictions=predictions, references=targets, ignore_case=True, ignore_punctuation=True)["exact_match"]
    print(f"Exact match : {em_score}")
    if model_name == "meta-llama/Llama-2-13b-hf":
        with open(os.path.join(save_dir, "alpha.txt"), "a") as fout:
            fout.write(str(em_score) + "\n")
    predictions = [{
        "question": example["question"],
        "answer": example["answer"],
        "model_output": output,
        "prediction": pred
    } for example, output, pred in zip(test_data, outputs, predictions)]
    if model_name == "meta-llama/Llama-2-13b-hf":
        with open(os.path.join(save_dir, "predictions.jsonl"), "w") as fout:
            for prediction in predictions:
                fout.write(json.dumps(prediction) + "\n")
    
if __name__ == "__main__":
    import defopt
    try:
        defopt.run(main)
    except:
        import sys,pdb,bdb
        type, value, tb = sys.exc_info()
        if type == bdb.BdbQuit:
            exit()
        print(type,value)
        pdb.post_mortem(tb)
