#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Adaptive SAT Solver Evaluator
Uses the same interface as evaluate.py, but internally automatically selects the best solver combination based on instance features

Usage:
python adaptive_eval.py --eval_data_dir ./dataset/test --results_save_path ./adaptive_results --model_path ./adaptive_selector_train_new_fixed.pkl
"""

import os
import shutil
import argparse
import subprocess
from datetime import datetime
import time
import re
import yaml
import math
import warnings
import json
import pickle
import numpy as np
from pathlib import Path
from typing import Dict, List, Tuple
import logging
import tempfile
import concurrent.futures
from threading import Lock

from autosat.utils import revise_file, clean_files, collect_results_eval
from autosat.execution.execution_worker import ExecutionWorker
from autosat.feature_extraction.compact_feature_extractor import CompactSATFeatureExtractor


class AdaptiveEvaluator:
    """Adaptive SAT Solver Evaluator"""
    
    def __init__(self, model_path: str, combinations_dir: str = "./combinations"):
        self.model_path = Path(model_path)
        self.combinations_dir = Path(combinations_dir)
        
        # Load trained adaptive selection model
        self.load_adaptive_model()
        
        # Initialize feature extractor
        self.feature_extractor = CompactSATFeatureExtractor()
        
        # Compiled solver cache
        self.compiled_solvers = {}
        self.compilation_lock = Lock()
        
        print("Adaptive SAT solver evaluator initialized")
        print(f"  Model: {self.model_path}")
        print(f"  Combinations: {len(self.heuristic_library)}")
        print(f"  Clusters: {len(self.cluster_centroids)}")
    
    def load_adaptive_model(self):
        """Load trained adaptive selection model"""
        if not self.model_path.exists():
            raise FileNotFoundError(f"Adaptive model does not exist: {self.model_path}")
        
        with open(self.model_path, 'rb') as f:
            model_data = pickle.load(f)
        
        self.heuristic_library = model_data['heuristic_library']
        self.clusters = model_data['clusters']
        self.cluster_centroids = {k: np.array(v) for k, v in model_data['cluster_centroids'].items()}
        self.optimal_mapping = model_data['optimal_mapping']
        
        # Always try to load external normalizer instead of using the one in the model
        self.feature_scaler = None
        try:
            normalizer_path = Path("normalized_features/normalizer_train_new_standard.pkl")
            if normalizer_path.exists():
                with open(normalizer_path, 'rb') as f:
                    normalizer_data = pickle.load(f)
                
                # Check normalizer format
                if isinstance(normalizer_data, dict):
                    if 'scaler' in normalizer_data:
                        # Use internal scaler object
                        self.feature_scaler = normalizer_data['scaler']
                        print("  ✓ Loaded external normalizer (scaler object)")
                    elif 'global_stats' in normalizer_data:
                        # Use statistics from global_stats
                        stats = normalizer_data['global_stats']
                        if 'mean' in stats and 'std' in stats:
                            self.feature_mean = np.array(stats['mean'])
                            self.feature_std = np.array(stats['std'])
                            self.feature_scaler = 'manual'
                            print("  ✓ Loaded external normalizer (manual standardization)")
                        else:
                            print("  ⚠️  Normalizer format not supported, will use original features")
                            self.feature_scaler = None
                    else:
                        print("  ⚠️  Normalizer format not supported, will use original features")
                        self.feature_scaler = None
                else:
                    self.feature_scaler = normalizer_data  # Try to use directly
                    print("  ✓ Loaded external normalizer (direct use)")
            else:
                print("  ⚠️  Normalizer not found, will use original features")
        except Exception as e:
            print(f"  ⚠️  Failed to load normalizer: {e}")
            self.feature_scaler = None
        
        print(f"Loaded adaptive model: {len(self.heuristic_library)} solver combinations")
    
    def predict_best_solver(self, cnf_file_path: str) -> Tuple[str, float]:
        """
        Predict the best solver for a single CNF instance
        
        Args:
            cnf_file_path: CNF file path
            
        Returns:
            (best_solver, confidence): Best solver name and confidence
        """
        # Extract instance features
        try:
            feature_vector = self.feature_extractor.extract_features_to_vector(cnf_file_path)
            
            # Normalize feature vector
            if self.feature_scaler == 'manual':
                # Use manual standardization
                try:
                    feature_array = np.array(feature_vector, dtype=float)
                    normalized_features = (feature_array - self.feature_mean) / (self.feature_std + 1e-8)
                except Exception as e:
                    print(f"  ⚠️  Manual standardization failed, using original features: {e}")
                    normalized_features = np.array(feature_vector, dtype=float)
            elif self.feature_scaler:
                # Use sklearn scaler
                try:
                    normalized_features = self.feature_scaler.transform([feature_vector])[0]
                except Exception as e:
                    print(f"  ⚠️  Standardization failed, using original features: {e}")
                    normalized_features = np.array(feature_vector, dtype=float)
            else:
                # If no scaler, assume cluster centroids are based on original features
                normalized_features = np.array(feature_vector, dtype=float)
            
            # Calculate distances to cluster centroids
            distances = {}
            for solver, centroid in self.cluster_centroids.items():
                dist = np.linalg.norm(normalized_features - centroid)
                distances[solver] = dist
            
            # Sort by distance to generate rankings
            sorted_distances = sorted(distances.items(), key=lambda x: x[1])
            rankings = {solver: rank + 1 for rank, (solver, _) in enumerate(sorted_distances)}
            
            # Select the nearest cluster
            best_solver = min(distances.keys(), key=lambda k: distances[k])
            min_distance = distances[best_solver]
            
            # Calculate confidence (inverse function of distance)
            confidence = 1.0 / (1.0 + min_distance)
            
            # Don't output detailed information, let caller decide output format
            
            return best_solver, confidence
            
        except Exception as e:
            print(f"  ⚠️  Feature extraction failed {cnf_file_path}: {e}")
            # Fallback to default selection
            default_solver = list(self.heuristic_library.keys())[0]
            return default_solver, 0.0
    
    def compile_solver(self, solver_name: str, temp_dir: str, timeout: int, data_dir: str) -> str:
        """Compile the specified solver combination"""
        with self.compilation_lock:
            if solver_name in self.compiled_solvers:
                return self.compiled_solvers[solver_name]
        
        # Get solver cpp file path
        solver_cpp_path = self.combinations_dir / f"{solver_name}.cpp"
        if not solver_cpp_path.exists():
            raise FileNotFoundError(f"Solver combination does not exist: {solver_cpp_path}")
        
        # Create temporary compilation directory
        compile_dir = os.path.join(temp_dir, f"compile_{solver_name}")
        print(f"  Creating compile directory: {compile_dir}")
        os.makedirs(compile_dir, exist_ok=True)
        print(f"  Directory created: {os.path.exists(compile_dir)}")
        
        # Process template placeholders, similar to evaluate.py
        tmp_cpp_path = os.path.join(compile_dir, f"{solver_name}.cpp")
        
        # Use revise_file to process template variables
        from autosat.utils import revise_file
        
        # Create results directory
        results_dir = os.path.join(compile_dir, "results")
        os.makedirs(results_dir, exist_ok=True)
        
        # Use absolute paths to ensure solver can find data directory
        abs_data_dir = os.path.abspath(data_dir)
        abs_results_dir = os.path.abspath(results_dir)
        
        revise_file(
            file_name=str(solver_cpp_path),
            save_dir=tmp_cpp_path,
            timeout=timeout,
            data_dir='"' + abs_data_dir + '"',
            results_dir=abs_results_dir
        )
        
        # Copy header files
        for header_file in ['EasySAT.hpp', 'heap.hpp']:
            header_path = self.combinations_dir / header_file
            if header_path.exists():
                dst_header_path = os.path.join(compile_dir, header_file)
                shutil.copy2(header_path, dst_header_path)
        
        # Compile
        executable_path = os.path.join(compile_dir, solver_name)
        
        try:
            print(f"  Compiling {solver_name} in {compile_dir}...")
            
            # Change to compile directory and compile
            original_cwd = os.getcwd()
            os.chdir(compile_dir)
            
            compile_cmd = f"g++ -O3 -std=c++17 {solver_name}.cpp -o {solver_name}"
            result = subprocess.run(compile_cmd, shell=True, capture_output=True, text=True, timeout=30)
            
            # Change back to original directory
            os.chdir(original_cwd)
            
            if result.returncode != 0:
                print(f"  Compilation failed: {result.stderr}")
                raise RuntimeError(f"Compilation failed: {result.stderr}")
            
            # Check if executable was created
            if not os.path.exists(executable_path):
                print(f"  Executable not found at {executable_path}")
                print(f"  Directory contents: {os.listdir(compile_dir) if os.path.exists(compile_dir) else 'Directory does not exist'}")
                raise RuntimeError(f"Executable not created: {executable_path}")
            
            print(f"  ✓ {solver_name} compiled successfully")
            
            with self.compilation_lock:
                self.compiled_solvers[solver_name] = executable_path
            
            return executable_path
            
        except Exception as e:
            raise RuntimeError(f"Compilation of {solver_name} failed: {e}")
    
    def run_single_instance(self, cnf_file_path: str, temp_dir: str, timeout: int, data_dir: str, compiled_solvers: Dict = None) -> Dict:
        """Run adaptive solving for a single instance"""
        instance_name = os.path.basename(cnf_file_path)
        
        # Predict best solver
        best_solver, confidence = self.predict_best_solver(cnf_file_path)
        strategy = self.heuristic_library.get(best_solver, "unknown")
        
        print(f"  {instance_name}: selected {best_solver} (confidence: {confidence:.3f})")
        print(f"      Strategy: {strategy}")
        
        # Get pre-compiled solver
        if compiled_solvers is None or best_solver not in compiled_solvers:
            print(f"  Pre-compiled solver not found: {best_solver}")
            print(f"  Available solvers: {list(compiled_solvers.keys()) if compiled_solvers else 'None'}")
            return {
                'cnf_file': cnf_file_path,
                'duration': timeout,
                'situation': 'COMPILE_ERROR',
                'solver': best_solver,
                'confidence': confidence,
                'strategy': strategy
            }
        
        executable_path = compiled_solvers[best_solver]
        print(f"  Using executable: {executable_path}")
        
        # Strictly follow evaluate.py's approach
        start_time = time.time()
        try:
            from autosat.execution.execution_worker import ExecutionWorker
            from autosat.utils import revise_file
            
            # Create temporary directory (same as evaluate.py)
            # Add instance-specific suffix to avoid conflicts in parallel execution
            instance_name = os.path.basename(cnf_file_path).replace('.cnf', '')
            solver_temp_dir = os.path.join(temp_dir, f"exec_{best_solver}_{instance_name}")
            os.makedirs(solver_temp_dir, exist_ok=True)
            
            # Create results directory (same as evaluate.py)
            results_dir = os.path.join(solver_temp_dir, "results")
            os.makedirs(results_dir, exist_ok=True)
            
            # Copy solver file to temporary directory (same as evaluate.py)
            tmp_cpp_source_path = os.path.join(solver_temp_dir, 'SAT_Solver_tmp.cpp')
            tmp_executable_file_path = os.path.join(solver_temp_dir, 'SAT_Solver_tmp')
            
            # Copy the solver file directly to the intermediate directory
            shutil.copy2(executable_path.replace('.exe', '.cpp'), tmp_cpp_source_path)
            
            # Also copy necessary header files from the solver's directory (same as evaluate.py)
            solver_dir = os.path.dirname(executable_path.replace('.exe', '.cpp'))
            for header_file in ['EasySAT.hpp', 'heap.hpp']:
                header_path = os.path.join(solver_dir, header_file)
                if os.path.exists(header_path):
                    dst_header_path = os.path.join(solver_temp_dir, header_file)
                    shutil.copy2(header_path, dst_header_path)
            
            # Create a temporary data directory with only this file
            temp_data_dir = os.path.join(solver_temp_dir, "temp_data")
            os.makedirs(temp_data_dir, exist_ok=True)
            temp_cnf_file = os.path.join(temp_data_dir, os.path.basename(cnf_file_path))
            shutil.copy2(cnf_file_path, temp_cnf_file)
            
            # Use revise_file exactly like evaluate.py
            # Get the original source file path from combinations directory
            original_source_path = os.path.join(self.combinations_dir, f"{best_solver}.cpp")
            
            revise_file(
                file_name=original_source_path,
                save_dir=tmp_cpp_source_path,
                timeout=timeout,
                data_dir="\"" + temp_data_dir + "\"",
                results_dir=f"{solver_temp_dir}/results"
            )
            
            # Use ExecutionWorker exactly like evaluate.py
            execution_worker = ExecutionWorker(temp_base_dir=solver_temp_dir)
            success = execution_worker.execute_eval(
                source_cpp_path=tmp_cpp_source_path,
                executable_file_path=tmp_executable_file_path,
                data_parallel_size=1
            )
            
            if not success:
                raise RuntimeError("cannot correctly execute... plz check again")
            
            # Wait for results exactly like evaluate.py
            filename = "1_0.txt"  # Since data_parallel_size=1
            start_wait_time = time.time()
            while True:
                end_wait_time = time.time()
                if end_wait_time - start_wait_time > timeout * 2 + 30:
                    raise ValueError("Infinite loop error!!!")
                
                if os.path.exists(os.path.join(f'{solver_temp_dir}/results/', 'finished' + filename)):
                    break
            
            end_time = time.time()
            duration = end_time - start_time
            
            # Parse result from the output file (same as evaluate.py)
            result_file = os.path.join(solver_temp_dir, "results", filename)
            if os.path.exists(result_file):
                try:
                    with open(result_file, 'r', encoding='utf-8', errors='ignore') as f:
                        lines = f.readlines()
                        if len(lines) > 1:  # Skip header line
                            for line in lines[1:]:  # Skip header
                                if os.path.basename(cnf_file_path) in line:
                                    parts = line.strip().split('\t')
                                    if len(parts) >= 3:
                                        duration = float(parts[1])
                                        situation = parts[2].lower()
                                        break
                            else:
                                situation = 'unknown'
                        else:
                            situation = 'unknown'
                except Exception as e:
                    print(f"  Error reading result file: {e}")
                    situation = 'error'
            else:
                situation = 'error'
                
        except subprocess.TimeoutExpired:
            duration = timeout
            situation = 'timeout'
        except Exception as e:
            duration = timeout
            situation = 'error'
            print(f"  Runtime error: {e}")
            print(f"  Executable path: {executable_path}")
            print(f"  File exists: {os.path.exists(executable_path)}")
            if os.path.exists(executable_path):
                print(f"  File size: {os.path.getsize(executable_path)}")
                print(f"  File permissions: {oct(os.stat(executable_path).st_mode)}")
        
        return {
            'cnf_file': cnf_file_path,
            'duration': duration,
            'situation': situation,
            'solver': best_solver,
            'confidence': confidence,
            'strategy': strategy
        }
    
    def adaptive_evaluate(self, args):
        """Execute adaptive evaluation"""
        print("Starting adaptive SAT solver evaluation...")
        print("=" * 60)
        

        
        # Create result directory
        dataset_name = os.path.basename(os.path.normpath(args.eval_data_dir))
        method_name = args.method_name or "adaptive_solver"
        eval_session_dir = os.path.join(args.results_save_path, dataset_name, method_name)
        os.makedirs(eval_session_dir, exist_ok=True)
        
        # Get all CNF files
        cnf_files = []
        for f in os.listdir(args.eval_data_dir):
            if f.endswith('.cnf'):
                cnf_files.append(os.path.join(args.eval_data_dir, f))
        
        if not cnf_files:
            raise ValueError(f"No CNF files found in {args.eval_data_dir}")
        
        print(f"Dataset: {dataset_name}")
        print(f"Instances: {len(cnf_files)}")
        print(f"Timeout: {args.eval_timeout}s")
        print(f"Parallel size: {args.eval_parallel_size}")
        print("")
        
        # Create temporary working directory
        temp_base_dir = os.path.join(eval_session_dir, 'tmp')
        os.makedirs(temp_base_dir, exist_ok=True)
        
        # Pre-compile all needed solvers (avoid repeated compilation)
        print("Pre-compiling required solvers...")
        needed_solvers = set()
        instance_predictions = {}
        
        for cnf_file in cnf_files:
            best_solver, confidence = self.predict_best_solver(cnf_file)
            needed_solvers.add(best_solver)
            instance_predictions[cnf_file] = (best_solver, confidence)
        
        print(f"Required solvers: {list(needed_solvers)}")
        compiled_solvers = {}
        for solver_name in needed_solvers:
            try:
                executable_path = self.compile_solver(solver_name, temp_base_dir, args.eval_timeout, args.eval_data_dir)
                compiled_solvers[solver_name] = executable_path
                print(f"  ✓ {solver_name} compiled")
            except Exception as e:
                print(f"  ❌ {solver_name} compilation failed: {e}")
                return None
        
        print("All solvers compiled successfully!")
        print("")
        
        # 统计求解器选择分布
        print("Solver selection distribution:")
        solver_selection_count = {}
        for cnf_file, (best_solver, confidence) in instance_predictions.items():
            solver_selection_count[best_solver] = solver_selection_count.get(best_solver, 0) + 1
        
        # Show selection distribution
        print("Solver selection distribution:")
        for solver, count in sorted(solver_selection_count.items(), key=lambda x: x[1], reverse=True):
            percentage = count / len(cnf_files) * 100
            print(f"  {solver}: {count} instances ({percentage:.1f}%)")
        print("")
        
        # 并行执行评估
        results = []
        solver_usage = {}
        
        start_time = time.time()
        
        with concurrent.futures.ThreadPoolExecutor(max_workers=args.eval_parallel_size) as executor:
            # 提交所有任务
            future_to_cnf = {
                executor.submit(self.run_single_instance, cnf_file, temp_base_dir, args.eval_timeout, args.eval_data_dir, compiled_solvers): cnf_file
                for cnf_file in cnf_files
            }
            
            # 收集结果
            for future in concurrent.futures.as_completed(future_to_cnf):
                cnf_file = future_to_cnf[future]
                try:
                    result = future.result()
                    results.append(result)
                    
                    # 统计solver使用情况
                    solver = result['solver']
                    solver_usage[solver] = solver_usage.get(solver, 0) + 1
                    
                    # 实时进度
                    progress = len(results) / len(cnf_files) * 100
                    print(f"  [{progress:5.1f}%] {os.path.basename(result['cnf_file'])}: "
                          f"{result['duration']:.1f}s ({result['situation']})")
                    
                except Exception as exc:
                    print(f"  ❌ {cnf_file} 执行失败: {exc}")
                    results.append({
                        'cnf_file': cnf_file,
                        'duration': args.eval_timeout,
                        'situation': 'error',
                        'solver': 'unknown',
                        'confidence': 0.0,
                        'strategy': 'unknown'
                    })
        
        end_time = time.time()
        total_time = end_time - start_time
        
        # 保存结果
        self.save_results(results, eval_session_dir, solver_usage, total_time, args)
        
        # 清理临时文件
        if not args.keep_intermediate_results:
            shutil.rmtree(temp_base_dir, ignore_errors=True)
        
        print("=" * 60)
        print("Adaptive evaluation completed!")
        
        return eval_session_dir
    
    def save_results(self, results: List[Dict], eval_session_dir: str, 
                     solver_usage: Dict, total_time: float, args):
        """Save evaluation results in the same format as evaluate.py"""
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        
        # Calculate statistics (same as evaluate.py)
        total_time = sum(result['duration'] for result in results)
        satisfiable_count = len([r for r in results if r['situation'] in ['satisfiable', 'sat']])
        unsatisfiable_count = len([r for r in results if r['situation'] in ['unsatisfiable', 'unsat']])
        timeout_count = len([r for r in results if r['situation'] == 'timeout'])
        total_questions = len(results)
        
        # Calculate PAR-2 score
        par2_scores = []
        for result in results:
            if result['situation'] in ['satisfiable', 'unsatisfiable', 'sat', 'unsat']:
                par2_scores.append(result['duration'])
            else:
                par2_scores.append(2 * args.eval_timeout)  # PAR-2 timeout penalty
        
        avg_par2 = round(sum(par2_scores) / total_questions, 2) if total_questions > 0 else 0
        
        # Create result_dict (same format as evaluate.py)
        result_dict = {
            'PAR-2': avg_par2,
            'satisfiable': satisfiable_count,
            'unsatisfiable': unsatisfiable_count,
            'timeout': timeout_count,
            'total time': total_time,
            '#question': total_questions
        }
        
        # 1. Save standard format results (compatible with existing tools)
        results_file = os.path.join(eval_session_dir, f"results_{timestamp}.txt")
        with open(results_file, 'w', encoding='utf-8') as f:
            f.write("cnf File\tDuration\tSituation\n")
            for result in results:
                f.write(f"{result['cnf_file']}\t{result['duration']:.2f}\t{result['situation']}\n")
            f.write(str(result_dict) + '\n')
        
        # 2. Save to eval_results directory (same as evaluate.py)
        # Use the same directory structure as evaluate.py: results_save_path/dataset_name/method_name/
        dataset_name = os.path.basename(os.path.normpath(args.eval_data_dir))
        eval_results_dir = os.path.join(args.results_save_path, dataset_name, args.method_name)
        os.makedirs(eval_results_dir, exist_ok=True)
        eval_results_file = os.path.join(eval_results_dir, f"results_{timestamp}.txt")
        with open(eval_results_file, 'w', encoding='utf-8') as f:
            f.write("cnf File\tDuration\tSituation\n")
            for result in results:
                f.write(f"{result['cnf_file']}\t{result['duration']:.2f}\t{result['situation']}\n")
            f.write(str(result_dict) + '\n')
        
        # 3. Save detailed adaptive results
        adaptive_results_file = os.path.join(eval_session_dir, f"adaptive_results_{timestamp}.json")
        adaptive_data = {
            'timestamp': timestamp,
            'total_time': total_time,
            'total_instances': len(results),
            'solver_usage': solver_usage,
            'results': results
        }
        
        with open(adaptive_results_file, 'w', encoding='utf-8') as f:
            json.dump(adaptive_data, f, indent=2, ensure_ascii=False)
        
        # 4. Generate evaluation report
        self.generate_report(results, eval_session_dir, solver_usage, total_time, timestamp)
        
        print(f"Results saved:")
        print(f"  Standard format: {results_file}")
        print(f"  Eval results: {eval_results_file}")
        print(f"  Detailed results: {adaptive_results_file}")
    
    def generate_report(self, results: List[Dict], eval_session_dir: str,
                       solver_usage: Dict, total_time: float, timestamp: str):
        """生成评估报告"""
        report_file = os.path.join(eval_session_dir, f"adaptive_report_{timestamp}.txt")
        
        # 计算统计信息
        total_instances = len(results)
        solved_count = len([r for r in results if r['situation'] in ['satisfiable', 'unsatisfiable']])
        timeout_count = len([r for r in results if r['situation'] == 'timeout'])
        error_count = len([r for r in results if r['situation'] in ['error', 'COMPILE_ERROR']])
        
        solve_rate = solved_count / total_instances * 100 if total_instances > 0 else 0
        avg_confidence = np.mean([r['confidence'] for r in results])
        
        # 计算PAR-2分数
        par2_scores = []
        timeout_penalty = 2000
        for result in results:
            if result['situation'] in ['satisfiable', 'unsatisfiable']:
                par2_scores.append(result['duration'])
            else:
                par2_scores.append(timeout_penalty)
        
        avg_par2 = np.mean(par2_scores)
        
        with open(report_file, 'w', encoding='utf-8') as f:
            f.write("自适应SAT求解器评估报告\n")
            f.write("=" * 50 + "\n\n")
            
            f.write(f"评估时间: {timestamp}\n")
            f.write(f"总执行时间: {total_time:.2f}秒\n\n")
            
            f.write("📊 总体统计:\n")
            f.write(f"  总实例数: {total_instances}\n")
            f.write(f"  成功求解: {solved_count} ({solve_rate:.1f}%)\n")
            f.write(f"  超时: {timeout_count}\n")
            f.write(f"  错误: {error_count}\n")
            f.write(f"  平均PAR-2: {avg_par2:.2f}\n")
            f.write(f"  平均置信度: {avg_confidence:.3f}\n\n")
            
            f.write("🎯 Solver使用分布:\n")
            for solver, count in sorted(solver_usage.items(), key=lambda x: x[1], reverse=True):
                percentage = count / total_instances * 100
                strategy = self.heuristic_library.get(solver, "unknown")
                f.write(f"  {solver}: {count} ({percentage:.1f}%)\n")
                f.write(f"    策略: {strategy}\n")
            
            f.write(f"\n📁 详细结果保存在: adaptive_results_{timestamp}.json\n")
        
        print(f"  评估报告: {report_file}")

    def analyze_and_filter_solvers(self, eval_results_dir: str = "./eval_results/train_new/train_new"):
        """
        分析所有求解器的性能，去除后30%的弱求解器，然后重新训练模型
        """
        print("🔍 分析求解器性能并过滤弱求解器...")
        
        # 加载所有求解器的性能数据
        performance_data = {}
        eval_results_path = Path(eval_results_dir)
        
        for solver_dir in eval_results_path.glob("solver_combination_*"):
            if not solver_dir.is_dir():
                continue
                
            solver_name = solver_dir.name
            result_files = list(solver_dir.glob("results_*.txt"))
            
            # 如果当前目录没有结果文件，检查嵌套目录
            if not result_files:
                # 检查嵌套目录结构：solver_dir/train_new/solver_combination_X/
                nested_dirs = list(solver_dir.glob("train_new/solver_combination_*"))
                for nested_dir in nested_dirs:
                    if nested_dir.is_dir():
                        nested_result_files = list(nested_dir.glob("results_*.txt"))
                        if nested_result_files:
                            result_files = nested_result_files
                            break
            
            if not result_files:
                print(f"  ⚠️  未找到 {solver_name} 的结果文件")
                continue
            
            # 使用最新的结果文件
            result_file = sorted(result_files)[-1]
            
            # 解析结果文件
            solver_performance = self._parse_result_file(result_file)
            if solver_performance:
                performance_data[solver_name] = solver_performance
        
        if not performance_data:
            print("  ❌ 没有找到任何求解器性能数据")
            return False
        
        # 计算每个求解器的平均PAR-2分数
        solver_scores = {}
        for solver_name, results in performance_data.items():
            scores = list(results.values())
            avg_score = sum(scores) / len(scores)
            solver_scores[solver_name] = avg_score
        
        # 按性能排序
        sorted_solvers = sorted(solver_scores.items(), key=lambda x: x[1])
        
        print("  📊 求解器性能排名:")
        for i, (solver, score) in enumerate(sorted_solvers):
            print(f"    {i+1}. {solver}: {score:.2f}")
        
        # 去除后40%的弱求解器
        num_to_remove = max(1, int(len(sorted_solvers) * 0.4))
        weak_solvers = [solver for solver, _ in sorted_solvers[-num_to_remove:]]
        
        print(f"  🗑️  去除 {len(weak_solvers)} 个弱求解器: {weak_solvers}")
        
        # 重新训练模型（排除弱求解器）
        return self._retrain_model_with_filtered_solvers(performance_data, weak_solvers)
    
    def _retrain_model_with_filtered_solvers(self, performance_data: Dict, weak_solvers: List[str]):
        """
        使用过滤后的求解器重新训练模型
        """
        print("  🔄 重新训练模型...")
        
        # 过滤掉弱求解器的性能数据
        filtered_performance = {}
        for solver, results in performance_data.items():
            if solver not in weak_solvers:
                filtered_performance[solver] = results
        
        if len(filtered_performance) < 2:
            print("  ❌ 过滤后求解器数量不足，无法重新训练")
            return False
        
        # 重新计算最优映射
        optimal_mappings = {}
        for instance in set().union(*[set(results.keys()) for results in filtered_performance.values()]):
            best_solver = None
            best_score = float('inf')
            
            for solver_name, results in filtered_performance.items():
                if instance in results:
                    score = results[instance]
                    if score < best_score:
                        best_score = score
                        best_solver = solver_name
            
            if best_solver:
                optimal_mappings[instance] = best_solver
        
        # 重新聚类
        clusters = {}
        for instance, best_solver in optimal_mappings.items():
            if best_solver not in clusters:
                clusters[best_solver] = []
            clusters[best_solver].append(instance)
        
        # 重新计算聚类中心
        cluster_centroids = {}
        for solver, instances in clusters.items():
            if not instances:
                continue
            
            # 加载特征
            features = []
            for instance in instances:
                try:
                    # 找到实例对应的CNF文件路径
                    cnf_file_path = None
                    for root, dirs, files in os.walk("./dataset/train_new"):
                        for file in files:
                            if file.endswith('.cnf') and instance in file:
                                cnf_file_path = os.path.join(root, file)
                                break
                        if cnf_file_path:
                            break
                    
                    if cnf_file_path:
                        # 提取真实特征
                        feature_vector = self.feature_extractor.extract_features_to_vector(cnf_file_path)
                        if self.feature_scaler:
                            if hasattr(self.feature_scaler, 'transform'):
                                feature_vector = self.feature_scaler.transform([feature_vector])[0]
                            elif self.feature_scaler == 'manual':
                                feature_vector = (feature_vector - self.feature_mean) / self.feature_std
                        features.append(feature_vector)
                except Exception as e:
                    print(f"    ⚠️  无法加载实例 {instance} 的特征: {e}")
                    continue
            
            if features:
                # 计算真实的聚类中心
                centroid = np.mean(features, axis=0)
                cluster_centroids[solver] = centroid.tolist()
        
        # 更新模型数据
        self.clusters = clusters
        self.cluster_centroids = cluster_centroids
        self.optimal_mapping = optimal_mappings
        
        # 更新启发式库（排除弱求解器）
        filtered_heuristic_library = {}
        for solver, strategy in self.heuristic_library.items():
            if solver not in weak_solvers:
                filtered_heuristic_library[solver] = strategy
        
        self.heuristic_library = filtered_heuristic_library
        
        print(f"  ✅ 模型重新训练完成")
        print(f"     - 保留求解器: {list(filtered_heuristic_library.keys())}")
        print(f"     - 聚类数量: {len(clusters)}")
        print(f"     - 实例数量: {len(optimal_mappings)}")
        
        return True
    
    def _parse_result_file(self, result_file: Path) -> Dict[str, float]:
        """
        解析结果文件，返回实例名到PAR-2分数的映射
        """
        results = {}
        timeout_penalty = 2000  # PAR-2 timeout penalty
        
        try:
            with open(result_file, 'r', encoding='utf-8') as f:
                lines = f.readlines()
            
            for line in lines[1:]:  # 跳过标题行
                parts = line.strip().split('\t')
                if len(parts) >= 3:
                    instance_name = parts[0]
                    duration_str = parts[1]
                    situation = parts[2]
                    
                    try:
                        if situation.lower() in ['sat', 'unsat', 'satisfiable', 'unsatisfiable']:
                            duration = float(duration_str)
                            # PAR-2 score = duration
                            results[instance_name] = duration
                        else:
                            # Timeout or error
                            results[instance_name] = timeout_penalty
                    except ValueError:
                        results[instance_name] = timeout_penalty
            
            return results
        except Exception as e:
            print(f"    ⚠️  解析结果文件 {result_file} 失败: {e}")
            return {}


def main():
    parser = argparse.ArgumentParser(description='自适应SAT求解器评估器')
    
    # 与evaluate.py兼容的参数
    parser.add_argument('--config', type=str, default='./examples/EasySAT/ada_config.yaml',
                       help='Path to the config file')
    parser.add_argument('--eval_data_dir', default='./dataset/test', type=str,
                       help='CNF文件目录')
    parser.add_argument('--results_save_path', type=str, default='./adaptive_eval_results/',
                       help='结果保存路径')
    parser.add_argument('--eval_parallel_size', type=int, default=8, 
                       help='并行评估进程数')
    parser.add_argument('--eval_timeout', type=int, default=1000, 
                       help='单实例超时时间(秒)')
    parser.add_argument('--method_name', type=str, default='adaptive_solver',
                       help='方法名称')
    parser.add_argument('--keep_intermediate_results', type=bool, default=False,
                       help='保留中间文件')
    parser.add_argument('--rand_seed', type=int, default=42, help='随机种子')
    
    # 自适应特有参数
    parser.add_argument('--model_path', type=str, 
                       default='./adaptive_selector_train_new_fixed.pkl',
                       help='训练好的自适应选择模型路径')
    parser.add_argument('--combinations_dir', type=str, default='./combinations',
                       help='Solver组合目录')
    parser.add_argument('--filter_weak_solvers', action='store_true',
                       help='启用弱求解器过滤功能（去除后40%的弱求解器）')
    
    args = parser.parse_args()
    
    # Load config file first (same as evaluate.py)
    if os.path.exists(args.config):
        print('Loading config file:', args.config)
        with open(args.config, 'r') as file:
            config = yaml.safe_load(file)
            # Override all arguments with config values
            for key, value in config.items():
                if hasattr(args, key):
                    setattr(args, key, value)
                    print(f'  {key}: {value}')
    else:
        print(f'Config file not found: {args.config}')
        print('Using command line arguments only')
    
    # 初始化自适应评估器
    try:
        evaluator = AdaptiveEvaluator(args.model_path, args.combinations_dir)
    except Exception as e:
        print(f"❌ 初始化自适应评估器失败: {e}")
        return 1
    
    # 执行评估
    try:
        result_dir = evaluator.adaptive_evaluate(args)
        print(f"🎉 评估完成，结果保存在: {result_dir}")
        return 0
    except Exception as e:
        print(f"❌ 评估失败: {e}")
        return 1


if __name__ == "__main__":
    exit(main())