import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import re
import json
import argparse
import os


def load_dataset_from_file(file_path):
    """从 .json 或 .jsonl 文件加载数据集"""
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"错误: 数据集文件未找到于 '{file_path}'")
    
    data = []
    try:
        if file_path.endswith('.jsonl'):
            with open(file_path, 'r', encoding='utf-8') as f:
                for line in f:
                    if line.strip():
                        data.append(json.loads(line))
        elif file_path.endswith('.json'):
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
                if not isinstance(data, list):
                    raise TypeError("JSON 文件应包含一个对象列表。")
        else:
            raise ValueError("不支持的文件格式。请使用 .json 或 .jsonl。")
    except json.JSONDecodeError as e:
        raise ValueError(f"解析文件 '{file_path}' 时出错: {e}")
    
    print(f"成功从 '{file_path}' 加载 {len(data)} 条数据。")
    return data

def build_prompt(item):
    body = item["Body"]
    question = item["Question"]
    return f"""Solve the following math word problem.
Your final answer must be a single number.

Problem:
{body}. {question}

Final Answer:"""

def parse_answer(model_output):
    matches = re.findall(r'-?\d+\.?\d*', model_output)
    if matches:
        last_match = matches[-1]
        try:
            return float(last_match)
        except (ValueError, TypeError):
            return None
    else:
        return None

def calculate_ground_truth(equation_str):
    try:
        result = eval(equation_str)
        return float(result)
    except Exception as e:
        print(f"\n[警告] 无法计算方程: '{equation_str}'. 错误: {e}. 将跳过此样本。")
        return None

def main(args):
    # --- 1. 加载数据集 ---
    dataset = load_dataset_from_file(args.dataset_path)

    # --- 2. 加载模型和分词器 ---
    print(f"正在加载模型: {args.model_path}")
    model = AutoModelForCausalLM.from_pretrained(
        args.model_path,
        torch_dtype="auto",
        device_map="auto"
    )
    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    print("模型加载完成。")

    # --- 3. 执行批量评测 ---
    results = []
    correct_count = 0
    total_output_length = 0
    skipped_count = 0

    print(f"\n开始评测... (共 {len(dataset)} 条数据, 批大小: {args.batch_size})")

    for i in tqdm(range(0, len(dataset), args.batch_size)):
        batch_data = dataset[i:i+args.batch_size]
        prompts = [build_prompt(item) for item in batch_data]
        
        messages = [[{"role": "user", "content": p}] for p in prompts]
        texts = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        
        inputs = tokenizer(texts, return_tensors="pt", padding=True).to(model.device)
        outputs = model.generate(
            **inputs,
            max_new_tokens=2048,
            do_sample=False
        )
        
        output_tokens = outputs[:, inputs.input_ids.shape[1]:]
        generated_texts = tokenizer.batch_decode(output_tokens, skip_special_tokens=True)
        
        # print("generated_texts:",generated_texts)

        for j, gen_text in enumerate(generated_texts):
            item = batch_data[j]
            ground_truth_answer = calculate_ground_truth(item["Equation"])

            if ground_truth_answer is None:
                skipped_count += 1
                result_item = { "Is_Correct": "Skipped", "Calculated_Ground_Truth": "N/A (Skipped)" }
            else:
                total_output_length += len(gen_text)
                print("len(gen_text):",len(gen_text))
                parsed_model_answer = parse_answer(gen_text)
                is_correct = False
                # print("parsed_model_answer:",parsed_model_answer)
                # print("ground_truth_answer:",ground_truth_answer)
                if parsed_model_answer is not None and abs(parsed_model_answer - ground_truth_answer) < 1e-6:
                    correct_count += 1
                    is_correct = True
                result_item = {
                    "Is_Correct": is_correct,
                    "Calculated_Ground_Truth": ground_truth_answer,
                    "Parsed_Answer": parsed_model_answer,
                }
            
            # 存储完整结果
            results.append({
                "ID": item["ID"],
                "Question": item["Question"],
                "Equation": item["Equation"],
                "Model_Output_Raw": gen_text.strip(),
                **result_item
            })

    # --- 4. 计算并打印最终结果 ---
    total_items = len(dataset)
    total_evaluable_items = total_items - skipped_count

    accuracy = (correct_count / total_evaluable_items) * 100 if total_evaluable_items > 0 else 0.0
    avg_output_length = total_output_length / total_evaluable_items if total_evaluable_items > 0 else 0.0

    print("\n--- 评测完成 ---")
    print(f"总样本数: {total_items}")
    print(f"跳过 (方程无法计算) 的样本数: {skipped_count}")
    print(f"有效评测样本数: {total_evaluable_items}")
    print(f"正确数量: {correct_count}")
    print(f"✅ 正确率 (基于有效样本): {accuracy:.2f}%")
    print(f"✍️ 平均输出字符长度 (基于有效样本): {avg_output_length:.2f} chars")

    # --- 5. 保存结果到文件 ---
    if args.output_file:
        try:
            with open(args.output_file, 'w', encoding='utf-8') as f:
                for res in results:
                    f.write(json.dumps(res) + '\n')
            print(f"\n📄 详细结果已保存到: {args.output_file}")
        except IOError as e:
            print(f"\n[错误] 无法写入输出文件: {e}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="使用本地大模型批量评测SVAMP数据集。")
    parser.add_argument("--dataset_path", type=str, default="",
                        help="指向SVAMP数据集文件的路径 (.json or .jsonl)。")
    parser.add_argument("--model_path", type=str, default="",
                        help="本地模型路径或Hugging Face模型名称。请务必修改此默认值！")
    parser.add_argument("--batch_size", type=int, default=16,
                        help="推理时使用的批处理大小。")
    parser.add_argument("--output_file", type=str, default="",
                        help="用于保存详细评测结果的输出文件路径 (JSONL格式)。")

    cli_args = parser.parse_args()


    if cli_args.model_path == "/path/to/your/Qwen3-8B-model":
        print("🚨 警告: 您正在使用默认的模型路径。")
        print("请使用 --model_path 参数指定您本地的Qwen3-8B模型路径。")
        # exit(1) 

    main(cli_args)