#!/usr/bin/env python3
"""
Adaptive SAT Solver Selection Framework
Heuristic combination selection based on instance features

Implementation of the DASHCO framework described in the paper:
1. Instance Space Partitioning via Performance-Based Clustering
2. Adaptive Heuristic Selection for New Instances
"""

import numpy as np
import pandas as pd
import json
import os
import pickle
from pathlib import Path
from typing import Dict, List, Tuple, Optional
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.preprocessing import StandardScaler
import argparse
from datetime import datetime


class AdaptiveSolverSelector:
    """Adaptive solver selector"""
    
    def __init__(self, combinations_dir: str = "./combinations", 
                 eval_results_dir: str = "./eval_results/train_new/train_new",
                 features_dir: str = "./features"):
        self.combinations_dir = Path(combinations_dir)
        self.eval_results_dir = Path(eval_results_dir)
        self.features_dir = Path(features_dir)
        
        # Core data structures
        self.heuristic_library = {}  # H: Heuristic combination library
        self.instance_features = {}  # Instance feature vectors
        self.performance_matrix = {}  # p(h_i, j): Performance matrix
        self.clusters = {}  # Clustering results C_i
        self.cluster_centroids = {}  # Cluster centroids v̄_i
        self.optimal_mapping = {}  # h*(j): Instance to optimal combination mapping
        
        # Configuration parameters
        self.timeout_penalty = 2000  # PAR-2 timeout penalty
        self.feature_scaler = StandardScaler()
        
        print("🤖 Initializing adaptive solver selector")
        print(f"  Combinations directory: {self.combinations_dir}")
        print(f"  Evaluation results: {self.eval_results_dir}")
        print(f"  Features directory: {self.features_dir}")
    
    def load_heuristic_library(self) -> Dict[str, str]:
        """Load heuristic combination library H"""
        print("📚 Loading heuristic combination library...")
        
        # Get combination descriptions (using the fixed function)
        from evaluate_combinations import get_solver_strategy_info
        
        library = {}
        for cpp_file in self.combinations_dir.glob("solver_combination_*.cpp"):
            solver_name = cpp_file.stem
            strategy_info = get_solver_strategy_info(solver_name)
            library[solver_name] = strategy_info
            print(f"  ✓ {solver_name}: {strategy_info}")
        
        self.heuristic_library = library
        print(f"  Total loaded {len(library)} heuristic combinations")
        return library
    
    def load_instance_features(self, dataset: str = "train_new", use_normalized: bool = True) -> Dict[str, np.ndarray]:
        """Load instance feature vectors"""
        print(f"🎯 Loading instance features ({dataset})...")
        
        if use_normalized:
            # Use normalized features
            normalized_dir = Path("normalized_features")
            feature_file = normalized_dir / f"{dataset}_normalized_{dataset}_standard.npz"
            
            if not feature_file.exists():
                print(f"  ⚠️  Normalized features do not exist: {feature_file}")
                print("  Trying to use original features...")
                use_normalized = False
            else:
                print(f"  ✓ Using normalized features: {feature_file}")
        
        if not use_normalized:
            # Use original features
            feature_file = self.features_dir / dataset / f"{dataset}_features.npz"
            if not feature_file.exists():
                raise FileNotFoundError(f"Feature file does not exist: {feature_file}")
            print(f"  ✓ Using original features: {feature_file}")
        
        # Load feature data
        with np.load(feature_file, allow_pickle=True) as data:
            feature_names = data['feature_names']
            
            if use_normalized:
                # Normalized feature format
                features_matrix = data['features']  # (n_instances, n_features)
                filenames = data['filenames']  # Instance filename array
                
                print(f"  ✓ Loaded normalized features for {len(filenames)} instances")
                print(f"  ✓ Feature dimensions: {features_matrix.shape[1]}")
                
                # Convert to dictionary format
                features = {}
                for i, filename in enumerate(filenames):
                    # Remove .cnf suffix to match evaluation results
                    clean_name = filename.replace('.cnf', '')
                    features[clean_name] = features_matrix[i]
            else:
                # Original feature format
                instances = data['instances'].item()  # Dict[instance_name, features]
                
                print(f"  ✓ Loaded original features for {len(instances)} instances")
                print(f"  ✓ Feature dimensions: {len(feature_names)}")
                
                # Convert to standard format
                features = {}
                for instance_name, feature_vector in instances.items():
                    # Remove .cnf suffix to match evaluation results
                    clean_name = instance_name.replace('.cnf', '')
                    features[clean_name] = np.array(feature_vector, dtype=float)
            
            self.instance_features = features
            print(f"  ✓ Processed instances: {len(features)}")
            
        return features
    
    def load_performance_matrix(self) -> Dict[str, Dict[str, float]]:
        """
        Load performance matrix p(h_i, j)
        Returns: {solver_name: {instance_name: par2_score}}
        """
        print("📊 Loading performance matrix...")
        
        performance = {}
        
        # Iterate through each solver's evaluation results
        print(f"  Searching directory: {self.eval_results_dir}")
        for solver_dir in self.eval_results_dir.glob("solver_combination_*"):
            if not solver_dir.is_dir():
                continue
                
            solver_name = solver_dir.name
            performance[solver_name] = {}
            
            # Find result files
            result_files = list(solver_dir.glob("results_*.txt"))
            print(f"  {solver_name}: Found {len(result_files)} result files")
            if not result_files:
                print(f"  ⚠️  No result files found for {solver_name}")
                continue
            
            # Use the latest result file
            result_file = sorted(result_files)[-1]
            print(f"  {solver_name}: Using result file {result_file.name}")
            
            # Parse result file
            try:
                instance_results = self._parse_result_file(result_file)
                performance[solver_name] = instance_results
                print(f"  ✓ {solver_name}: {len(instance_results)} instances")
            except Exception as e:
                print(f"  ❌ Failed to parse {solver_name}: {e}")
                continue
        
        self.performance_matrix = performance
        print(f"  Total: {len(performance)} solvers")
        return performance
    
    def _parse_result_file(self, result_file: Path) -> Dict[str, float]:
        """Parse a single result file"""
        results = {}
        
        with open(result_file, 'r', encoding='utf-8') as f:
            lines = f.readlines()
        
        for line in lines[1:]:  # 跳过标题行
            parts = line.strip().split('\t')
            if len(parts) >= 3:
                cnf_file = parts[0].strip()
                duration = parts[1].strip()
                situation = parts[2].strip()
                
                # Extract filename (remove path prefix and .cnf suffix)
                import os
                base_filename = os.path.basename(cnf_file)
                instance_name = base_filename.replace('.cnf', '')
                
                # Calculate PAR-2 score
                try:
                    time_val = float(duration)
                    # Identify successful solving states
                    success_states = ['SAT', 'UNSAT', 'SATISFIABLE', 'UNSATISFIABLE']
                    if situation.upper() in success_states:
                        par2_score = time_val
                    else:  # TIMEOUT or others
                        par2_score = self.timeout_penalty
                    
                    results[instance_name] = par2_score
                except ValueError:
                    # Parse failed, assume timeout
                    results[instance_name] = self.timeout_penalty
        
        return results
    
    def find_optimal_mappings(self) -> Dict[str, str]:
        """
        Find optimal combination for each instance: h*(j) = argmin p(h_i, j)
        Only consider instances successfully solved with time>=3s (exclude timeout and system error effects)
        """
        print("🎯 Computing optimal mappings...")
        
        optimal_mappings = {}
        timeout_instances = []
        too_fast_instances = []
        min_time_threshold = 3.0  # Minimum time threshold (seconds)
        
        # Get all instances and solvers
        all_instances = set()
        for solver_results in self.performance_matrix.values():
            all_instances.update(solver_results.keys())
        
        for instance in all_instances:
            best_solver = None
            best_score = float('inf')
            all_timeout = True  # Check if all solvers timeout
            all_too_fast = True  # Check if all solvers are too fast
            
            # 找到该实例上表现最好的求解器
            for solver_name, results in self.performance_matrix.items():
                if instance in results:
                    score = results[instance]
                    if score < self.timeout_penalty:  # 只考虑成功求解的
                        all_timeout = False
                        if score >= min_time_threshold:  # 只考虑时间>=3s的
                            all_too_fast = False
                            if score < best_score:
                                best_score = score
                                best_solver = solver_name
            
            if best_solver and not all_timeout and not all_too_fast:
                optimal_mappings[instance] = best_solver
            elif all_timeout:
                timeout_instances.append(instance)
            elif all_too_fast:
                too_fast_instances.append(instance)
        
        self.optimal_mapping = optimal_mappings
        print(f"  ✓ 完成 {len(optimal_mappings)} 个实例的最优映射")
        print(f"  ⚠️  排除 {len(timeout_instances)} 个全timeout实例")
        print(f"  ⚠️  排除 {len(too_fast_instances)} 个全<{min_time_threshold}s实例")
        
        # 统计每个求解器的优势实例数
        solver_counts = {}
        for solver in optimal_mappings.values():
            solver_counts[solver] = solver_counts.get(solver, 0) + 1
        
        print("  各求解器优势实例数:")
        for solver, count in sorted(solver_counts.items()):
            strategy = self.heuristic_library.get(solver, "unknown")
            print(f"    {solver}: {count} 个实例 ({strategy})")
        
        return optimal_mappings
    
    def partition_instance_space(self) -> Dict[str, List[str]]:
        """
        基于性能的实例空间分割: C_i = {j ∈ I_train | h*(j) = h_i}
        """
        print("🔀 分割实例空间...")
        
        clusters = {}
        
        # 根据最优求解器分组实例
        for instance, best_solver in self.optimal_mapping.items():
            if best_solver not in clusters:
                clusters[best_solver] = []
            clusters[best_solver].append(instance)
        
        self.clusters = clusters
        
        print("  聚类结果:")
        for solver, instances in clusters.items():
            strategy = self.heuristic_library.get(solver, "unknown")
            print(f"    {solver}: {len(instances)} 个实例")
            print(f"      策略: {strategy}")
        
        return clusters
    
    def compute_cluster_centroids(self) -> Dict[str, np.ndarray]:
        """
        计算每个聚类的特征空间中心: v̄_i = mean(features of instances in C_i)
        """
        print("📍 计算聚类中心...")
        
        centroids = {}
        
        for solver, instances in self.clusters.items():
            # 收集该聚类中所有实例的特征向量
            feature_vectors = []
            valid_instances = []
            
            for instance in instances:
                if instance in self.instance_features:
                    feature_vectors.append(self.instance_features[instance])
                    valid_instances.append(instance)
            
            if feature_vectors:
                # 计算中心点
                centroid = np.mean(feature_vectors, axis=0)
                centroids[solver] = centroid
                
                print(f"  ✓ {solver}: {len(valid_instances)}/{len(instances)} 有效实例")
            else:
                print(f"  ⚠️  {solver}: 无有效特征向量")
        
        self.cluster_centroids = centroids
        return centroids
    
    def train_adaptive_selector(self, dataset: str = "train_new", filter_weak_solvers: bool = False):
        """训练自适应选择器的完整流程"""
        print("🚀 开始训练自适应选择器...")
        print("=" * 50)
        
        # Step 1: 加载数据
        self.load_heuristic_library()
        self.load_instance_features(dataset)
        self.load_performance_matrix()
        
        # Step 2: 弱求解器过滤（可选）
        if filter_weak_solvers:
            print("🔍 启用弱求解器过滤功能...")
            self.filter_weak_solvers()
        
        # Step 3: 性能分析与聚类
        self.find_optimal_mappings()
        self.partition_instance_space()
        self.compute_cluster_centroids()
        
        # Step 4: 特征处理（如果使用了归一化特征则跳过）
        # 注意：如果load_instance_features使用了归一化特征，则无需再次标准化
        print("  ✓ 特征处理完成（使用归一化特征）")
        
        print("=" * 50)
        print("✅ 训练完成！")
        
        # 返回训练摘要
        return {
            'num_heuristics': len(self.heuristic_library),
            'num_instances': len(self.instance_features),
            'num_clusters': len(self.clusters),
            'performance_coverage': len(self.optimal_mapping)
        }
    
    def filter_weak_solvers(self):
        """过滤掉后40%的弱求解器"""
        print("  📊 分析求解器性能...")
        
        # 计算每个求解器的平均PAR-2分数
        solver_scores = {}
        for solver_name, results in self.performance_matrix.items():
            if not results:  # 跳过没有结果的求解器
                continue
            scores = list(results.values())
            avg_score = sum(scores) / len(scores)
            solver_scores[solver_name] = avg_score
        
        # 按性能排序
        sorted_solvers = sorted(solver_scores.items(), key=lambda x: x[1])
        
        print("  📈 求解器性能排名:")
        for i, (solver, score) in enumerate(sorted_solvers):
            print(f"    {i+1}. {solver}: {score:.2f}")
        
        # 去除后40%的弱求解器
        num_to_remove = max(1, int(len(sorted_solvers) * 0.4))
        weak_solvers = [solver for solver, _ in sorted_solvers[-num_to_remove:]]
        
        print(f"  🗑️  去除 {len(weak_solvers)} 个弱求解器: {weak_solvers}")
        
        # 更新启发式库，排除弱求解器
        filtered_heuristic_library = {}
        for solver, strategy in self.heuristic_library.items():
            if solver not in weak_solvers:
                filtered_heuristic_library[solver] = strategy
        
        self.heuristic_library = filtered_heuristic_library
        
        # 更新性能矩阵，排除弱求解器
        filtered_performance_matrix = {}
        for solver, results in self.performance_matrix.items():
            if solver not in weak_solvers:
                filtered_performance_matrix[solver] = results
        
        self.performance_matrix = filtered_performance_matrix
        
        print(f"  ✅ 过滤完成，保留 {len(filtered_heuristic_library)} 个求解器")
    
    def predict_best_solver(self, instance_features: np.ndarray) -> Tuple[str, float]:
        """
        为新实例预测最佳求解器
        
        Args:
            instance_features: 实例的特征向量（应该已经是归一化的）
            
        Returns:
            (best_solver, confidence): 最佳求解器名称和置信度
        """
        if not self.cluster_centroids:
            raise ValueError("模型未训练，请先调用 train_adaptive_selector()")
        
        # 如果特征已经是归一化的，直接使用；否则需要标准化
        # 注意：这里假设输入特征已经是归一化的
        normalized_features = instance_features
        
        # 计算到各聚类中心的距离
        distances = {}
        for solver, centroid in self.cluster_centroids.items():
            dist = np.linalg.norm(normalized_features - centroid)
            distances[solver] = dist
        
        # 选择最近的聚类
        best_solver = min(distances.keys(), key=lambda k: distances[k])
        min_distance = distances[best_solver]
        
        # 计算置信度（基于距离的反函数）
        confidence = 1.0 / (1.0 + min_distance)
        
        return best_solver, confidence
    
    def save_model(self, filepath: str):
        """保存训练好的模型"""
        model_data = {
            'heuristic_library': self.heuristic_library,
            'clusters': self.clusters,
            'cluster_centroids': {k: v.tolist() for k, v in self.cluster_centroids.items()},
            'feature_scaler': self.feature_scaler,
            'optimal_mapping': self.optimal_mapping,
            'timestamp': datetime.now().isoformat()
        }
        
        with open(filepath, 'wb') as f:
            pickle.dump(model_data, f)
        
        print(f"💾 模型已保存到: {filepath}")
    
    def load_model(self, filepath: str):
        """加载训练好的模型"""
        with open(filepath, 'rb') as f:
            model_data = pickle.load(f)
        
        self.heuristic_library = model_data['heuristic_library']
        self.clusters = model_data['clusters']
        self.cluster_centroids = {k: np.array(v) for k, v in model_data['cluster_centroids'].items()}
        self.feature_scaler = model_data['feature_scaler']
        self.optimal_mapping = model_data['optimal_mapping']
        
        print(f"📂 模型已加载: {filepath}")
        print(f"  训练时间: {model_data.get('timestamp', 'unknown')}")


def main():
    parser = argparse.ArgumentParser(description='自适应SAT求解器选择框架')
    parser.add_argument('--mode', choices=['train', 'predict', 'evaluate'], 
                       default='train', help='运行模式')
    parser.add_argument('--dataset', default='train_new', 
                       help='数据集名称 (default: train_new)')
    parser.add_argument('--model-path', default='adaptive_selector_model.pkl',
                       help='模型保存/加载路径')
    parser.add_argument('--combinations-dir', default='./combinations',
                       help='组合目录')
    parser.add_argument('--eval-results-dir', default='./eval_results/train_new/train_new', 
                       help='评估结果目录')
    parser.add_argument('--features-dir', default='./features',
                       help='特征目录')
    parser.add_argument('--filter-weak-solvers', action='store_true',
                       help='启用弱求解器过滤功能（去除后40%的弱求解器）')
    
    args = parser.parse_args()
    
    # 初始化选择器
    selector = AdaptiveSolverSelector(
        combinations_dir=args.combinations_dir,
        eval_results_dir=args.eval_results_dir,
        features_dir=args.features_dir
    )
    
    if args.mode == 'train':
        print("🎓 训练模式")
        summary = selector.train_adaptive_selector(args.dataset, args.filter_weak_solvers)
        
        print("\n训练摘要:")
        print(f"  启发式组合数: {summary['num_heuristics']}")
        print(f"  训练实例数: {summary['num_instances']}")
        print(f"  聚类数: {summary['num_clusters']}")
        print(f"  性能覆盖: {summary['performance_coverage']}")
        
        # 保存模型
        selector.save_model(args.model_path)
        
    elif args.mode == 'predict':
        print("🔮 预测模式")
        # TODO: 实现单实例预测
        selector.load_model(args.model_path)
        print("模型加载完成，等待实例预测...")
        
    elif args.mode == 'evaluate':
        print("📈 评估模式")
        # TODO: 实现交叉验证评估
        print("交叉验证评估功能开发中...")


if __name__ == "__main__":
    main()