import os
import json
import time
import argparse
import re
import numpy as np
from tqdm import tqdm
from openai import OpenAI

import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.llm_service import client
from typing import List, Dict, Any, Tuple


# 设置OpenAI客户端
client = client()

def load_math_problems(directory_path: str, topic: str = None, difficulty: str = None, sample_size: int = None) -> List[Dict[str, Any]]:
    """加载MATH数据集的问题
    
    Args:
        directory_path: 数据集路径
        topic: 指定主题，为None时加载特定的四个主题
        difficulty: 难度级别，默认为Level 5
        sample_size: 样本大小
    """
    data = []
    
    # 如果指定了特定主题，只加载该主题的问题
    if topic:
        topic_path = os.path.join(directory_path, topic)
        if not os.path.exists(topic_path):
            raise ValueError(f"主题路径不存在: {topic_path}")
        topics = [topic]
    else:
        # 根据论文要求，只加载这四个特定主题
        target_topics = {
            "counting_and_probability": "Combinatorics & Probability",
            "number_theory": "Number Theory",
            "prealgebra": "Pre-algebra",
            "precalculus": "Pre-calculus"
        }
        topics = [t for t in target_topics.keys() if os.path.isdir(os.path.join(directory_path, t))]
    
    print(f"将评估以下主题: {topics}")
    
    # 只评测在math_wrong.json中错误的文件
    current_dir = os.path.dirname(os.path.abspath(__file__))
    data_path = os.path.join(current_dir, "math_wrong.json")
    with open(data_path, 'r') as f:
        wrong_files = json.load(f)
    
    wrong_filenames = [f['file'] for f in wrong_files["wrong_questions"]]

    # 遍历每个主题目录
    for topic in topics:
        topic_path = os.path.join(directory_path, topic)
        problem_files = [f for f in os.listdir(topic_path) if f.endswith('.json')]
        
        for file_name in problem_files:
            if file_name not in wrong_filenames:
                continue
            file_path = os.path.join(topic_path, file_name)
            try:
                with open(file_path, 'r', encoding='utf-8') as f:
                    problem_data = json.load(f)
                    
                    # 只选择指定难度级别的问题
                    if difficulty and problem_data.get('level') != difficulty:
                        continue
                    # 添加文件名和主题信息
                    problem_data['file_name'] = file_name
                    problem_data['topic'] = topic
                    data.append(problem_data)
            except Exception as e:
                print(f"加载文件时出错 {file_path}: {e}")
    
    print(f"加载了 {len(data)} 个难度为 {difficulty} 的问题")
    
    # 如果指定了样本大小，随机选择样本
    if sample_size and sample_size < len(data):
        import random
        random.seed(42)  # 固定随机种子以确保可重复性
        data = random.sample(data, sample_size)
        print(f"随机选择了 {sample_size} 个问题进行评估")
    
    return data

def extract_boxed_answer(solution_text: str) -> str:
    """从解决方案文本中提取\boxed{}中的答案"""
    # 尝试匹配\boxed{...}格式
    boxed_match = re.search(r'\\boxed{([^{}]+(?:{[^{}]*})*)}', solution_text)
    if boxed_match:
        return boxed_match.group(1)
    
    # 如果没有\boxed{}，尝试查找最后一个数字或表达式
    numbers = re.findall(r'-?\d+(?:\.\d+)?', solution_text)
    if numbers:
        return numbers[-1]
    
    # 如果以上方法都失败，返回空字符串
    return ""

def extract_model_answer(response_text: str) -> str:
    """从模型响应中提取最终答案"""

    # 尝试查找\boxed{}格式的答案
    boxed_match = re.search(r'\\boxed{([^{}]+(?:{[^{}]*})*)}', response_text)
    if boxed_match:
        return boxed_match.group(1)
    
    # 寻找最后一个数字，通常是答案
    lines = response_text.strip().split('\n')
    for line in reversed(lines):
        if "答案是" in line or "answer is" in line or "=" in line:
            # 尝试提取数字
            parts = line.split()
            for part in reversed(parts):
                # 移除可能的标点符号
                part = part.strip(",.。:;()[]")
                # 检查是否为数字形式（可能带有逗号或小数点）
                cleaned_part = part.replace(",", "")
                try:
                    float(cleaned_part)  # 尝试将其转换为浮点数
                    return part  # 如果成功，则返回原始形式（带逗号）
                except ValueError:
                    continue
    
    # 如果上面的方法未找到答案，尝试查找最后出现的数字
    numbers = re.findall(r'-?\d+(?:,\d+)*(?:\.\d+)?', response_text)
    if numbers:
        return numbers[-1]
    
    return ""

def evaluate_problem(agentflow,problem: Dict[str, Any], model: str, temperature: float = 0, max_retries: int = 2) -> Tuple[bool, str, str]:
    """评估单个问题"""
    question = problem["problem"]
    solution = problem["solution"]
    correct_answer = extract_boxed_answer(solution)
    
    # system_prompt = """请解决下面的数学问题。
    # 请直接用"\\boxed{{answer}}"的格式给出最终答案。如果答案是一个表达式或公式，可以使用LaTeX格式，例如 \\boxed{-\\frac{4}{3}}。"""
    
    user_prompt = f"解决下面的数学问题:\n\n{question}, 如果答案是一个表达式或公式，可以使用LaTeX格式，例如 -\\frac{{4}}{{3}}"
    
    retry_count = 0
    while retry_count < max_retries:
        try:
            # response = client.chat.completions.create(
            #     model=model,
            #     messages=[
            #         {"role": "system", "content": system_prompt},
            #         {"role": "user", "content": user_prompt}
            #     ],
            #     temperature=temperature
            # )
            
            response = agentflow.run(user_prompt)
            
            # 如果response是字典且keys中有answer
            if isinstance(response, dict) and "answer" in response:
                model_answer = response["answer"]
                if type(model_answer) != str:
                    model_answer = str(model_answer)
            else:
                model_answer = extract_model_answer(str(response))
            
            # 清理答案进行比较（移除逗号、空格等）
            clean_model_answer = model_answer.replace(",", "").strip()
            clean_correct_answer = correct_answer.replace(",", "").strip()
            
            # 尝试将两个答案转换为浮点数进行比较
            try:
                # 对于数值答案
                model_value = float(clean_model_answer)
                correct_value = float(clean_correct_answer)
                is_correct = abs(model_value - correct_value) < 1e-6
            except ValueError:
                # 对于非数值答案，直接进行字符串比较
                is_correct = clean_model_answer == clean_correct_answer
            
            return is_correct, model_answer, correct_answer
            
        except Exception as e:
            retry_count += 1
            print(f"错误: {e}. 重试 ({retry_count}/{max_retries})...")
            time.sleep(2)  # 等待2秒后重试
    
    # 达到最大重试次数后仍然失败
    return False, None, correct_answer

def Eval(agentflow, model="gpt-4o-mini", temperature=0.0, output_file="results/math_results.json", data_path="Benchmark/MATH/test",level="Level 5", sample_size=None, mode="train"):
    parser = argparse.ArgumentParser(description="使用OpenAI API评估MATH数据集")
    parser.add_argument("--data_path", type=str, default=data_path, help="MATH数据集路径")
    parser.add_argument("--topic", type=str, default=None, help="要评估的数学主题，不指定则评估指定的四个主题")
    parser.add_argument("--difficulty", type=str, default=level, help="要评估的难度级别，默认为Level 5")
    parser.add_argument("--model", type=str, default=model, help="OpenAI模型名称")
    parser.add_argument("--temperature", type=float, default=temperature, help="温度参数")
    parser.add_argument("--output_file", type=str, default=output_file, help="结果输出文件")
    parser.add_argument("--sample_size", type=int, default=None, help="评估样本数量，为None时评估全部")
    args = parser.parse_args()
    
    print(f"加载数据集: {args.data_path}")
    print(f"难度级别: {args.difficulty}")
    data = load_math_problems(args.data_path, args.topic, args.difficulty, args.sample_size)
    
    print(f"评估模型: {args.model}, 样本数量: {len(data)}")
    
    results = []
    correct_count = 0
    topic_stats = {}  # 按主题统计正确率
    wrong_questions = []
    

    for i, problem in enumerate(tqdm(data)):
        topic = problem['topic']
        if topic not in topic_stats:
            topic_stats[topic] = {"total": 0, "correct": 0}
        
        topic_stats[topic]["total"] += 1
        
        is_correct, model_output, correct_answer = evaluate_problem(
            agentflow, problem, args.model, args.temperature)
        
        if is_correct:
            correct_count += 1
            topic_stats[topic]["correct"] += 1
        else:
            # 统计错误的题号
            wrong_questions.append({
                "file": problem['file_name'],
                "topic": topic,
                "problem": problem['problem'],
                "correct_answer": correct_answer,
                "model_answer": extract_model_answer(model_output)
            })
            # 输出错误的问题和答案
            print(f"题目: {problem['problem']}")
            print(f"正确答案: {correct_answer}")
            print(f"模型答案: {model_output}")
            print("-" * 50)

            if model_output is None:
                break
        
        results.append({
            "file_name": problem["file_name"],
            "topic": topic,
            "problem": problem["problem"],
            "model_output": model_output,
            "correct_answer": correct_answer,
            "model_answer": extract_model_answer(model_output),
            "is_correct": is_correct
        })
        
        # 每10个问题保存一次中间结果
        if (i + 1) % 10 == 0:
            accuracy = correct_count / (i + 1)
            print(f"当前进度: {i+1}/{len(data)}, 总体准确率: {accuracy:.4f}")
            
            # 计算每个主题的准确率
            topic_accuracies = {}
            for t, stats in topic_stats.items():
                if stats["total"] > 0:
                    topic_accuracies[t] = stats["correct"] / stats["total"]
            
            # 保存中间结果
            with open(args.output_file, 'w', encoding='utf-8') as f:
                json.dump({
                    "model": args.model,
                    "overall_accuracy": accuracy,
                    "correct_count": correct_count,
                    "total_count": i + 1,
                    "topic_accuracies": topic_accuracies,
                    "wrong_questions": wrong_questions
                }, f, ensure_ascii=False, indent=2)
    
    final_accuracy = correct_count / len(data)
    print(f"最终准确率: {final_accuracy:.4f} ({correct_count}/{len(data)})")
    
    # 计算每个主题的最终准确率
    topic_accuracies = {}
    for topic, stats in topic_stats.items():
        if stats["total"] > 0:
            topic_accuracies[topic] = stats["correct"] / stats["total"]
            print(f"{topic} 准确率: {topic_accuracies[topic]:.4f} ({stats['correct']}/{stats['total']})")
    
    # 保存最终结果
    with open(args.output_file, 'w', encoding='utf-8') as f:
        json.dump({
            "model": args.model,
            "overall_accuracy": final_accuracy,
            "correct_count": correct_count,
            "total_count": len(data),
            "topic_accuracies": topic_accuracies,
            "wrong_questions": wrong_questions
        }, f, ensure_ascii=False, indent=2)

    return final_accuracy
