#!/usr/bin/env python3
"""
Feature Importance and Ablation Visualization Toolkit for LLM vs mzn2feat Research

This toolkit provides comprehensive visualization capabilities for analyzing:
1. Feature importance comparison between LLM and mzn2feat extractors
2. Feature ablation studies
3. Feature interpretability analysis

Usage:
    python feature_visualization_toolkit.py --problem FLECC --selector random_forest
"""

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import pickle
import argparse
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import sys
sys.path.append('../algorithm_selection')
from single_dataset_trainer import SingleDatasetTrainer
import warnings
warnings.filterwarnings('ignore')

# Set style for publication-quality plots with Times New Roman font
plt.rcParams.update({
    'font.family': 'serif',
    'font.serif': ['Times New Roman'],
    'font.size': 28,           # Base font size - enlarged
    'axes.titlesize': 28,      # Axes title font size - slightly reduced
    'axes.labelsize': 30,      # X and Y label font size - enlarged
    'xtick.labelsize': 26,     # X tick label font size - enlarged
    'ytick.labelsize': 26,     # Y tick label font size - enlarged
    'legend.fontsize': 28,     # Legend font size - enlarged
    'figure.titlesize': 32     # Figure title font size - slightly reduced
})
plt.style.use('default')
sns.set_palette("husl")

class FeatureVisualizationToolkit:
    """Comprehensive toolkit for feature importance and ablation visualization."""
    
    def __init__(self, problem, selector_type, loss_function='accuracy'):
        self.problem = problem
        self.selector_type = selector_type
        self.loss_function = loss_function
        self.figures_dir = Path('./figures')
        self.figures_dir.mkdir(exist_ok=True)
        
        # Load data and models
        self.load_data()
        self.load_models()
    
    def load_data(self):
        """Load feature datasets for comparison (features only - no performance data)."""
        # Navigate to datasets directory from src/visualization/
        base_path = Path('../datasets') / self.problem
        
        # Load mzn2feat data
        mzn2feat_path = base_path / 'mzn2feat'
        self.mzn2feat_train = pd.read_csv(mzn2feat_path / 'features_train.csv')
        self.mzn2feat_test = pd.read_csv(mzn2feat_path / 'features_test.csv')
        
        # Load LLM extractor data (use first available)
        llm_extractors = [d for d in base_path.iterdir() 
                         if d.is_dir() and d.name.startswith('lmtuner')]
        
        if llm_extractors:
            llm_path = llm_extractors[0]  # Use first extractor
            self.llm_extractor_name = llm_path.name
            self.llm_train = pd.read_csv(llm_path / 'features_train.csv')
            self.llm_test = pd.read_csv(llm_path / 'features_test.csv')
        
        print(f"Loaded data for {self.problem}:")
        print(f"  mzn2feat: {self.mzn2feat_train.shape[1]-1} features")
        print(f"  LLM ({self.llm_extractor_name}): {self.llm_train.shape[1]-1} features")
        
        # Display LLM feature names
        llm_feature_names = [col for col in self.llm_train.columns if col != 'filename']
        print(f"\n📋 LLM-based feature names ({len(llm_feature_names)} total):")
        for i, feat_name in enumerate(llm_feature_names, 1):
            print(f"  {i:2d}. {feat_name}")
        
        # Store base path for later use
        self.base_path = base_path
    
    def load_models(self):
        """Load trained models for feature importance extraction."""
        results_dir = Path(f'../../results_{self.loss_function}')
        
        # Load mzn2feat model
        mzn2feat_model_path = results_dir / f'{self.problem}_mzn2feat_{self.selector_type}.pkl'
        if mzn2feat_model_path.exists():
            with open(mzn2feat_model_path, 'rb') as f:
                model_data = pickle.load(f)
                # Handle tuple of (model, scaler) or just model
                if isinstance(model_data, tuple):
                    self.mzn2feat_model = model_data[0]  # Extract actual model
                    self.mzn2feat_scaler = model_data[1] if len(model_data) > 1 else None
                else:
                    self.mzn2feat_model = model_data
                    self.mzn2feat_scaler = None
        
        # Load LLM model
        llm_model_path = results_dir / f'{self.problem}_{self.llm_extractor_name}_{self.selector_type}.pkl'
        if llm_model_path.exists():
            with open(llm_model_path, 'rb') as f:
                model_data = pickle.load(f)
                # Handle tuple of (model, scaler) or just model
                if isinstance(model_data, tuple):
                    self.llm_model = model_data[0]  # Extract actual model
                    self.llm_scaler = model_data[1] if len(model_data) > 1 else None
                else:
                    self.llm_model = model_data
                    self.llm_scaler = None
        
        print(f"Loaded models: mzn2feat={mzn2feat_model_path.exists()}, "
              f"LLM={llm_model_path.exists()}")
        
        # Print model types for debugging
        if hasattr(self, 'mzn2feat_model'):
            print(f"  mzn2feat model type: {type(self.mzn2feat_model).__name__}")
        if hasattr(self, 'llm_model'):
            print(f"  LLM model type: {type(self.llm_model).__name__}")
    
    def extract_feature_importance(self, model, feature_names):
        """
        Extract feature importance from trained model with enhanced AutoSklearn support.
        
        Feature Importance Explanation:
        - For Random Forest: Gini impurity-based importance (how much each feature 
          contributes to decreasing node impurity across all trees)
        - For AutoSklearn: Ensemble-weighted average of individual model importances
        - Values range from 0 to 1, with higher values indicating more important features
        - Sum of all importance values equals 1.0
        - Interpretation: A feature with importance 0.1 contributes 10% to the model's 
          decision-making process
        """
        if hasattr(model, 'feature_importances_'):
            # Random Forest, etc.
            return pd.Series(model.feature_importances_, index=feature_names)
        elif hasattr(model, 'coef_'):
            # Linear models
            return pd.Series(np.abs(model.coef_[0]), index=feature_names)
        elif hasattr(model, 'estimator_'):
            # AutoSklearn ensemble - try different approaches
            if hasattr(model.estimator_, 'feature_importances_'):
                return pd.Series(model.estimator_.feature_importances_, index=feature_names)
            elif hasattr(model.estimator_, 'get_models_with_weights'):
                # AutoSklearn ensemble - aggregate feature importance from all models
                return self.extract_autosklearn_ensemble_importance(model, feature_names)
        elif hasattr(model, 'get_models_with_weights'):
            # Direct AutoSklearn ensemble
            return self.extract_autosklearn_ensemble_importance(model, feature_names)
        
        # Fallback: try to use model for permutation importance
        return self.calculate_permutation_importance(model, feature_names)
    
    def extract_autosklearn_ensemble_importance(self, model, feature_names):
        """Extract feature importance from AutoSklearn ensemble models."""
        try:
            models_with_weights = model.get_models_with_weights()
            ensemble_importance = np.zeros(len(feature_names))
            total_weight = 0
            
            for weight, model_instance in models_with_weights:
                if hasattr(model_instance, 'feature_importances_'):
                    ensemble_importance += weight * model_instance.feature_importances_
                    total_weight += weight
                elif hasattr(model_instance, 'coef_'):
                    ensemble_importance += weight * np.abs(model_instance.coef_[0])
                    total_weight += weight
            
            if total_weight > 0:
                ensemble_importance /= total_weight
                return pd.Series(ensemble_importance, index=feature_names)
        except:
            pass
        
        return pd.Series(np.zeros(len(feature_names)), index=feature_names)
    
    def calculate_permutation_importance(self, model, feature_names):
        """Calculate permutation importance as fallback - not implemented for accuracy."""
        # Removed performance data dependency - only use trained model feature importance
        return pd.Series(np.zeros(len(feature_names)), index=feature_names)
    
    def plot_feature_importance_comparison(self, top_k=20):
        """Create single merged feature importance comparison plot with sorted features."""
        # Extract feature importance
        mzn2feat_features = [col for col in self.mzn2feat_train.columns if col != 'filename']
        llm_features = [col for col in self.llm_train.columns if col != 'filename']
        
        mzn2feat_importance = self.extract_feature_importance(self.mzn2feat_model, mzn2feat_features)
        llm_importance = self.extract_feature_importance(self.llm_model, llm_features)
        
        # Get top features from both (already sorted by importance)
        mzn2feat_top = mzn2feat_importance.nlargest(top_k)
        llm_top = llm_importance.nlargest(top_k)
        
        # Create combined data for single plot - sorted by importance
        combined_data = []
        labels = []
        colors = []
        feature_types = []
        
        # Combine both sets and sort by importance
        all_features = []
        
        # Add mzn2feat features with type indicator
        for feat, imp in mzn2feat_top.items():
            feat_short = self._shorten_feature_name(feat)
            all_features.append((imp, f'M: {feat_short}', 'orange', 'mzn2feat'))
        
        # Add LLM features with type indicator
        for feat, imp in llm_top.items():
            feat_short = self._shorten_feature_name(feat)
            all_features.append((imp, f'L: {feat_short}', 'steelblue', 'LLM'))
        
        # Sort all features by importance (descending)
        all_features.sort(key=lambda x: x[0], reverse=True)
        
        # Extract sorted data
        combined_data = [f[0] for f in all_features]
        labels = [f[1] for f in all_features]
        colors = [f[2] for f in all_features]
        
        # Create single horizontal bar plot
        fig, ax = plt.subplots(1, 1, figsize=(14, 22))
        
        y_positions = range(len(combined_data))
        bars = ax.barh(y_positions, combined_data, color=colors, alpha=0.8)
        
        ax.set_yticks(y_positions)
        ax.set_yticklabels(labels)
        ax.set_xlabel('Feature Importance', fontsize=30)
        ax.invert_yaxis()
        
        # Add legend
        from matplotlib.patches import Patch
        legend_elements = [Patch(facecolor='orange', alpha=0.8, label=f'mzn2feat ({len(mzn2feat_features)} features)'),
                          Patch(facecolor='steelblue', alpha=0.8, label=f'LLM ({len(llm_features)} features)')]
        ax.legend(handles=legend_elements, loc='lower right', fontsize=28)
        
        plt.tight_layout()
        
        output_path = self.figures_dir / f'{self.problem}_{self.selector_type}_feature_importance_comparison.pdf'
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.show()
        
        return mzn2feat_importance, llm_importance
    
    def _shorten_feature_name(self, feature_name):
        """Shorten feature names for better display."""
        # Remove common prefixes
        name = feature_name.replace('characteristic_', 'ch_')
        name = name.replace('v_', '').replace('c_', '').replace('o_', '')
        name = name.replace('_dom_', '_d_').replace('_avg', '_a').replace('_max', '_m')
        name = name.replace('_cons', '_c').replace('_vars', '_v')
        
        # Truncate if still too long
        if len(name) > 15:
            name = name[:12] + '...'
        
        return name
    
    def plot_feature_importance_heatmap(self, problems=['FLECC', 'car_sequencing', 'vrp'], top_k=15):
        """Create heatmap showing feature importance across problems."""
        # This would require loading multiple problems - simplified version for single problem
        mzn2feat_features = [col for col in self.mzn2feat_train.columns if col != 'filename']
        llm_features = [col for col in self.llm_train.columns if col != 'filename']
        
        mzn2feat_importance = self.extract_feature_importance(self.mzn2feat_model, mzn2feat_features)
        llm_importance = self.extract_feature_importance(self.llm_model, llm_features)
        
        # Create combined data for heatmap
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))
        
        # mzn2feat heatmap
        mzn2feat_top = mzn2feat_importance.nlargest(top_k)
        mzn2feat_matrix = mzn2feat_top.values.reshape(1, -1)
        sns.heatmap(mzn2feat_matrix, 
                   xticklabels=[name[:20] + '...' if len(name) > 20 else name 
                              for name in mzn2feat_top.index],
                   yticklabels=[self.problem],
                   cmap='Oranges', ax=ax1, cbar_kws={'label': 'Importance'})
        ax1.set_title('mzn2feat Feature Importance', fontsize=28)
        ax1.tick_params(axis='x', rotation=45)
        
        # LLM heatmap
        llm_top = llm_importance.nlargest(top_k)
        llm_matrix = llm_top.values.reshape(1, -1)
        sns.heatmap(llm_matrix,
                   xticklabels=[name[:20] + '...' if len(name) > 20 else name 
                              for name in llm_top.index],
                   yticklabels=[self.problem],
                   cmap='Blues', ax=ax2, cbar_kws={'label': 'Importance'})
        ax2.set_title('LLM Feature Importance', fontsize=28)
        ax2.tick_params(axis='x', rotation=45)
        
        plt.suptitle(f'Feature Importance Heatmap: {self.selector_type}', 
                    fontsize=14, fontweight='bold')
        plt.tight_layout()
        
        output_path = self.figures_dir / f'{self.problem}_{self.selector_type}_importance_heatmap.pdf'
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.show()
    
    def perform_feature_ablation_study(self, n_trials=5):
        """Feature ablation study removed - requires performance data reconstruction which is less accurate than using trained models directly."""
        print("   Feature ablation study skipped - using trained model analysis instead")
        return None, None
    
    def plot_feature_correlation_matrix(self, top_k=20):
        """Plot feature correlation matrices for interpretability."""
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
        
        # mzn2feat correlation
        mzn2feat_features = self.mzn2feat_train.drop(columns=['filename'])
        mzn2feat_importance = self.extract_feature_importance(
            self.mzn2feat_model, mzn2feat_features.columns)
        mzn2feat_top_features = mzn2feat_importance.nlargest(top_k).index
        mzn2feat_corr = mzn2feat_features[mzn2feat_top_features].corr()
        
        # Calculate average correlation for title
        mzn2feat_avg_corr = np.mean(np.abs(mzn2feat_corr.where(
            np.triu(np.ones(mzn2feat_corr.shape), k=1).astype(bool)
        ).stack().values))
        
        mask1 = np.triu(np.ones_like(mzn2feat_corr, dtype=bool))
        hm1 = sns.heatmap(mzn2feat_corr, mask=mask1, cmap='RdBu_r', center=0,
                   square=True, ax=ax1, cbar_kws={'label': 'Correlation'},
                   xticklabels=False, yticklabels=False)
        ax1.set_title(f'mzn2feat (Avg |r| = {mzn2feat_avg_corr:.3f})', fontsize=28)
        # Set colorbar label font size
        hm1.collections[0].colorbar.ax.tick_params(labelsize=26)
        hm1.collections[0].colorbar.set_label('Correlation', size=26)
        
        # LLM correlation
        llm_features = self.llm_train.drop(columns=['filename'])
        llm_importance = self.extract_feature_importance(
            self.llm_model, llm_features.columns)
        llm_top_features = llm_importance.nlargest(top_k).index
        llm_corr = llm_features[llm_top_features].corr()
        
        # Calculate average correlation for title
        llm_avg_corr = np.mean(np.abs(llm_corr.where(
            np.triu(np.ones(llm_corr.shape), k=1).astype(bool)
        ).stack().values))
        
        mask2 = np.triu(np.ones_like(llm_corr, dtype=bool))
        hm2 = sns.heatmap(llm_corr, mask=mask2, cmap='RdBu_r', center=0,
                   square=True, ax=ax2, cbar_kws={'label': 'Correlation'},
                   xticklabels=False, yticklabels=False)
        ax2.set_title(f'LLM (Avg |r| = {llm_avg_corr:.3f})', fontsize=28)
        # Set colorbar label font size
        hm2.collections[0].colorbar.ax.tick_params(labelsize=26)
        hm2.collections[0].colorbar.set_label('Correlation', size=26)
        
        plt.tight_layout()
        
        output_path = self.figures_dir / f'{self.problem}_{self.selector_type}_correlation_matrix.pdf'
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.show()
        
        # Calculate and print correlation statistics
        self.analyze_correlation_statistics(mzn2feat_corr, llm_corr)
    
    def analyze_correlation_statistics(self, mzn2feat_corr, llm_corr):
        """Calculate and print correlation statistics for comparison."""
        # Extract upper triangle values (avoid diagonal and duplicates)
        mzn2feat_upper = mzn2feat_corr.where(
            np.triu(np.ones(mzn2feat_corr.shape), k=1).astype(bool)
        ).stack().values
        
        llm_upper = llm_corr.where(
            np.triu(np.ones(llm_corr.shape), k=1).astype(bool)
        ).stack().values
        
        # Calculate statistics
        print(f"\n📊 CORRELATION ANALYSIS RESULTS:")
        print(f"{'='*50}")
        print(f"mzn2feat Features:")
        print(f"  • Mean |correlation|: {np.mean(np.abs(mzn2feat_upper)):.3f}")
        print(f"  • Std |correlation|:  {np.std(np.abs(mzn2feat_upper)):.3f}")
        print(f"  • High correlations (|r|>0.7): {np.sum(np.abs(mzn2feat_upper) > 0.7)}")
        print(f"  • Strong correlations (|r|>0.5): {np.sum(np.abs(mzn2feat_upper) > 0.5)}")
        
        print(f"\nLLM Features:")
        print(f"  • Mean |correlation|: {np.mean(np.abs(llm_upper)):.3f}")
        print(f"  • Std |correlation|:  {np.std(np.abs(llm_upper)):.3f}")
        print(f"  • High correlations (|r|>0.7): {np.sum(np.abs(llm_upper) > 0.7)}")
        print(f"  • Strong correlations (|r|>0.5): {np.sum(np.abs(llm_upper) > 0.5)}")
        
        # Interpretation
        mzn2_mean = np.mean(np.abs(mzn2feat_upper))
        llm_mean = np.mean(np.abs(llm_upper))
        
        print(f"\n🎯 INTERPRETATION:")
        if llm_mean < mzn2_mean:
            improvement = ((mzn2_mean - llm_mean) / mzn2_mean) * 100
            print(f"  • LLM features show {improvement:.1f}% lower correlation → More diverse features")
        else:
            difference = ((llm_mean - mzn2_mean) / mzn2_mean) * 100
            print(f"  • LLM features show {difference:.1f}% higher correlation")
        
        print(f"{'='*50}")
    
    def plot_cross_correlation_matrix(self, top_k=20):
        """Plot cross-correlation matrix between mzn2feat and LLM features."""
        fig, ax = plt.subplots(1, 1, figsize=(12, 10))
        
        # Get top features from both methods
        mzn2feat_features = self.mzn2feat_train.drop(columns=['filename'])
        mzn2feat_importance = self.extract_feature_importance(
            self.mzn2feat_model, mzn2feat_features.columns)
        mzn2feat_top_features = mzn2feat_importance.nlargest(top_k).index
        
        llm_features = self.llm_train.drop(columns=['filename'])
        llm_importance = self.extract_feature_importance(
            self.llm_model, llm_features.columns)
        llm_top_features = llm_importance.nlargest(top_k).index
        
        # Combine datasets on matching filenames
        mzn2feat_data = self.mzn2feat_train.set_index('filename')[mzn2feat_top_features]
        llm_data = self.llm_train.set_index('filename')[llm_top_features]
        
        # Calculate cross-correlation matrix (mzn2feat vs LLM features)
        combined_data = pd.concat([mzn2feat_data, llm_data], axis=1)
        cross_corr = combined_data.corr().loc[mzn2feat_top_features, llm_top_features]
        
        # Calculate statistics
        cross_corr_values = cross_corr.values.flatten()
        avg_cross_corr = np.mean(np.abs(cross_corr_values))
        max_cross_corr = np.max(np.abs(cross_corr_values))
        
        # Create heatmap
        hm = sns.heatmap(cross_corr, cmap='RdBu_r', center=0,
                        cbar_kws={'label': 'Cross-Correlation'},
                        xticklabels=False, yticklabels=False, ax=ax)
        
        ax.set_xlabel('LLM Features (Top 20)', fontsize=30)
        ax.set_ylabel('mzn2feat Features (Top 20)', fontsize=30)
        ax.set_title(f'Cross-Correlation: mzn2feat vs LLM Features\n(Avg |r| = {avg_cross_corr:.3f}, Max |r| = {max_cross_corr:.3f})', fontsize=28)
        
        # Set colorbar label font size
        hm.collections[0].colorbar.ax.tick_params(labelsize=26)
        hm.collections[0].colorbar.set_label('Cross-Correlation', size=26)
        
        plt.tight_layout()
        
        output_path = self.figures_dir / f'{self.problem}_{self.selector_type}_cross_correlation.pdf'
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.show()
        
        # Print cross-correlation statistics
        self.analyze_cross_correlation_statistics(cross_corr, avg_cross_corr, max_cross_corr)
    
    def analyze_cross_correlation_statistics(self, cross_corr, avg_cross_corr, max_cross_corr):
        """Analyze and print cross-correlation statistics."""
        # Find most correlated pairs
        corr_abs = np.abs(cross_corr)
        max_idx = np.unravel_index(np.argmax(corr_abs.values), corr_abs.shape)
        max_mzn2feat = corr_abs.index[max_idx[0]]
        max_llm = corr_abs.columns[max_idx[1]]
        max_corr_value = cross_corr.iloc[max_idx[0], max_idx[1]]
        
        # Count high correlations
        high_corr_count = (corr_abs > 0.7).sum().sum()
        moderate_corr_count = ((corr_abs > 0.5) & (corr_abs <= 0.7)).sum().sum()
        total_pairs = cross_corr.size
        
        print(f"\n📊 CROSS-CORRELATION ANALYSIS RESULTS:")
        print(f"{'='*55}")
        print(f"Cross-Feature Analysis (mzn2feat vs LLM):")
        print(f"  • Average |cross-correlation|: {avg_cross_corr:.3f}")
        print(f"  • Maximum |cross-correlation|: {max_cross_corr:.3f}")
        print(f"  • High correlations (|r|>0.7): {high_corr_count}/{total_pairs} ({high_corr_count/total_pairs*100:.1f}%)")
        print(f"  • Moderate correlations (0.5<|r|≤0.7): {moderate_corr_count}/{total_pairs} ({moderate_corr_count/total_pairs*100:.1f}%)")
        print(f"  • Most correlated pair: {max_mzn2feat[:30]}... ↔ {max_llm} (r = {max_corr_value:.3f})")
        
        print(f"\n🎯 CROSS-CORRELATION INTERPRETATION:")
        if avg_cross_corr < 0.3:
            print(f"  • Low cross-correlation ({avg_cross_corr:.3f}) → Feature methods capture different aspects")
        elif avg_cross_corr < 0.6:
            print(f"  • Moderate cross-correlation ({avg_cross_corr:.3f}) → Some overlapping information")
        else:
            print(f"  • High cross-correlation ({avg_cross_corr:.3f}) → Feature methods capture similar information")
        
        if high_corr_count == 0:
            print(f"  • No highly correlated feature pairs → Methods are complementary")
        else:
            print(f"  • {high_corr_count} highly correlated pairs → Some redundancy between methods")
        
        print(f"{'='*55}")
    
    def calculate_accuracy_vs_top_features(self):
        """Calculate training and testing accuracy achieved with top N features by retraining models iteratively."""
        print("\n🔄 Calculating accuracy vs top N features (this may take a few minutes)...")
        
        # Load performance data for retraining
        base_path = Path('../datasets') / self.problem
        
        # Load mzn2feat performance data
        mzn2feat_path = base_path / 'mzn2feat'
        mzn2feat_perf_train = pd.read_csv(mzn2feat_path / 'performance_train.csv')
        mzn2feat_perf_test = pd.read_csv(mzn2feat_path / 'performance_test.csv')
        
        # Load LLM performance data
        llm_path = base_path / self.llm_extractor_name
        llm_perf_train = pd.read_csv(llm_path / 'performance_train.csv')
        llm_perf_test = pd.read_csv(llm_path / 'performance_test.csv')
        
        # Get feature importance rankings
        mzn2feat_features = [col for col in self.mzn2feat_train.columns if col != 'filename']
        llm_features = [col for col in self.llm_train.columns if col != 'filename']
        
        mzn2feat_importance = self.extract_feature_importance(self.mzn2feat_model, mzn2feat_features)
        llm_importance = self.extract_feature_importance(self.llm_model, llm_features)
        
        # Feature counts to test - extended to 50 for LLM features
        max_features = min(50, len(mzn2feat_features), len(llm_features))
        feature_counts = list(range(1, max_features + 1, 2))  # Test every other number for efficiency
        
        mzn2feat_train_accuracies = []
        mzn2feat_test_accuracies = []
        llm_train_accuracies = []
        llm_test_accuracies = []
        
        # Test different numbers of top features
        for n_features in feature_counts:
            print(f"   Testing with {n_features} features...")
            
            # mzn2feat accuracy with top N features
            mzn2feat_top_features = mzn2feat_importance.nlargest(n_features).index.tolist()
            mzn2feat_X_train = self.mzn2feat_train[mzn2feat_top_features]
            mzn2feat_X_test = self.mzn2feat_test[mzn2feat_top_features]
            # Extract best algorithms using same logic as official trainer
            mzn2feat_y_train = self._extract_best_algorithms_from_performance(mzn2feat_perf_train)
            mzn2feat_y_test = self._extract_best_algorithms_from_performance(mzn2feat_perf_test)
            
            # Train Random Forest with top N features
            rf_mzn = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=1)
            scaler_mzn = StandardScaler()
            mzn2feat_X_train_scaled = scaler_mzn.fit_transform(mzn2feat_X_train)
            mzn2feat_X_test_scaled = scaler_mzn.transform(mzn2feat_X_test)
            rf_mzn.fit(mzn2feat_X_train_scaled, mzn2feat_y_train)
            
            # Calculate both training and testing accuracy
            mzn2feat_train_pred = rf_mzn.predict(mzn2feat_X_train_scaled)
            mzn2feat_test_pred = rf_mzn.predict(mzn2feat_X_test_scaled)
            mzn2feat_train_acc = accuracy_score(mzn2feat_y_train, mzn2feat_train_pred)
            mzn2feat_test_acc = accuracy_score(mzn2feat_y_test, mzn2feat_test_pred)
            mzn2feat_train_accuracies.append(mzn2feat_train_acc)
            mzn2feat_test_accuracies.append(mzn2feat_test_acc)
            
            print(f"     mzn2feat: train={mzn2feat_train_acc:.3f}, test={mzn2feat_test_acc:.3f}")
            
            # LLM accuracy with top N features
            llm_top_features = llm_importance.nlargest(n_features).index.tolist()
            llm_X_train = self.llm_train[llm_top_features]
            llm_X_test = self.llm_test[llm_top_features]
            llm_y_train = self._extract_best_algorithms_from_performance(llm_perf_train)
            llm_y_test = self._extract_best_algorithms_from_performance(llm_perf_test)
            
            rf_llm = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=1)
            scaler_llm = StandardScaler()
            llm_X_train_scaled = scaler_llm.fit_transform(llm_X_train)
            llm_X_test_scaled = scaler_llm.transform(llm_X_test)
            rf_llm.fit(llm_X_train_scaled, llm_y_train)
            
            # Calculate both training and testing accuracy
            llm_train_pred = rf_llm.predict(llm_X_train_scaled)
            llm_test_pred = rf_llm.predict(llm_X_test_scaled)
            llm_train_acc = accuracy_score(llm_y_train, llm_train_pred)
            llm_test_acc = accuracy_score(llm_y_test, llm_test_pred)
            llm_train_accuracies.append(llm_train_acc)
            llm_test_accuracies.append(llm_test_acc)
            
            print(f"     LLM: train={llm_train_acc:.3f}, test={llm_test_acc:.3f}")
        
        print(f"   ✓ Completed accuracy calculations for {len(feature_counts)} feature counts")
        return (mzn2feat_train_accuracies, mzn2feat_test_accuracies, 
                llm_train_accuracies, llm_test_accuracies, feature_counts)
    
    def _extract_best_algorithms_from_performance(self, performance_df: pd.DataFrame) -> np.ndarray:
        """Extract best algorithms using EXACT same logic as official trainer."""
        # Get algorithm performance columns (skip instance name column)
        algo_columns = performance_df.columns[1:].tolist()
        
        # Use exact same logic as official trainer for consistency
        best_algorithms = []
        for _, row in performance_df.iterrows():
            algo_values = [row[col] for col in algo_columns]
            
            # Handle missing/invalid values (same as official trainer)
            valid_values = [(i, val) for i, val in enumerate(algo_values) 
                           if pd.notna(val) and val != float('inf')]
            
            if not valid_values:
                # If no valid values, choose first algorithm as fallback
                best_algorithms.append(algo_columns[0])
            else:
                # Find minimum value - first occurrence wins ties (same as official)
                best_idx = min(valid_values, key=lambda x: x[1])[0]
                best_algorithms.append(algo_columns[best_idx])
        
        return np.array(best_algorithms)
    
    def analyze_model_feature_usage(self):
        """Analyze how different models utilize features differently."""
        print(f"\n🔍 MODEL-BASED FEATURE ANALYSIS:")
        print(f"{'='*60}")
        
        # Print feature importance explanation
        print(f"Feature Importance Definition:")
        print(f"  • Random Forest: Gini impurity-based importance")
        print(f"  • Values 0-1 (higher = more important for algorithm selection)")
        print(f"  • Sum of all importances = 1.0")
        print(f"  • Example: 0.1 = feature contributes 10% to decision process")
        print()
        
        # Get feature importance from both models
        mzn2feat_features = [col for col in self.mzn2feat_train.columns if col != 'filename']
        llm_features = [col for col in self.llm_train.columns if col != 'filename']
        
        mzn2feat_importance = self.extract_feature_importance(self.mzn2feat_model, mzn2feat_features)
        llm_importance = self.extract_feature_importance(self.llm_model, llm_features)
        
        # Model architecture analysis
        print(f"Model Architecture Analysis:")
        print(f"  mzn2feat model: {type(self.mzn2feat_model).__name__}")
        print(f"  LLM model: {type(self.llm_model).__name__}")
        
        # Feature utilization efficiency
        mzn2feat_nonzero = (mzn2feat_importance > 0.001).sum()
        llm_nonzero = (llm_importance > 0.001).sum()
        
        print(f"\nFeature Utilization:")
        print(f"  mzn2feat: {mzn2feat_nonzero}/{len(mzn2feat_features)} features used effectively ({mzn2feat_nonzero/len(mzn2feat_features)*100:.1f}%)")
        print(f"  LLM: {llm_nonzero}/{len(llm_features)} features used effectively ({llm_nonzero/len(llm_features)*100:.1f}%)")
        
        # Top feature analysis
        print(f"\nTop 5 Most Important Features:")
        print(f"  mzn2feat:")
        for i, (feat, imp) in enumerate(mzn2feat_importance.nlargest(5).items()):
            print(f"    {i+1}. {feat[:40]}: {imp:.4f}")
        
        print(f"  LLM:")
        for i, (feat, imp) in enumerate(llm_importance.nlargest(5).items()):
            print(f"    {i+1}. {feat}: {imp:.4f}")
        
        # Feature concentration analysis
        mzn2feat_top10_share = mzn2feat_importance.nlargest(10).sum() / mzn2feat_importance.sum()
        llm_top10_share = llm_importance.nlargest(10).sum() / llm_importance.sum()
        
        print(f"\nFeature Concentration (Top 10 features):")
        print(f"  mzn2feat: {mzn2feat_top10_share:.1%} of total importance")
        print(f"  LLM: {llm_top10_share:.1%} of total importance")
        
        if mzn2feat_top10_share > llm_top10_share:
            print(f"  → mzn2feat shows more concentrated feature usage")
        else:
            print(f"  → LLM shows more concentrated feature usage")
    
    def create_feature_importance_distribution_plot(self):
        """Create feature importance distribution comparison plot."""
        fig, ax = plt.subplots(1, 1, figsize=(12, 8))
        
        # Get feature importance
        mzn2feat_features = [col for col in self.mzn2feat_train.columns if col != 'filename']
        llm_features = [col for col in self.llm_train.columns if col != 'filename']
        
        mzn2feat_importance = self.extract_feature_importance(self.mzn2feat_model, mzn2feat_features)
        llm_importance = self.extract_feature_importance(self.llm_model, llm_features)
        
        # Merged Feature importance distribution with categorical ranges
        # Define importance ranges
        ranges = [(0, 0.001), (0.001, 0.005), (0.005, 0.01), (0.01, 0.02), (0.02, 0.05), (0.05, float('inf'))]
        range_labels = ['0-0.001', '0.001-0.005', '0.005-0.01', '0.01-0.02', '0.02-0.05', '>0.05']
        
        # Calculate ratios for each range
        mzn2feat_ratios = []
        llm_ratios = []
        
        for r_min, r_max in ranges:
            # Count features in range
            mzn2feat_count = ((mzn2feat_importance >= r_min) & (mzn2feat_importance < r_max)).sum()
            llm_count = ((llm_importance >= r_min) & (llm_importance < r_max)).sum()
            
            # Calculate ratios
            mzn2feat_ratios.append(mzn2feat_count / len(mzn2feat_importance))
            llm_ratios.append(llm_count / len(llm_importance))
        
        # Create grouped bar chart
        x = np.arange(len(range_labels))
        width = 0.35
        
        ax.bar(x - width/2, mzn2feat_ratios, width, label='mzn2feat', color='orange', alpha=0.8)
        ax.bar(x + width/2, llm_ratios, width, label='LLM', color='steelblue', alpha=0.8)
        
        ax.set_xlabel('Feature Importance Range', fontsize=36)
        ax.set_ylabel('Ratio of Features', fontsize=36)
        ax.set_title('Feature Importance Distribution Comparison', fontsize=36)
        ax.set_xticks(x)
        ax.set_xticklabels(range_labels, rotation=45, ha='right', fontsize=36)
        ax.legend(fontsize=36, bbox_to_anchor=(1.05, 1), loc='upper left')
        ax.grid(True, alpha=0.3, axis='y')
        ax.tick_params(axis='y', labelsize=36)  # Explicitly set y-tick font size
        
        plt.tight_layout()
        
        output_path = self.figures_dir / f'{self.problem}_{self.selector_type}_feature_distribution.pdf'
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.show()
    
    def create_accuracy_analysis_plot(self):
        """Create accuracy vs top N features analysis plot."""
        fig, ax = plt.subplots(1, 1, figsize=(12, 8))
        
        # Accuracy vs Top N Features (both training and testing)
        results = self.calculate_accuracy_vs_top_features()
        mzn2feat_train_acc, mzn2feat_test_acc, llm_train_acc, llm_test_acc, feature_counts = results
        
        # Plot training accuracies (dashed lines)
        ax.plot(feature_counts, mzn2feat_train_acc, 'o--', 
                color='orange', label='mzn2feat (train)', linewidth=3, markersize=8, alpha=0.7)
        ax.plot(feature_counts, llm_train_acc, 's--', 
                color='steelblue', label='LLM (train)', linewidth=3, markersize=8, alpha=0.7)
        
        # Plot testing accuracies (solid lines)
        ax.plot(feature_counts, mzn2feat_test_acc, 'o-', 
                color='orange', label='mzn2feat (test)', linewidth=3, markersize=8)
        ax.plot(feature_counts, llm_test_acc, 's-', 
                color='steelblue', label='LLM (test)', linewidth=3, markersize=8)
        
        ax.set_xlabel('Number of Top Features', fontsize=36)
        ax.set_ylabel('Accuracy', fontsize=36)
        ax.legend(fontsize=36, bbox_to_anchor=(1.05, 1), loc='upper left')
        ax.grid(True, alpha=0.3)
        ax.tick_params(axis='x', labelsize=30)  # Explicitly set x-tick font size
        ax.tick_params(axis='y', labelsize=30)  # Explicitly set y-tick font size
        ax.set_title('Accuracy vs Top N Features', fontsize=36)
        
        plt.tight_layout()
        
        output_path = self.figures_dir / f'{self.problem}_{self.selector_type}_accuracy_analysis.pdf'
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.show()
    
    def generate_all_visualizations(self):
        """Generate complete set of visualizations."""
        print(f"Generating visualizations for {self.problem} - {self.selector_type}")
        print("=" * 60)
        
        visualizations = [
            ("1. Creating feature correlation matrices...", self.plot_feature_correlation_matrix),
            ("2. Creating cross-correlation analysis...", self.plot_cross_correlation_matrix),
            ("3. Analyzing model-based feature usage...", self.analyze_model_feature_usage),
            ("4. Creating feature importance distribution plot...", self.create_feature_importance_distribution_plot),
            ("5. Creating accuracy analysis plot...", self.create_accuracy_analysis_plot),
        ]
        
        for desc, func in visualizations:
            try:
                print(desc)
                func()
                print("   ✓ Success")
            except Exception as e:
                print(f"   ✗ Error: {str(e)}")
                continue
        
        print(f"\nVisualizations saved to: {self.figures_dir}/")
        print("Files generated:")
        for file in self.figures_dir.glob(f"{self.problem}_{self.selector_type}_*.pdf"):
            print(f"  - {file.name}")


def main():
    parser = argparse.ArgumentParser(
        description="Generate feature importance and ablation visualizations"
    )
    parser.add_argument('--problem', type=str, required=True,
                       choices=['FLECC', 'car_sequencing', 'vrp'],
                       help='Problem to analyze')
    parser.add_argument('--selector', type=str, default='random_forest',
                       choices=['random_forest', 'autosklearn', 'autosklearn_conservative'],
                       help='Selector type to analyze')
    parser.add_argument('--loss-function', type=str, default='accuracy',
                       choices=['accuracy', 'ranking'],
                       help='Loss function used for training')
    
    args = parser.parse_args()
    
    toolkit = FeatureVisualizationToolkit(
        problem=args.problem,
        selector_type=args.selector,
        loss_function=args.loss_function
    )
    
    toolkit.generate_all_visualizations()


if __name__ == "__main__":
    main()