import json
import vllm
import torch
from tqdm import tqdm
import re
import os
import pickle
import time
import argparse
import random
import math

from lm_eval_utils import GSM8KEvaluator
from prompts import GSM8KPrompts, llama_assistant_turn_end

GSM8K_STOP_SEQUENCES = [
    "<|eot_id|>",
    "<|start_header_id|>user<|end_header_id|>",
    "Q:",
    "</s>",
    "<|im_end|>",
]
FORMATTING_PROMPT = 'Your response should end with "The final answer is [answer]" where [answer] is the response to the problem.'
random.seed(42)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("-o", dest="out_file", required=True)
    parser.add_argument("--target_model", required=True)
    parser.add_argument("--draft_model", required=True)
    parser.add_argument("--shots", type=int, choices=[0, 8], required=True)
    parser.add_argument("--no_spec_dec", action="store_true")
    parser.add_argument("--no_judge", action="store_true")
    parser.add_argument("--data_size", type=float, default=1.0)
    parser.add_argument("--window_size", type=int, default=32)
    parser.add_argument("--judge_path", required=False)
    parser.add_argument("--judge_threshold", type=float, required=False)
    parser.add_argument("data_file")
    return parser.parse_args()


def extract_answer(s, suffix="<|eot_id|>"):
    s = s.lower().replace(suffix, "").replace("the final answer is", "=")
    idx = s.rfind("=")
    if idx != -1:
        return s[idx + 1 :].strip()


def extract_float(num_str):
    try:
        num_str = re.sub(r"[^0-9.-]", "", num_str).strip(".")
        return float(num_str)
    except (ValueError, TypeError):
        return


def load_questions(gsm8k_test_path, data_size):
    with open(gsm8k_test_path) as f:
        gsm_questions = [json.loads(line) for line in f]

    gsm_questions = [
        {
            "question": i["question"],
            "answer": i["answer"][i["answer"].rfind("#### ") + 5 :],
        }
        for i in gsm_questions
    ]
    if data_size == 1.0:
        return gsm_questions
    selected_questions = random.sample(
        gsm_questions, math.ceil(data_size * len(gsm_questions))
    )
    return selected_questions


def eval(
    gsm8k_test_path,
    target_model,
    draft_model,
    eval_shots,
    use_spec_dec,
    data_size,
    window_size,
    judge_config
):
    sampling_params = vllm.SamplingParams(
        temperature=0.0, top_p=1.0, max_tokens=1024, stop=GSM8K_STOP_SEQUENCES
    )
    if use_spec_dec:
        if judge_config is not None:
            speculative_config = {
                "model": draft_model,
                "num_speculative_tokens": window_size,
                "judge_config": judge_config,
            }
        else:
            speculative_config = {
                "model": draft_model,
                "num_speculative_tokens": window_size,
            }
    else:
        speculative_config = None
    llm = vllm.LLM(
        model=target_model,
        dtype="bfloat16",
        # dtype="float32",
        gpu_memory_utilization=0.8,
        tensor_parallel_size=torch.cuda.device_count(),
        speculative_config=speculative_config,
    )

    problems = load_questions(gsm8k_test_path, data_size)

    correct = 0
    all = 0

    outputs = []
    total_time = 0.0
    if eval_shots == 0:
        meta_prompt = GSM8KPrompts.prompt_with_0_shots
    elif eval_shots == 8:
        meta_prompt = GSM8KPrompts.prompt_with_8_shots
    for problem in tqdm(problems):
        question = problem["question"]
        formatted_prompt = (
            meta_prompt
            + question
            + "\n"
            + GSM8KPrompts.formatting_prompt
            + llama_assistant_turn_end
        )

        start_time = time.time()
        cur_outputs = llm.generate([formatted_prompt], sampling_params, use_tqdm=False)
        end_time = time.time()
        torch.cuda.empty_cache()
        outputs += cur_outputs
        elapsed = end_time - start_time
        total_time += elapsed

    prompt_tokens = sum(len(out.prompt_token_ids) for out in outputs)
    output_tokens = sum(
        len(out.token_ids) for output in outputs for out in output.outputs
    )

    evaluator = GSM8KEvaluator()

    for problem, output in zip(problems, outputs):
        answer = problem["answer"]
        verdict = (
            evaluator(generations=[output.outputs[0].text], references=[answer]) == 1.0
        )
        correct += verdict
        all += 1

    return correct / all, total_time, prompt_tokens, output_tokens

def get_judge_config(args):
    from vllm.config import JudgeConfig
    with open(args.judge_path, "rb") as f:
        data = pickle.load(f)
        judge_config = JudgeConfig(
            weights=data["weights"],
            mean=data["mean"],
            scale=data["scale"],
            bias=data["bias"],
            threshold=args.judge_threshold,            
        )
    return judge_config

if __name__ == "__main__":
    args = parse_args()
    gsm8k_test_path = args.data_file
    out_file = args.out_file
    if not args.no_judge:
        judge_config = get_judge_config(args)
    else:
        judge_config = None
    res = eval(
        gsm8k_test_path,
        args.target_model,
        args.draft_model,
        args.shots,
        not args.no_spec_dec,
        args.data_size,
        args.window_size,
        judge_config
    )

    autojudge_threshold = args.judge_threshold
    with open(out_file, "a") as f:
        print(autojudge_threshold, *res, file=f)
