import os
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from tqdm import tqdm
import torch
import argparse


def extract_ans(ans_model):
    ans_model = ans_model.split("\n")
    ans = []
    residual = []
    for li, al in enumerate(ans_model):
        ans.append(al)
        if "answer is" in al:
            break
    residual = list(ans_model[li + 1 :])
    ans = "\n".join(ans)
    residual = "\n".join(residual)
    return ans, residual

def parse_pred_ans(filename):
    with open(filename) as fd:
        lines = fd.readlines()
    am, a = None, None
    num_q, acc = 0, 0
    current_mode = "none"
    questions = []
    ans_pred = []
    ans_gold = []
    for l in lines:
        l = l.replace(",", "")
        if l.startswith("Q: "):
            if am is not None and a is not None:
                questions.append(q)
                ans_pred.append(am)
                ans_gold.append(a)
                if test_answer(am, a):
                    acc += 1
            current_mode = "q"
            q = l
            num_q += 1
        elif l.startswith("A_model:"):
            current_mode = "am"
            am = l
        elif l.startswith("A:"):
            current_mode = "a"
            a = l
        else:
            if current_mode == "q":
                q += l
            elif current_mode == "am":
                am += l
            elif current_mode == "a":
                a += l
            else:
                raise ValueError(current_mode)

    questions.append(q)
    ans_pred.append(am)
    ans_gold.append(a)
    if test_answer(am, a):
        acc += 1
    print("num_q %d correct %d ratio %.4f" % (num_q, acc, float(acc / num_q)))
    return questions, ans_pred, ans_gold

def get_result(text: str):
    pattern = "\d*\.?\d+"
    res = re.findall(pattern, text)
    return res[-1] if res else ""


def test_answer(pred_str, ans_str):
    pred, gold = get_result(pred_str), get_result(ans_str)
    return pred == gold



parser = argparse.ArgumentParser(description="Run model with configurable paths and ratios")
parser.add_argument("--big_model_path", type=str, required=True, help="Path to the big model (e.g., Qwen2-7B)")
parser.add_argument("--small_model_path", type=str, required=True, help="Path to the small model (e.g., Qwen2-0.5B)")
parser.add_argument("--heavy_budget_ratio", type=float, default=0.2, help="Heavy budget ratio")
parser.add_argument("--recent_budget_ratio", type=float, default=0.2, help="Recent budget ratio")
parser.add_argument("--compensate_budget_ratio", type=float, default=0.2, help="Compensate budget ratio")
parser.add_argument("--gsm8k_path", type=str, required=True, help="Path to GSM8K dataset")
parser.add_argument("--prompt_file", type=str, required=True, help="Path to prompt file")
parser.add_argument("--model_series", type=str, required=True, help="qwen or llama series")
args = parser.parse_args()

# 加载模型和数据
config = AutoConfig.from_pretrained(args.big_model_path)
tokenizer = AutoTokenizer.from_pretrained(args.big_model_path)

small_model = AutoModelForCausalLM.from_pretrained(
    args.small_model_path,
    torch_dtype="auto",
    device_map="auto",
    attn_implementation="eager"
)

big_model = AutoModelForCausalLM.from_pretrained(
    args.big_model_path,
    torch_dtype="auto",
    device_map="auto",
    attn_implementation="eager"
)


config.heavy_budget_ratio = args.heavy_budget_ratio
config.recent_budget_ratio = args.recent_budget_ratio
config.compensate_budget_ratio = args.compensate_budget_ratio

if args.model_series == "qwen":
    from modify_qwen import enable_qwen2_small_model
    enable_qwen2_small_model(small_model, big_model, config)
elif args.model_series == "llama":
    from modify_llama import enable_llama_small_model
    enable_llama_small_model(small_model, big_model, config)   


gsm8k = load_dataset(args.gsm8k_path)
gsm8k_test = gsm8k["test"]


with open(args.prompt_file, 'r') as f:
    c_prompt = f.read()

device = 'cuda:0'
file_path = "gsm8k_outputs/result.txt"
if os.path.exists(file_path):
    os.remove(file_path)

model = small_model


for q, a in tqdm(
    zip(gsm8k_test["question"][:], gsm8k_test["answer"][:]), total=len(gsm8k_test["question"][:])
):
    instruction = (
        "Please reference the following examples to answer the math question,\n"
    )
    prompt = (
        instruction
        + c_prompt
        + "\n\nQuestion: "
        + q
        + "\n"
    )

    model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
    prompt_length = len(prompt)
    generated_ids = model.generate(
        model_inputs.input_ids,
        max_new_tokens=500,
        do_sample=False,
    )

    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    generated_answer = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    ans_, residual = extract_ans(generated_answer)
    with open(f"gsm8k_outputs/result.txt", "a") as fd:
        fd.write(
            "Q: %s\nA_model:\n%s\nA:\n%s\n\n"
            % (q, ans_.replace("Q:", "").replace("A:", ""), a)
        )
    _ = parse_pred_ans(f"gsm8k_outputs/result.txt")