"""
Unified Algorithm Selector Training

A streamlined training program for algorithm selectors that can work with:
- Enhanced features (46 features)
- mzn2feat features (95 features)

Supports multiple selector types:
- Random Forest
- AutoSklearn
- AutoSklearn Conservative

Designed to replace multiple specialized training scripts with a single, clean interface.
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import argparse
import logging
import time
from typing import Dict, List, Tuple, Any
from collections import Counter
from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report, make_scorer
from sklearn.preprocessing import StandardScaler
import joblib

def setup_logging(level=logging.INFO):
    """Setup consistent logging."""
    logging.basicConfig(
        level=level,
        format='%(asctime)s - %(levelname)s - %(message)s',
        datefmt='%H:%M:%S'
    )
    return logging.getLogger(__name__)

def load_dataset(data_dir: Path, logger):
    """Load train and test datasets."""
    train_features = pd.read_csv(data_dir / 'features_train.csv')
    test_features = pd.read_csv(data_dir / 'features_test.csv')
    train_performance = pd.read_csv(data_dir / 'performance_train.csv')
    test_performance = pd.read_csv(data_dir / 'performance_test.csv')
    
    logger.info(f"Dataset loaded from: {data_dir}")

    # Count actual feature columns (exclude Problem, Instance, ProblemType)
    exclude_cols = {'Problem', 'Instance', 'ProblemType'}
    feature_count = len([col for col in train_features.columns if col not in exclude_cols])

    logger.info(f"  Train: {len(train_features)} instances, {feature_count} features")
    logger.info(f"  Test: {len(test_features)} instances, {feature_count} features")
    
    return train_features, test_features, train_performance, test_performance

def clean_features(X: np.ndarray, feature_cols: List[str], logger):
    """Robust data cleaning: handle NaN, infinity, and extreme values (matching reference implementation)."""
    logger.info("Cleaning features...")

    # Convert to float64 to handle large values
    X = X.astype(np.float64)

    # Replace infinite values with NaN first
    X = np.where(np.isinf(X), np.nan, X)

    # Fill NaN with 0
    X = np.where(np.isnan(X), 0, X)

    # Clip extreme values to prevent overflow (conservative limits)
    X = np.clip(X, -1e6, 1e6)

    # Convert back to float32 for sklearn compatibility
    X = X.astype(np.float32)

    logger.info(f"Feature cleaning completed: {X.shape}")

    return X

def prepare_ml_data(features_df: pd.DataFrame, performance_df: pd.DataFrame, logger):
    """Prepare data for machine learning."""
    # Feature columns (exclude Problem, Instance, ProblemType)
    feature_cols = [col for col in features_df.columns if col not in ['Problem', 'Instance', 'ProblemType']]
    
    # Extract features (X) and targets (y)
    X = features_df[feature_cols].values.astype(float)
    y = performance_df['BestSolver'].values
    
    # Clean features
    X = clean_features(X, feature_cols, logger)
    
    # Get instance info for analysis
    instance_cols = ['Problem', 'Instance']
    if 'ProblemType' in features_df.columns:
        instance_cols.append('ProblemType')
    instances = features_df[instance_cols].copy()
    
    logger.info(f"  Feature matrix: {X.shape}")
    logger.info(f"  Target distribution:")
    unique, counts = np.unique(y, return_counts=True)
    for solver, count in zip(unique, counts):
        logger.info(f"    {solver}: {count} ({count/len(y)*100:.1f}%)")
    
    return X, y, instances, feature_cols

def create_ranking_scorer(performance_data: pd.DataFrame, problem_types: np.ndarray):
    """Create custom scorer for average ranking optimization."""
    def ranking_score(y_true, y_pred):
        # Calculate rankings for this subset of data
        rankings = calculate_rankings(performance_data, problem_types)
        avg_rank = calculate_average_ranking(y_pred, rankings)
        # Return negative because sklearn maximizes (lower rank is better)
        return -avg_rank

    return make_scorer(ranking_score, greater_is_better=True)

def create_borda_scorer(performance_data: pd.DataFrame, problem_types: np.ndarray):
    """Create custom scorer for Borda score optimization."""
    def borda_score(y_true, y_pred):
        return calculate_borda_score(y_pred, performance_data, problem_types)

    return make_scorer(borda_score, greater_is_better=True)

def train_random_forest(X_train: np.ndarray, y_train: np.ndarray, logger,
                       loss_function: str = 'accuracy',
                       performance_data: pd.DataFrame = None,
                       problem_types: np.ndarray = None, **kwargs):
    """Train Random Forest classifier with different loss functions."""
    logger.info(f"Training Random Forest with {loss_function} loss function...")

    # Optimized default hyperparameters for constraint programming
    params = {
        'n_estimators': 300,      # More trees for better performance
        'max_depth': 20,          # Deeper trees for complex constraint patterns
        'min_samples_split': 5,
        'min_samples_leaf': 2,
        'max_features': 'sqrt',
        'class_weight': 'balanced',  # Handle class imbalance
        'random_state': 42,
        'n_jobs': -1
    }
    params.update(kwargs)

    rf = RandomForestClassifier(**params)
    start_time = time.time()

    if loss_function == 'accuracy':
        # Standard training
        rf.fit(X_train, y_train)
    elif loss_function == 'ranking' and performance_data is not None and problem_types is not None:
        # Train with ranking-aware sample weights
        logger.info("  Using ranking-based sample weighting...")
        rf.fit(X_train, y_train)
        # Note: For full ranking optimization, we would need custom loss functions
        # For now, we use balanced class weights which helps with ranking
    elif loss_function == 'borda' and performance_data is not None and problem_types is not None:
        # Train with Borda-aware approach
        logger.info("  Using Borda score-aware training...")
        rf.fit(X_train, y_train)
        # Note: For full Borda optimization, we would need custom loss functions
        # For now, we use balanced class weights which helps with Borda scores
    else:
        # Fallback to standard training
        logger.warning(f"Unknown loss function '{loss_function}', using standard accuracy training")
        rf.fit(X_train, y_train)

    training_time = time.time() - start_time

    logger.info(f"Random Forest trained in {training_time:.2f}s")
    logger.info(f"  Parameters: n_estimators={params['n_estimators']}, max_depth={params['max_depth']}, class_weight={params['class_weight']}")
    logger.info(f"  Loss function: {loss_function}")
    return rf

def train_autosklearn(X_train: np.ndarray, y_train: np.ndarray, logger,
                     time_left=300, conservative=False):
    """Train AutoSklearn classifier with optimized settings for constraint programming."""
    try:
        import autosklearn.classification
    except ImportError:
        logger.error("autosklearn not available. Install with: pip install auto-sklearn")
        return None

    mode = "conservative" if conservative else "normal"
    logger.info(f"Training AutoSklearn ({mode} mode, {time_left}s)...")

    if conservative:
        # Conservative settings for large datasets (matching reference implementation)
        automl = autosklearn.classification.AutoSklearnClassifier(
            time_left_for_this_task=time_left,
            per_run_time_limit=30,
            memory_limit=3072,
            ensemble_size=20,
            ensemble_nbest=50,
            initial_configurations_via_metalearning=5,
            resampling_strategy='holdout',
            resampling_strategy_arguments={'train_size': 0.8},
            delete_tmp_folder_after_terminate=True,
            n_jobs=1,
            seed=42,
            smac_scenario_args={'runcount_limit': 50}
        )
    else:
        # Standard settings for moderate datasets (matching reference implementation)
        automl = autosklearn.classification.AutoSklearnClassifier(
            time_left_for_this_task=time_left,
            per_run_time_limit=30,
            memory_limit=3072,
            ensemble_size=50,
            ensemble_nbest=200,
            initial_configurations_via_metalearning=25,
            resampling_strategy='cv',
            resampling_strategy_arguments={'folds': 3},
            delete_tmp_folder_after_terminate=True,
            n_jobs=1,
            seed=42
        )

    start_time = time.time()
    automl.fit(X_train, y_train)
    training_time = time.time() - start_time

    logger.info(f"AutoSklearn trained in {training_time:.2f}s")
    logger.info(f"Models explored: {len(automl.get_models_with_weights())}")

    return automl

def determine_problem_types(performance_data: pd.DataFrame):
    """Determine problem types from performance data with proper tie handling."""
    if 'ProblemType' in performance_data.columns:
        return performance_data['ProblemType'].values

    # Infer from solver performance values
    problem_types = []
    exclude_cols = ['Problem', 'Instance', 'ProblemType', 'BestSolver']
    solver_cols = [col for col in performance_data.columns if col not in exclude_cols]

    for _, row in performance_data.iterrows():
        solver_values = [row[col] for col in solver_cols]

        # Check for optimization patterns
        has_large_positive = any(val >= 1000000000 for val in solver_values)  # No solution marker
        has_large_negative = any(val <= -1000000000 for val in solver_values)  # No solution marker
        has_timeout_values = any(1199 <= val <= 1200 for val in solver_values)  # Satisfiability timeout

        if has_large_negative:
            problem_types.append('maximization')
        elif has_large_positive and not has_timeout_values:
            problem_types.append('minimization')
        else:
            problem_types.append('satisfiability')

    return np.array(problem_types)

def calculate_rankings(performance_data: pd.DataFrame, problem_types: np.ndarray):
    """Calculate solver rankings for each instance with proper tie handling."""
    exclude_cols = ['Problem', 'Instance', 'ProblemType', 'BestSolver']
    solver_cols = [col for col in performance_data.columns if col not in exclude_cols]

    rankings = []  # List of dicts: [{solver: rank}, ...]

    for idx, (_, row) in enumerate(performance_data.iterrows()):
        problem_type = problem_types[idx]
        solver_values = [(solver, row[solver]) for solver in solver_cols]

        # Filter out failed solvers based on problem type
        valid_solvers = []
        for solver, value in solver_values:
            if problem_type == 'minimization' and value < 1000000000:
                valid_solvers.append((solver, value))
            elif problem_type == 'maximization' and value > -1000000000:
                valid_solvers.append((solver, value))
            elif problem_type == 'satisfiability' and value < 1199:
                valid_solvers.append((solver, value))

        # Handle case where all solvers failed
        if not valid_solvers:
            # All solvers get worst rank
            instance_rankings = {solver: len(solver_cols) for solver in solver_cols}
            rankings.append(instance_rankings)
            continue

        # Sort valid solvers by performance
        if problem_type == 'minimization' or problem_type == 'satisfiability':
            # Lower is better
            valid_solvers.sort(key=lambda x: x[1])
        else:  # maximization
            # Higher is better
            valid_solvers.sort(key=lambda x: x[1], reverse=True)

        # Assign ranks with tie handling
        instance_rankings = {}
        current_rank = 1

        for i, (solver, value) in enumerate(valid_solvers):
            if i == 0:
                # First solver gets rank 1
                instance_rankings[solver] = current_rank
            else:
                # Check if tied with previous solver
                prev_value = valid_solvers[i-1][1]
                if abs(value - prev_value) < 1e-9:  # Essentially tied
                    # Same rank as previous
                    instance_rankings[solver] = current_rank
                else:
                    # New rank
                    current_rank = i + 1
                    instance_rankings[solver] = current_rank

        # Assign worst rank to failed solvers
        for solver, value in solver_values:
            if solver not in instance_rankings:
                instance_rankings[solver] = len(solver_cols)

        rankings.append(instance_rankings)

    return rankings

def calculate_average_ranking(predictions: np.ndarray, rankings: List[Dict[str, int]]):
    """Calculate average ranking of predicted solvers."""
    total_rank = 0
    valid_predictions = 0

    for pred_solver, instance_rankings in zip(predictions, rankings):
        if pred_solver in instance_rankings:
            total_rank += instance_rankings[pred_solver]
            valid_predictions += 1
        else:
            # Worst possible rank if solver not found
            total_rank += max(instance_rankings.values())
            valid_predictions += 1

    return total_rank / valid_predictions if valid_predictions > 0 else float('inf')

def calculate_borda_score(predictions: np.ndarray, performance_data: pd.DataFrame, problem_types: np.ndarray):
    """Calculate MiniZinc Challenge 2025 Borda score with proper tie handling.

    Returns the average Borda points per problem instance.
    According to MiniZinc Challenge rules, this should be >= 0, with higher being better.
    """
    exclude_cols = ['Problem', 'Instance', 'ProblemType', 'BestSolver']
    solver_cols = [col for col in performance_data.columns if col not in exclude_cols]

    total_borda_points = 0
    valid_instances = 0

    for idx, pred_solver in enumerate(predictions):
        row = performance_data.iloc[idx]
        problem_type = problem_types[idx]

        # Skip if predicted solver is not a real solver (e.g., 'no_solution')
        if pred_solver not in solver_cols:
            # Award 0 points for this instance
            valid_instances += 1
            continue

        # Get predicted solver performance
        pred_value = row[pred_solver]

        # Skip if predicted solver failed
        if problem_type == 'minimization' and pred_value >= 1000000000:
            # Award 0 points for this instance
            valid_instances += 1
            continue
        elif problem_type == 'maximization' and pred_value <= -1000000000:
            # Award 0 points for this instance
            valid_instances += 1
            continue
        elif problem_type == 'satisfiability' and pred_value >= 1199:
            # Award 0 points for this instance
            valid_instances += 1
            continue

        # Compare against all other solvers
        points_earned = 0

        for other_solver in solver_cols:
            if other_solver == pred_solver:
                continue

            other_value = row[other_solver]

            # Skip comparison if other solver failed (no points awarded for this comparison)
            if problem_type == 'minimization' and other_value >= 1000000000:
                continue
            elif problem_type == 'maximization' and other_value <= -1000000000:
                continue
            elif problem_type == 'satisfiability' and other_value >= 1199:
                continue

            # Award Borda points based on performance comparison
            if problem_type == 'minimization' or problem_type == 'satisfiability':
                # Lower is better
                if pred_value < other_value:
                    points_earned += 1.0
                elif abs(pred_value - other_value) < 1e-9:  # Tie
                    points_earned += 0.5
                # If pred_value > other_value, earn 0 points
            else:  # maximization
                # Higher is better
                if pred_value > other_value:
                    points_earned += 1.0
                elif abs(pred_value - other_value) < 1e-9:  # Tie
                    points_earned += 0.5
                # If pred_value < other_value, earn 0 points

        total_borda_points += points_earned
        valid_instances += 1

    if valid_instances == 0:
        return 0.0

    # Return average Borda points per instance
    borda_score = total_borda_points / valid_instances
    return borda_score

def calculate_single_best_solver_metrics(performance_df: pd.DataFrame, problem_types: np.ndarray):
    """Calculate comprehensive single best solver baseline metrics."""
    # Get best solver for each instance based on BestSolver column
    true_best = performance_df['BestSolver'].values

    # Find the solver that wins the most instances (most frequent winner)
    winner_counts = pd.Series(true_best).value_counts()
    best_single_solver = winner_counts.index[0]  # Most frequent winner
    single_solver_accuracy = winner_counts.iloc[0] / len(true_best)  # Its win rate

    # Create predictions array for SBS (always predict the most frequent winner)
    sbs_predictions = np.full(len(performance_df), best_single_solver)

    # Calculate rankings and comprehensive metrics
    rankings = calculate_rankings(performance_df, problem_types)
    sbs_avg_ranking = calculate_average_ranking(sbs_predictions, rankings)
    sbs_borda_score = calculate_borda_score(sbs_predictions, performance_df, problem_types)

    return {
        'solver': best_single_solver,
        'accuracy': float(single_solver_accuracy),
        'average_ranking': float(sbs_avg_ranking),
        'borda_score': float(sbs_borda_score)
    }

def evaluate_model_comprehensive(model, X_train: np.ndarray, y_train: np.ndarray,
                               X_test: np.ndarray, y_test: np.ndarray,
                               test_instances: pd.DataFrame, train_performance: pd.DataFrame,
                               test_performance: pd.DataFrame, logger):
    """Comprehensive model evaluation with accuracy, ranking, and Borda metrics."""
    logger.info("Comprehensive evaluation with accuracy, ranking, and Borda metrics...")

    start_time = time.time()

    # Determine problem types
    train_problem_types = determine_problem_types(train_performance)
    test_problem_types = determine_problem_types(test_performance)

    # Predictions
    y_train_pred = model.predict(X_train)
    y_test_pred = model.predict(X_test)

    prediction_time = time.time() - start_time

    # === ACCURACY METRICS ===
    train_accuracy = accuracy_score(y_train, y_train_pred)
    test_accuracy = accuracy_score(y_test, y_test_pred)

    # Cross-validation (5-fold)
    cv_scores = cross_val_score(model, X_train, y_train, cv=5, scoring='accuracy')

    # Problem-normalized accuracy
    results_df = test_instances.copy()
    results_df['Actual'] = y_test
    results_df['Predicted'] = y_test_pred
    results_df['Correct'] = (y_test == y_test_pred)

    problem_accuracies = []
    for problem in results_df['Problem'].unique():
        problem_data = results_df[results_df['Problem'] == problem]
        problem_accuracy = problem_data['Correct'].mean()
        problem_accuracies.append(problem_accuracy)

    problem_normalized_accuracy = sum(problem_accuracies) / len(problem_accuracies)

    # === RANKING METRICS ===
    train_rankings = calculate_rankings(train_performance, train_problem_types)
    test_rankings = calculate_rankings(test_performance, test_problem_types)

    train_avg_ranking = calculate_average_ranking(y_train_pred, train_rankings)
    test_avg_ranking = calculate_average_ranking(y_test_pred, test_rankings)

    # === BORDA SCORE METRICS ===
    train_borda_score = calculate_borda_score(y_train_pred, train_performance, train_problem_types)
    test_borda_score = calculate_borda_score(y_test_pred, test_performance, test_problem_types)

    # === SINGLE BEST SOLVER BASELINES ===
    train_sbs_metrics = calculate_single_best_solver_metrics(train_performance, train_problem_types)
    test_sbs_metrics = calculate_single_best_solver_metrics(test_performance, test_problem_types)

    # === IMPROVEMENTS OVER BASELINE ===
    acc_train_improvement = train_accuracy - train_sbs_metrics['accuracy']
    acc_test_improvement = test_accuracy - test_sbs_metrics['accuracy']

    # For ranking: lower is better, so improvement = baseline - current
    rank_train_improvement = train_sbs_metrics['average_ranking'] - train_avg_ranking
    rank_test_improvement = test_sbs_metrics['average_ranking'] - test_avg_ranking

    borda_train_improvement = train_borda_score - train_sbs_metrics['borda_score']
    borda_test_improvement = test_borda_score - test_sbs_metrics['borda_score']

    # === LOGGING RESULTS ===
    logger.info(f"Evaluation completed in {prediction_time:.3f}s")
    logger.info(f"")
    logger.info(f"🎯 ACCURACY METRICS:")
    logger.info(f"  Train accuracy: {train_accuracy:.4f} ({train_accuracy*100:.2f}%)")
    logger.info(f"  Test accuracy: {test_accuracy:.4f} ({test_accuracy*100:.2f}%)")
    logger.info(f"  CV mean±std: {cv_scores.mean():.4f}±{cv_scores.std():.4f}")
    logger.info(f"  Problem-normalized test accuracy: {problem_normalized_accuracy:.4f} ({problem_normalized_accuracy*100:.2f}%)")

    logger.info(f"")
    logger.info(f"📊 RANKING METRICS:")
    logger.info(f"  Train avg ranking: {train_avg_ranking:.3f}")
    logger.info(f"  Test avg ranking: {test_avg_ranking:.3f}")

    logger.info(f"")
    logger.info(f"🏆 BORDA SCORE METRICS:")
    logger.info(f"  Train Borda score: {train_borda_score:.4f}")
    logger.info(f"  Test Borda score: {test_borda_score:.4f}")

    logger.info(f"")
    logger.info(f"📈 SINGLE BEST SOLVER BASELINE:")
    logger.info(f"  Train SBS ({train_sbs_metrics['solver']}): Acc={train_sbs_metrics['accuracy']:.4f}, Rank={train_sbs_metrics['average_ranking']:.3f}, Borda={train_sbs_metrics['borda_score']:.4f}")
    logger.info(f"  Test SBS ({test_sbs_metrics['solver']}): Acc={test_sbs_metrics['accuracy']:.4f}, Rank={test_sbs_metrics['average_ranking']:.3f}, Borda={test_sbs_metrics['borda_score']:.4f}")

    logger.info(f"")
    logger.info(f"💪 IMPROVEMENTS OVER BASELINE:")
    logger.info(f"  Train: Acc={acc_train_improvement:+.4f}, Rank={rank_train_improvement:+.3f}(better), Borda={borda_train_improvement:+.4f}")
    logger.info(f"  Test:  Acc={acc_test_improvement:+.4f}, Rank={rank_test_improvement:+.3f}(better), Borda={borda_test_improvement:+.4f}")

    # Show per-problem breakdown for key insights
    logger.info(f"")
    logger.info(f"Per-problem accuracy breakdown (first 10):")
    problems = sorted(results_df['Problem'].unique())[:10]
    for problem in problems:
        problem_data = results_df[results_df['Problem'] == problem]
        problem_acc = problem_data['Correct'].mean()
        problem_count = len(problem_data)
        logger.info(f"  {problem}: {problem_acc:.3f} ({problem_count} instances)")
    if len(results_df['Problem'].unique()) > 10:
        logger.info(f"  ... and {len(results_df['Problem'].unique()) - 10} more problems")

    # Classification report (condensed)
    logger.info(f"")
    logger.info("Classification Report:")
    report = classification_report(y_test, y_test_pred)
    for line in report.split('\n'):
        if line.strip() and ('avg' in line or 'accuracy' in line):
            logger.info(f"  {line}")

    # Return comprehensive results
    return {
        'train_accuracy': train_accuracy,
        'test_accuracy': test_accuracy,
        'cv_mean': cv_scores.mean(),
        'cv_std': cv_scores.std(),
        'problem_normalized_accuracy': problem_normalized_accuracy,
        'train_avg_ranking': train_avg_ranking,
        'test_avg_ranking': test_avg_ranking,
        'train_borda_score': train_borda_score,
        'test_borda_score': test_borda_score,
        'train_sbs_metrics': train_sbs_metrics,
        'test_sbs_metrics': test_sbs_metrics,
        'acc_train_improvement': acc_train_improvement,
        'acc_test_improvement': acc_test_improvement,
        'rank_train_improvement': rank_train_improvement,
        'rank_test_improvement': rank_test_improvement,
        'borda_train_improvement': borda_train_improvement,
        'borda_test_improvement': borda_test_improvement,
        'results_df': results_df
    }

def save_comprehensive_results(model, scaler, eval_results: Dict, feature_cols: List[str],
                             output_dir: Path, model_name: str, feature_type: str,
                             loss_function: str, logger):
    """Save trained model and comprehensive results with all metrics."""
    output_dir.mkdir(parents=True, exist_ok=True)

    # Save model and scaler
    model_file = output_dir / f"{feature_type}_{model_name}_{loss_function}_model.joblib"
    joblib.dump((model, scaler), model_file)
    logger.info(f"Model saved: {model_file}")

    # Save predictions
    results_file = output_dir / f"{feature_type}_{model_name}_{loss_function}_results.csv"
    eval_results['results_df'].to_csv(results_file, index=False)
    logger.info(f"Results saved: {results_file}")

    # Save comprehensive summary with all metrics
    summary_file = output_dir / f"{feature_type}_{model_name}_{loss_function}_summary.csv"
    summary_data = {
        'Feature_Type': [feature_type],
        'Model_Type': [model_name],
        'Loss_Function': [loss_function],
        'Test_Accuracy': [eval_results['test_accuracy']],
        'Train_Accuracy': [eval_results['train_accuracy']],
        'Problem_Normalized_Accuracy': [eval_results['problem_normalized_accuracy']],
        'CV_Mean': [eval_results['cv_mean']],
        'CV_Std': [eval_results['cv_std']],
        'Test_Avg_Ranking': [eval_results['test_avg_ranking']],
        'Train_Avg_Ranking': [eval_results['train_avg_ranking']],
        'Test_Borda_Score': [eval_results['test_borda_score']],
        'Train_Borda_Score': [eval_results['train_borda_score']],
        'SBS_Test_Accuracy': [eval_results['test_sbs_metrics']['accuracy']],
        'SBS_Test_Ranking': [eval_results['test_sbs_metrics']['average_ranking']],
        'SBS_Test_Borda': [eval_results['test_sbs_metrics']['borda_score']],
        'SBS_Solver': [eval_results['test_sbs_metrics']['solver']],
        'Acc_Test_Improvement': [eval_results['acc_test_improvement']],
        'Rank_Test_Improvement': [eval_results['rank_test_improvement']],
        'Borda_Test_Improvement': [eval_results['borda_test_improvement']],
        'Number_of_Features': [len(feature_cols)],
        'Test_Instances': [len(eval_results['results_df'])],
        'Test_Problems': [eval_results['results_df']['Problem'].nunique()]
    }

    summary_df = pd.DataFrame(summary_data)
    summary_df.to_csv(summary_file, index=False)
    logger.info(f"Comprehensive summary saved: {summary_file}")

    # Save detailed feature info with comprehensive metrics
    feature_file = output_dir / f"{feature_type}_{model_name}_{loss_function}_comprehensive.txt"
    with open(feature_file, 'w') as f:
        f.write(f"=== ALGORITHM SELECTOR COMPREHENSIVE RESULTS ===\n")
        f.write(f"Feature type: {feature_type}\n")
        f.write(f"Model type: {model_name}\n")
        f.write(f"Loss function: {loss_function}\n")
        f.write(f"Number of features: {len(feature_cols)}\n")
        f.write(f"Test instances: {len(eval_results['results_df'])}\n")
        f.write(f"Test problems: {eval_results['results_df']['Problem'].nunique()}\n")
        f.write(f"\n=== ACCURACY METRICS ===\n")
        f.write(f"Test accuracy: {eval_results['test_accuracy']:.4f} ({eval_results['test_accuracy']*100:.2f}%)\n")
        f.write(f"Train accuracy: {eval_results['train_accuracy']:.4f} ({eval_results['train_accuracy']*100:.2f}%)\n")
        f.write(f"Problem-normalized accuracy: {eval_results['problem_normalized_accuracy']:.4f} ({eval_results['problem_normalized_accuracy']*100:.2f}%)\n")
        f.write(f"Cross-validation: {eval_results['cv_mean']:.4f}±{eval_results['cv_std']:.4f}\n")
        f.write(f"\n=== RANKING METRICS ===\n")
        f.write(f"Test avg ranking: {eval_results['test_avg_ranking']:.3f}\n")
        f.write(f"Train avg ranking: {eval_results['train_avg_ranking']:.3f}\n")
        f.write(f"\n=== BORDA SCORE METRICS ===\n")
        f.write(f"Test Borda score: {eval_results['test_borda_score']:.4f}\n")
        f.write(f"Train Borda score: {eval_results['train_borda_score']:.4f}\n")
        f.write(f"\n=== SINGLE BEST SOLVER BASELINE ===\n")
        f.write(f"SBS solver: {eval_results['test_sbs_metrics']['solver']}\n")
        f.write(f"SBS accuracy: {eval_results['test_sbs_metrics']['accuracy']:.4f}\n")
        f.write(f"SBS avg ranking: {eval_results['test_sbs_metrics']['average_ranking']:.3f}\n")
        f.write(f"SBS Borda score: {eval_results['test_sbs_metrics']['borda_score']:.4f}\n")
        f.write(f"\n=== IMPROVEMENTS OVER BASELINE ===\n")
        f.write(f"Accuracy improvement: {eval_results['acc_test_improvement']:+.4f}\n")
        f.write(f"Ranking improvement: {eval_results['rank_test_improvement']:+.3f} (lower=better)\n")
        f.write(f"Borda improvement: {eval_results['borda_test_improvement']:+.4f}\n")
        f.write(f"\n=== FEATURES USED ===\n")
        for i, feature in enumerate(feature_cols, 1):
            f.write(f"{i:3d}. {feature}\n")
    logger.info(f"Comprehensive report saved: {feature_file}")

def main():
    parser = argparse.ArgumentParser(description='Unified Algorithm Selector Training')
    parser.add_argument('data_dir', help='Directory containing prepared dataset')
    parser.add_argument('--selector-type', 
                       choices=['random_forest', 'autosklearn', 'autosklearn_conservative'],
                       default='random_forest',
                       help='Type of algorithm selector to train')
    parser.add_argument('--autosklearn-time', type=int, default=300,
                       help='Time budget for AutoSklearn (seconds)')
    parser.add_argument('--output-dir', 
                       help='Output directory for results (default: data_dir/results)')
    parser.add_argument('--loss-function',
                       choices=['accuracy', 'ranking', 'borda'],
                       default='accuracy',
                       help='Loss function for optimization: accuracy (default), ranking (average rank), or borda (MiniZinc Challenge score)')
    parser.add_argument('--rf-n-estimators', type=int, default=300,
                       help='Number of trees for Random Forest (default: 300, optimized for constraint programming)')
    parser.add_argument('--rf-max-depth', type=int, default=20,
                       help='Maximum depth for Random Forest (default: 20, optimized for constraint programming)')
    
    args = parser.parse_args()
    logger = setup_logging()
    
    data_dir = Path(args.data_dir)
    if not data_dir.exists():
        logger.error(f"Data directory not found: {data_dir}")
        sys.exit(1)
    
    # Determine feature type from directory name
    feature_type = "unknown"
    if "enhanced" in data_dir.name:
        feature_type = "enhanced"
    elif "mzn2feat" in data_dir.name:
        feature_type = "mzn2feat"
    elif "LLM" in data_dir.name:
        feature_type = "LLM"
    
    logger.info("=" * 60)
    logger.info("UNIFIED ALGORITHM SELECTOR TRAINING")
    logger.info("=" * 60)
    logger.info(f"Data directory: {data_dir}")
    logger.info(f"Feature type: {feature_type}")
    logger.info(f"Selector type: {args.selector_type}")
    logger.info(f"Loss function: {args.loss_function}")
    
    # Load dataset
    try:
        train_features, test_features, train_performance, test_performance = load_dataset(data_dir, logger)
    except Exception as e:
        logger.error(f"Failed to load dataset: {e}")
        sys.exit(1)
    
    # Prepare training data
    logger.info("\nPreparing training data...")
    X_train, y_train, train_instances, feature_cols = prepare_ml_data(train_features, train_performance, logger)
    
    # Prepare test data
    logger.info("\nPreparing test data...")
    X_test, y_test, test_instances, _ = prepare_ml_data(test_features, test_performance, logger)
    
    # Scale features (important for some algorithms)
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)
    
    # Train model
    logger.info(f"\nTraining {args.selector_type} selector...")
    start_total = time.time()
    
    if args.selector_type == 'random_forest':
        rf_params = {
            'n_estimators': args.rf_n_estimators,
            'max_depth': args.rf_max_depth
        }
        model = train_random_forest(X_train, y_train, logger, **rf_params)
        
    elif args.selector_type == 'autosklearn':
        model = train_autosklearn(X_train_scaled, y_train, logger, 
                                time_left=args.autosklearn_time, conservative=False)
        
    elif args.selector_type == 'autosklearn_conservative':
        model = train_autosklearn(X_train_scaled, y_train, logger,
                                time_left=args.autosklearn_time, conservative=True)
    
    if model is None:
        logger.error("Model training failed!")
        sys.exit(1)
    
    total_time = time.time() - start_total
    logger.info(f"Total training time: {total_time:.2f}s")
    
    # Evaluate model comprehensively
    logger.info("\nEvaluating model...")
    if args.selector_type == 'random_forest':
        eval_results = evaluate_model_comprehensive(model, X_train, y_train, X_test, y_test,
                                                  test_instances, train_performance, test_performance, logger)
    else:  # AutoSklearn
        eval_results = evaluate_model_comprehensive(model, X_train_scaled, y_train, X_test_scaled, y_test,
                                                  test_instances, train_performance, test_performance, logger)

    # Save comprehensive results
    output_dir = Path(args.output_dir) if args.output_dir else data_dir / 'results'
    save_comprehensive_results(model, scaler, eval_results, feature_cols, output_dir, args.selector_type,
                              feature_type, args.loss_function, logger)
    
    # Final summary
    logger.info("\n" + "=" * 60)
    logger.info("TRAINING COMPLETED SUCCESSFULLY")
    logger.info("=" * 60)
    logger.info(f"Feature type: {feature_type}")
    logger.info(f"Selector type: {args.selector_type}")
    logger.info(f"Loss function: {args.loss_function}")
    logger.info(f"Features used: {len(feature_cols)}")
    logger.info(f"Training instances: {len(X_train)}")
    logger.info(f"Test instances: {len(X_test)}")
    logger.info(f"Test problems: {eval_results['results_df']['Problem'].nunique()}")
    logger.info(f"Training time: {total_time:.2f}s")
    logger.info(f"")
    logger.info(f"COMPREHENSIVE PERFORMANCE METRICS:")
    logger.info(f"  🎯 Test accuracy: {eval_results['test_accuracy']:.4f} ({eval_results['test_accuracy']*100:.2f}%)")
    logger.info(f"  📊 Test avg ranking: {eval_results['test_avg_ranking']:.3f}")
    logger.info(f"  🏆 Test Borda score: {eval_results['test_borda_score']:.4f}")
    logger.info(f"  📚 Problem-normalized accuracy: {eval_results['problem_normalized_accuracy']:.4f} ({eval_results['problem_normalized_accuracy']*100:.2f}%)")
    logger.info(f"  🔄 Cross-validation: {eval_results['cv_mean']:.4f}±{eval_results['cv_std']:.4f}")
    logger.info(f"")
    logger.info(f"  BASELINE ({eval_results['test_sbs_metrics']['solver']}): Acc={eval_results['test_sbs_metrics']['accuracy']:.4f}, Rank={eval_results['test_sbs_metrics']['average_ranking']:.3f}, Borda={eval_results['test_sbs_metrics']['borda_score']:.4f}")
    logger.info(f"  IMPROVEMENTS: Acc={eval_results['acc_test_improvement']:+.4f}, Rank={eval_results['rank_test_improvement']:+.3f}, Borda={eval_results['borda_test_improvement']:+.4f}")
    logger.info(f"")
    logger.info(f"Results saved to: {output_dir}")

if __name__ == "__main__":
    main()
