import os
import json
import time
import random
import argparse
import numpy as np
import traceback
from tqdm import tqdm

import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.llm_service import client
from typing import List, Dict, Any, Tuple


client = client()

def load_gsm8k(file_path: str) -> List[Dict[str, Any]]:
    """加载GSM8K数据集"""
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()  # 去掉首尾空格/换行
            if line:
                data.append(json.loads(line))
    return data

def extract_answer(response_text: str) -> str:
    """从模型响应中提取最终答案"""
    # 寻找最后一个数字，通常是答案
    lines = response_text.strip().split('\n')
    for line in reversed(lines):
        if "答案是" in line or "answer is" in line or "=" in line:
            # 尝试提取数字
            parts = line.split()
            for part in reversed(parts):
                # 移除可能的标点符号
                part = part.strip(",.。:;()[]")
                # 检查是否为数字形式（可能带有逗号或小数点）
                cleaned_part = part.replace(",", "")
                try:
                    float(cleaned_part)  # 尝试将其转换为浮点数
                    return part  # 如果成功，则返回原始形式（带逗号）
                except ValueError:
                    continue
    
    # 如果上面的方法未找到答案，尝试查找最后出现的数字
    import re
    numbers = re.findall(r'-?\d+(?:,\d+)*(?:\.\d+)?', response_text)
    if numbers:
        return numbers[-1]
    
    return ""

def evaluate_problem(agentflow, problem: Dict[str, Any], model: str, temperature: float = 0, max_retries: int = 2) -> Tuple[bool, str, str]:
    """评估单个问题"""
    question = problem["question"]
    answer = problem["answer"].split("####")[1].strip()
    
    user_prompt = f"解决下面的数学问题:\n\n{question}"
    
    retry_count = 0
    while retry_count < max_retries:
        try:
            response = agentflow.run(user_prompt)
            
            # 如果response是字典且keys中有answer
            if isinstance(response, dict) and "answer" in response:
                model_answer = response["answer"]
                if type(model_answer) != str:
                    model_answer = str(model_answer)
            else:
                model_answer = extract_answer(str(response))

            
            # 清理答案进行比较（移除逗号、空格等）
            clean_model_answer = model_answer.replace(",", "").strip()
            clean_correct_answer = answer.replace(",", "").strip()
            
            # 尝试将两个答案转换为浮点数进行比较
            try:
                # 对于数值答案
                model_value = float(clean_model_answer)
                correct_value = float(clean_correct_answer)
                is_correct = abs(model_value - correct_value) < 1e-6
            except ValueError:
                # 对于非数值答案，直接进行字符串比较
                is_correct = clean_model_answer == clean_correct_answer
            
            return is_correct, response, answer
            
        except Exception as e:
            retry_count += 1
            print(f"错误: {e}. 重试 ({retry_count}/{max_retries})...")
            print(traceback.format_exc())

            time.sleep(1)  # 等待2秒后重试
    
    # 达到最大重试次数后仍然失败
    return False, None, answer 

def Eval(agentflow, model="gpt-4o-mini", temperature=0.0, output_file="results/gsm8k_results.json", data_path="Benchmark/gsm8k-test.jsonl", sample_size=None, mode="train"):
    parser = argparse.ArgumentParser(description="使用OpenAI API评估GSM8K数据集")
    parser.add_argument("--data_path", type=str, default=data_path, help="GSM8K数据集路径")
    parser.add_argument("--model", type=str, default=model, help="OpenAI模型名称")
    parser.add_argument("--temperature", type=float, default=temperature, help="温度参数")
    parser.add_argument("--output_file", type=str, default=output_file, help="结果输出文件")
    parser.add_argument("--sample_size", type=int, default=sample_size, help="评估样本数量，为None时评估全部")
    args = parser.parse_args()
    
    data = load_gsm8k(args.data_path)
    
    if args.sample_size:
        random.seed(42)  # 固定随机种子以确保可重复性
        data = random.sample(data, min(args.sample_size, len(data)))
    
    results = []
    correct_count = 0
    
    #打开wrong.json，只评测前20%错误的样本
    # current_dir = os.path.dirname(os.path.abspath(__file__))
    # data_path = os.path.join(current_dir, "wrong.json")
    # with open(data_path, "r", encoding="utf-8") as f:
    #     wrong_data = json.load(f)
    # wrong_questions = wrong_data["wrong_questions"]
    # correct_count = len(data)-len(wrong_questions)
    # eval_size = 0.2
    # if mode == "test":
    #     eval_size = 1.0


    for i, problem in enumerate(tqdm(data)):
        # if i not in wrong_questions[:int(len(wrong_questions) * eval_size)]:
        #     continue
        is_correct, model_output, correct_answer = evaluate_problem(
            agentflow, problem, args.model, args.temperature)
        
        if is_correct:
            correct_count += 1
        else:
            if model_output is None:
                break
        results.append({
            "question": problem["question"],
            "model_output": model_output,
            "correct_answer": correct_answer,
            "is_correct": is_correct
        })
        
        if (i + 1) % 10 == 0:
            accuracy = correct_count / (i + 1)
            print(f"当前进度: {i+1}/{len(data)}, 准确率: {accuracy:.4f}")
            
            with open(args.output_file, 'w', encoding='utf-8') as f:
                json.dump({
                    "model": args.model,
                    "accuracy": accuracy,
                    "correct_count": correct_count,
                    "total_count": i + 1
                }, f, ensure_ascii=False, indent=2)
    
    final_accuracy = correct_count / len(data)
    print(f"最终准确率: {final_accuracy:.4f} ({correct_count}/{len(data)})")
    
    with open(args.output_file, 'w', encoding='utf-8') as f:
        json.dump({
            "model": args.model,
            "accuracy": final_accuracy,
            "correct_count": correct_count,
            "total_count": len(data)
        }, f, ensure_ascii=False, indent=2)

    return final_accuracy