import argparse
import json
import pdb
import jsonlines

import util

import re

from fraction import Fraction
# from vllm import LLM, SamplingParams
import sys
import torch
from ntk import load_model
import transformers
from peft import PeftModel, PeftConfig
from tqdm import tqdm

MAX_INT = sys.maxsize
INVALID_ANS = "[invalid]"
device_id = "3"

invalid_outputs = []
def remove_boxed(s):
    left = "\\boxed{"
    try:
        assert s[:len(left)] == left
        assert s[-1] == "}"
        return s[len(left):-1]
    except:
        return None

def process_results(doc, completion, answer):
    split_ans = completion.split('The answer is: ')
    if len(split_ans) > 1:
        ans = split_ans[-1]
        extract_ans_temp = ans.split('.\n')[0]
        extract_ans_temp = extract_ans_temp.strip()
        if len(extract_ans_temp)>0 and extract_ans_temp[-1] == '.':
            extract_ans = extract_ans_temp[0:-1]
        else:
            extract_ans = extract_ans_temp
        extract_ans = extract_ans.strip()
        if util.is_equiv(extract_ans, answer):
            return True
        else:
            return False
    else:
        temp = {'question': doc, 'output': completion, 'answer': answer}
        invalid_outputs.append(temp)
        return False
def batch_data(data_list, batch_size=1):
    n = len(data_list) // batch_size
    batch_data = []
    for i in range(n-1):
        start = i * batch_size
        end = (i+1)*batch_size
        batch_data.append(data_list[start:end])

    last_start = (n-1) * batch_size
    last_end = MAX_INT
    batch_data.append(data_list[last_start:last_end])
    return batch_data

def test_hendrycks_math(model, tokenizer, data_path="./MATH_test.jsonl", start=0, end=MAX_INT, batch_size=64, tensor_parallel_size=1):
    hendrycks_math_ins = []
    hendrycks_math_answers = []
    problem_prompt = (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Response: Let's think step by step."
    )
    print('promt =====', problem_prompt)
    with open(data_path, "r+", encoding="utf8") as f:
        for idx, item in enumerate(jsonlines.Reader(f)):
            temp_instr = problem_prompt.format(instruction=item["instruction"])
            hendrycks_math_ins.append(temp_instr)
            solution = item['output']
            temp_ans = remove_boxed(util.last_boxed_only_string(solution))
            hendrycks_math_answers.append(temp_ans)

    print('total length ===', len(hendrycks_math_ins))
    hendrycks_math_ins = hendrycks_math_ins[start:end]
    hendrycks_math_answers = hendrycks_math_answers[start:end]
    print('lenght ====', len(hendrycks_math_ins))
    batch_hendrycks_math_ins = batch_data(hendrycks_math_ins, batch_size=batch_size)

    stop_tokens = ["Question:", "Question", "USER:", "USER", "ASSISTANT:", "ASSISTANT", "Instruction:", "Instruction", "Response:", "Response"]
    # sampling_params = SamplingParams(temperature=0, top_p=1, max_tokens=2048, stop=stop_tokens)
    # print('sampleing =====', sampling_params)
    # llm = LLM(model=model,tensor_parallel_size=tensor_parallel_size)
    
    model_llm = model.to(f"cuda:{device_id}")
    generation_params = {
        "max_length": 2048,  # Max tokens to generate
        "temperature": 0.0,  # Sampling temperature for creativity
        "top_k": 50,  # Top-k sampling
        "top_p": 1,  # Top-p (nucleus) sampling
        "do_sample": False  # Enable sampling
    }
    res_completions = []
    for idx, (prompt, prompt_answer) in enumerate(tqdm(zip(batch_hendrycks_math_ins, hendrycks_math_answers))):
        if isinstance(prompt, list):
            pass
        else:
            prompt = [prompt]
        inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
        for k in inputs:
            inputs[k] = inputs[k].to(f"cuda:{device_id}")

        # outputs = model_llm.generate(**inputs, **generation_params)

        outputs = model_llm.generate(**inputs, **generation_params)
        # print(outputs)
        for output in outputs:
            generated_text = tokenizer.decode(output, skip_special_tokens=True)
            res_completions.append(generated_text)

    results = []
    for idx, (prompt, completion, prompt_answer) in enumerate(zip(hendrycks_math_ins, res_completions, hendrycks_math_answers)):
        res = process_results(prompt, completion, prompt_answer)
        results.append(res)

    acc = sum(results) / len(results)
    print('len invalid outputs ====', len(invalid_outputs), ', valid_outputs===', invalid_outputs)
    print('start===', start, ', end====',end)
    print('length====', len(results), ', acc====', acc)
    with open("1B-ntk-math.txt", mode="w") as file:
        file.write(str(acc))


def eval(path="Llama-3.2-1B_lora/checkpoint-3000", peft_type="lora"):
    if peft_type == "lora":
        model = transformers.AutoModelForCausalLM.from_pretrained(path)

        model = PeftModel.from_pretrained(model, path)

    elif peft_type == "ntk":
        model = load_model(path)
    else:
        model = transformers.AutoModelForCausalLM.from_pretrained(path)
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        path,
        padding_side="left",
        # use_fast=False,
    )
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.pad_token = tokenizer.eos_token

    test_hendrycks_math(model, tokenizer)

    
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str)  # model path
    parser.add_argument("--data_file", type=str, default='MATH_test.jsonl')  # data path
    parser.add_argument("--start", type=int, default=0) #start index
    parser.add_argument("--end", type=int, default=MAX_INT)  # end index
    parser.add_argument("--batch_size", type=int, default=400)  # batch_size
    parser.add_argument("--tensor_parallel_size", type=int, default=8)  # tensor_parallel_size
    parser.add_argument("--peft_type", type=str, default="lora")  # tensor_parallel_size
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    eval(args.model, args.peft_type)