"""
旅行商问题 (TSP) 评估器 (可反馈错误模式)。

此评估器在“全有或全无”的基础上进行了增强。当评估失败时，
它不仅返回一个极低的分数，还会返回一个包含具体失败原因的字符串。
这个错误信息可以被上层框架捕获，并用于指导大语言模型进行有针对性的代码修复。

评分逻辑:
- 适应度 = -误差。
- 严格评估: 必须成功解决所有问题实例。
- 错误反馈: 任何失败都会中止评估，并以结构化的方式返回失败的具体原因。
"""

import numpy as np
import time
import os
import subprocess
import tempfile
import traceback
import sys
import pickle
import re

# --- 问题与评估器配置 ---
PROBLEM_FOLDER = 'all'
TIME_BUDGET_PER_PROBLEM = 600.0
FAILURE_SCORE = -1e9


class TimeoutError(Exception):
    """自定义超时异常。"""
    pass

# --- 核心功能函数 (无变动) ---

def validate_tsp_tour(tour, num_cities):
    if not isinstance(tour, (list, np.ndarray)):
        return False, f"验证错误: 路径不是列表或numpy数组 (实际类型: {type(tour)})。"
    if len(tour) != num_cities:
        return False, f"验证错误: 路径包含 {len(tour)} 个城市，应为 {num_cities} 个。"
    if len(set(tour)) != num_cities:
        return False, "验证错误: 路径未包含所有必需的城市或包含重复城市。"
    return True, "路径有效"


def calculate_tour_distance(tour, dist_matrix):
    return sum(dist_matrix[tour[i], tour[(i + 1) % len(tour)]] for i in range(len(tour)))


def run_solver_with_timeout(program_path, dist_matrix, timeout_seconds=600):
    with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as matrix_file, \
         tempfile.NamedTemporaryFile(suffix=".py", delete=False, mode='w', encoding='utf-8') as script_file:

        pickle.dump(dist_matrix, matrix_file)
        matrix_file_path = matrix_file.name
        results_path = f"{matrix_file_path}.results"

        script_content = f"""
import sys, os, pickle, traceback
sys.path.insert(0, os.path.dirname('{program_path}'))
results = {{}}
try:
    spec = __import__('importlib.util').util.spec_from_file_location("program", '{program_path}')
    program = __import__('importlib.util').util.module_from_spec(spec)
    spec.loader.exec_module(program)
    with open('{matrix_file_path}', 'rb') as f:
        dist_matrix_data = pickle.load(f)
    tour, reported_distance = program.solve_tsp_approximate(dist_matrix_data)
    results['tour'] = tour
    results['reported_distance'] = reported_distance
except Exception as e:
    results['error'] = str(e)
    results['trace'] = traceback.format_exc()
finally:
    with open('{results_path}', 'wb') as f:
        pickle.dump(results, f)
"""
        script_file.write(script_content)
        script_file_path = script_file.name

    try:
        process = subprocess.Popen([sys.executable, script_file_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        try:
            stdout, stderr = process.communicate(timeout=timeout_seconds)
            if process.returncode != 0:
                raise RuntimeError(f"子进程以代码 {process.returncode} 退出\\n{stderr.decode()}")
        except subprocess.TimeoutExpired:
            process.kill()
            process.wait()
            raise TimeoutError(f"进程在 {timeout_seconds} 秒后超时")

        if os.path.exists(results_path):
            with open(results_path, "rb") as f:
                results = pickle.load(f)
            if "error" in results:
                raise RuntimeError(f"程序执行失败: {results['error']}\\n{results['trace']}")
            return results["tour"], results["reported_distance"]
        else:
            raise RuntimeError(f"未找到结果文件。stdout: {stdout.decode()}, stderr: {stderr.decode()}")
    finally:
        for p in [matrix_file_path, results_path, script_file_path]:
            if os.path.exists(p):
                os.unlink(p)


# --- 主评估函数 (已修改返回值) ---

def evaluate(program_path):
    """
    以严格的“全有或全无”模式评估一个 TSP 求解器程序。
    如果失败，返回的字典中将包含一个 'error_reason' 字段。
    """
    eval_start_time = time.time()
    all_fitness_scores = []
    all_problems_passed = True
    failure_reason = None # 新增变量，用于存储失败的具体原因

    problem_files_path = os.path.abspath(PROBLEM_FOLDER)
    problem_files = sorted([f for f in os.listdir(problem_files_path) if f.endswith(".txt")])
    total_problems = len(problem_files)

    for filename in problem_files:
        try:
            file_path = os.path.join(problem_files_path, filename)
            optimal_distance = float(re.search(r"Total Distance:\s*([\d.]+)", open(file_path, encoding='utf-8').read()).group(1))
            dist_matrix = np.loadtxt(file_path, comments="Total Distance:", encoding='utf-8')

            tour, _ = run_solver_with_timeout(program_path, dist_matrix, timeout_seconds=TIME_BUDGET_PER_PROBLEM)
            is_valid, reason = validate_tsp_tour(tour, dist_matrix.shape[0])

            if not is_valid:
                # ❗ 捕获具体的错误信息
                failure_reason = f"对于文件 {filename} 的解无效: {reason}"
                print(f"❌ {failure_reason}")
                all_problems_passed = False
                break

            actual_distance = calculate_tour_distance(tour, dist_matrix)
            error = (actual_distance - optimal_distance) / optimal_distance
            fitness_score = -error
            all_fitness_scores.append(fitness_score)
            print(f"✅ 已处理 {filename}: 误差={error:+.4%}, 适应度={fitness_score:.4f}")

        except Exception as e:
            # ❗ 捕获具体的异常信息
            failure_reason = f"处理 {filename} 时评估失败: {e}"
            print(f"❌ {failure_reason}")
            all_problems_passed = False
            break

    eval_time = time.time() - eval_start_time

    if not all_problems_passed:
        print("\n评估总结: 未能成功解决所有问题。评测中止。")
        # ❗ 在返回的字典中新增 error_reason 字段
        return {
            "average_fitness": FAILURE_SCORE,
            "validity_ratio": 0.0,
            "total_eval_time": float(eval_time),
            "error_reason": failure_reason
        }
    else:
        average_fitness = np.mean(all_fitness_scores) if all_fitness_scores else 0.0
        print(f"\n评估总结: 成功解决所有 {total_problems} 个问题。")
        # ❗ 成功的返回也包含 error_reason 字段（值为None），以保持结构一致
        return {
            "average_fitness": float(average_fitness),
            "validity_ratio": 1.0,
            "total_eval_time": float(eval_time),
            "error_reason": None
        }


# --- 独立运行入口 (已修改以展示新返回值) ---

if __name__ == '__main__':
    solution_file_path = 'evaluator.py'
    if not os.path.exists(solution_file_path):
        print(f"错误: 未在 '{solution_file_path}' 找到解决方案文件。")
    else:
        print(f"--- 正在评估 {solution_file_path} (可反馈错误模式) ---")
        metrics = evaluate(solution_file_path)
        print("\n--- 评估指标 ---")
        for key, value in metrics.items():
            # 对新字段进行特殊打印
            if key == "error_reason" and value is not None:
                 print(f"  - {key}: {value}")
            elif key != "error_reason":
                print(f"  - {key}: {value}")
        print("--------------------")