#!/usr/bin/env python3
"""
Comprehensive 8 Datasets Analysis
- Analyze all 8 datasets with the same alpha values
- Cross-validation for reliable results
- Save all results to JSON for future analysis
- Include feature weights and category analysis
"""

import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score, StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score, brier_score_loss
from sklearn.calibration import calibration_curve
import warnings
warnings.filterwarnings('ignore')

# Import the final feature map
from feature_map_final import extract_final_stable_features, get_feature_info_stable, FEATURE_MAP_FINAL_STABLE

# Get feature names and categories from the final feature map
feature_info = get_feature_info_stable()
feature_names = list(feature_info.keys())

# Feature categories (updated to match the new 4-category structure)
feature_categories = {
    'Dynamics': list(range(0, 19)),      # 0-18: All dynamics features
    'Position': list(range(19, 33)),     # 19-32: All position features  
    'Stability': list(range(33, 43)),    # 33-42: All stability features
    'Structure': list(range(43, 48))     # 43-47: All structure features
}

# Dataset files mapping
dataset_files = {
    'SimpleQA': 'gpt-4.1_SimpleQA__full_20250902_evaluated.jsonl',
    'SimpleQA_large': 'gpt-4.1__SimpleQA_large_evaluated.jsonl',
    'HotpotQA': 'gpt-4.1_HotpotQA__full_20250902_evaluated.jsonl',
    'StrategyQA': 'gpt-4.1_StrategyQA__full_20250902_evaluated.jsonl',
    'MATH500': 'gpt-4.1_MATH500__full_20250902_evaluated.jsonl',
    'MMLU-Pro': 'gpt-4.1_MMLU-Pro__full_20250902_evaluated.jsonl',
    'GPQA-diamond': 'gpt-4.1_GPQA-diamond__full_20250902_evaluated.jsonl',
    'HLE': 'gpt-4.1_HLE__full_20250902_evaluated.jsonl',
    'GAIA': 'gpt-4.1_GAIA___full_20250902_evaluated.jsonl'
}

# Alpha values to test (same as SimpleQA analysis)
alpha_values = [0.001, 0.01, 0.1, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 20.0, 50.0]

def load_dataset_data(dataset_name):
    """Load dataset data"""
    print(f"📂 Loading {dataset_name} data...")
    
    if dataset_name not in dataset_files:
        raise ValueError(f"Dataset {dataset_name} not found in mapping")
    
    data = []
    file_path = f"evaluations_0902/{dataset_files[dataset_name]}"
    
    try:
        with open(file_path, 'r') as f:
            for line in f:
                try:
                    line = line.strip()
                    if not line:
                        continue
                    item = json.loads(line)
                    if isinstance(item, dict) and 'detailed_confidence_analysis' in item:
                        data.append(item)
                except json.JSONDecodeError:
                    continue
    except FileNotFoundError:
        print(f"File not found: {file_path}")
        return None
    
    print(f"Loaded {len(data)} samples from {dataset_name}")
    return data

def extract_features_and_labels(data):
    """Extract 48 features and labels using the final feature map"""
    print("🔍 Extracting 48 features using final feature map...")
    
    try:
        features_array, labels_array = extract_final_stable_features(data)
        print(f"Extracted features shape: {features_array.shape}")
        print(f"Positive samples: {np.sum(labels_array)} ({np.mean(labels_array):.1%})")
        return features_array, labels_array
    except Exception as e:
        print(f"Error extracting features: {e}")
        return None, None

def calculate_ece(y_true, y_prob, n_bins=10):
    """Calculate Expected Calibration Error"""
    bin_boundaries = np.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]
    
    ece = 0
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        in_bin = (y_prob > bin_lower) & (y_prob <= bin_upper)
        prop_in_bin = in_bin.mean()
        
        if prop_in_bin > 0:
            accuracy_in_bin = y_true[in_bin].mean()
            avg_confidence_in_bin = y_prob[in_bin].mean()
            ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
    
    return ece

def evaluate_model_with_cv(model, X, y, cv_folds=5):
    """Evaluate model with cross-validation for all metrics"""
    cv = StratifiedKFold(n_splits=cv_folds, shuffle=True, random_state=42)
    
    auroc_scores = []
    brier_scores = []
    ece_scores = []
    
    for train_idx, val_idx in cv.split(X, y):
        X_train_fold, X_val_fold = X[train_idx], X[val_idx]
        y_train_fold, y_val_fold = y[train_idx], y[val_idx]
        
        # Train model
        model.fit(X_train_fold, y_train_fold)
        
        # Get predictions
        y_prob = model.predict_proba(X_val_fold)[:, 1]
        
        # Calculate metrics
        auroc = roc_auc_score(y_val_fold, y_prob)
        brier = brier_score_loss(y_val_fold, y_prob)
        ece = calculate_ece(y_val_fold, y_prob)
        
        auroc_scores.append(auroc)
        brier_scores.append(brier)
        ece_scores.append(ece)
    
    return {
        'auroc_mean': np.mean(auroc_scores),
        'auroc_std': np.std(auroc_scores),
        'brier_mean': np.mean(brier_scores),
        'brier_std': np.std(brier_scores),
        'ece_mean': np.mean(ece_scores),
        'ece_std': np.std(ece_scores)
    }

def analyze_feature_categories(selected_features, coefficients):
    """Analyze feature importance by categories"""
    category_analysis = {}
    
    for category_name, feature_indices in feature_categories.items():
        # Get features in this category that were selected
        selected_in_category = [i for i in feature_indices if i in selected_features and i < len(coefficients)]
        
        if selected_in_category:
            # Get coefficients for selected features in this category
            category_coeffs = coefficients[selected_in_category]
            avg_importance = np.mean(np.abs(category_coeffs))
        else:
            avg_importance = 0
        
        category_analysis[category_name] = {
            'total_features': len(feature_indices),
            'selected_features': len(selected_in_category),
            'selection_ratio': len(selected_in_category) / len(feature_indices) if len(feature_indices) > 0 else 0,
            'avg_importance': float(avg_importance),
            'selected_indices': selected_in_category
        }
    
    return category_analysis

def optimize_alpha_parameter(X, y, dataset_name):
    """Optimize alpha parameter for a dataset"""
    
    results = []
    
    for alpha in alpha_values:
        C = 1.0 / alpha
        print(f"\nTesting Alpha = {alpha} (C = {C})")
        
        # Train Lasso model
        lasso = LogisticRegression(penalty='l1', C=C, solver='liblinear', 
                                 max_iter=1000, random_state=42)
        lasso.fit(X, y)
        
        # Get selected features
        selected_features = np.where(lasso.coef_[0] != 0)[0]
        n_selected = len(selected_features)
        
        if n_selected == 0:
            print(f"  No features selected, skipping...")
            continue
        
        # Use selected features for evaluation
        X_selected = X[:, selected_features]
        
        # Evaluate with cross-validation
        metrics = evaluate_model_with_cv(lasso, X_selected, y)
        
        result = {
            'alpha': float(alpha),
            'C': float(C),
            'n_selected_features': int(n_selected),
            'selected_features': selected_features.tolist(),
            'feature_weights': lasso.coef_[0].tolist(),
            'auroc_mean': float(metrics['auroc_mean']),
            'auroc_std': float(metrics['auroc_std']),
            'brier_mean': float(metrics['brier_mean']),
            'brier_std': float(metrics['brier_std']),
            'ece_mean': float(metrics['ece_mean']),
            'ece_std': float(metrics['ece_std'])
        }
        results.append(result)
        
        print(f"  Selected features: {n_selected}/48")
        print(f"  AUROC: {metrics['auroc_mean']:.4f} ± {metrics['auroc_std']:.4f}")
        print(f"  Brier Score: {metrics['brier_mean']:.4f} ± {metrics['brier_std']:.4f}")
        print(f"  ECE: {metrics['ece_mean']:.4f} ± {metrics['ece_std']:.4f}")
    
    return results

def find_best_alpha(results):
    """Find best alpha based on combined score"""
    if not results:
        return None
    
    best_result = max(results, key=lambda x: x['auroc_mean'] - 1.0 * x['brier_mean'] - 1.0 * x['ece_mean'])
    return best_result

def analyze_dataset(dataset_name):
    """Analyze a single dataset"""
    print(f"\n{'='*70}")
    print(f"ANALYZING DATASET: {dataset_name}")
    print(f"{'='*70}")
    
    # Load data
    data = load_dataset_data(dataset_name)
    if data is None:
        return None
    
    X, y = extract_features_and_labels(data)
    if X is None or y is None:
        return None
    
    # Standardize features
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)
    
    print(f"  Features: {X_scaled.shape[1]}")
    print(f"  Samples: {X_scaled.shape[0]}")
    print(f"  Positive samples: {np.sum(y)} ({np.mean(y):.1%})")
    
    # Optimize alpha parameter
    results = optimize_alpha_parameter(X_scaled, y, dataset_name)
    
    if not results:
        print(f"No valid results for {dataset_name}")
        return None
    
    # Find best alpha
    best_result = find_best_alpha(results)
    
    print(f"  Best Alpha Configuration for {dataset_name}:")
    print(f"  Alpha: {best_result['alpha']}")
    print(f"  C: {best_result['C']}")
    print(f"  Selected features: {best_result['n_selected_features']}")
    print(f"  AUROC: {best_result['auroc_mean']:.4f} ± {best_result['auroc_std']:.4f}")
    print(f"  Brier Score: {best_result['brier_mean']:.4f} ± {best_result['brier_std']:.4f}")
    print(f"  ECE: {best_result['ece_mean']:.4f} ± {best_result['ece_std']:.4f}")
    
    # Prepare result with complete 48-feature weights for best alpha
    best_alpha = best_result['alpha']
    C_best = 1.0 / best_alpha
    
    # Train final model with best alpha to get complete weights
    lasso_final = LogisticRegression(penalty='l1', C=C_best, solver='liblinear', 
                                   max_iter=1000, random_state=42)
    lasso_final.fit(X_scaled, y)
    
    # Get complete 48-feature weights (including zeros for unselected features)
    complete_weights = lasso_final.coef_[0].tolist()
    
    # Update best result with complete weights
    best_result['complete_feature_weights'] = complete_weights
    
    dataset_result = {
        'dataset_name': dataset_name,
        'n_samples': int(X_scaled.shape[0]),
        'n_features': int(X_scaled.shape[1]),
        'positive_rate': float(np.mean(y)),
        'alpha_optimization_results': results,
        'best_alpha_result': best_result,
        'feature_names': feature_names
    }
    
    return dataset_result

def create_summary_analysis(all_results):
    """Create summary analysis across all datasets"""
    print(f"\n{'='*70}")
    print("SUMMARY ANALYSIS ACROSS ALL DATASETS")
    print(f"{'='*70}")
    
    summary = {
        'dataset_ranking': [],
        'alpha_sensitivity_analysis': {}
    }
    
    # Dataset ranking by performance
    dataset_performances = []
    for result in all_results:
        if result and result['best_alpha_result']:
            best = result['best_alpha_result']
            combined_score = best['auroc_mean'] - 0.5 * best['brier_mean'] - 0.5 * best['ece_mean']
            dataset_performances.append({
                'dataset': result['dataset_name'],
                'auroc': best['auroc_mean'],
                'brier': best['brier_mean'],
                'ece': best['ece_mean'],
                'combined_score': combined_score,
                'n_features': best['n_selected_features'],
                'alpha': best['alpha']
            })
    
    # Sort by combined score
    dataset_performances.sort(key=lambda x: x['combined_score'], reverse=True)
    summary['dataset_ranking'] = dataset_performances
    
    print(f"Dataset Performance Ranking:")
    for i, perf in enumerate(dataset_performances):
        print(f"  {i+1}. {perf['dataset']:15s}: Score={perf['combined_score']:.4f}, "
              f"AUROC={perf['auroc']:.4f}, Features={perf['n_features']}, Alpha={perf['alpha']}")
    
    # Alpha sensitivity analysis
    alpha_sensitivity = {}
    for alpha in alpha_values:
        alpha_results = []
        for result in all_results:
            if result and result['alpha_optimization_results']:
                for alpha_result in result['alpha_optimization_results']:
                    if alpha_result['alpha'] == alpha:
                        alpha_results.append(alpha_result)
        
        if alpha_results:
            avg_auroc = np.mean([r['auroc_mean'] for r in alpha_results])
            avg_features = np.mean([r['n_selected_features'] for r in alpha_results])
            alpha_sensitivity[alpha] = {
                'avg_auroc': float(avg_auroc),
                'avg_features': float(avg_features),
                'n_datasets': len(alpha_results)
            }
    
    summary['alpha_sensitivity_analysis'] = alpha_sensitivity
    
    return summary

def main():
    print("Comprehensive 8 Datasets Analysis")
    print("=" * 70)
    
    # List of datasets to analyze
    datasets_to_analyze = [
        'SimpleQA',
        'HotpotQA',
        'StrategyQA',
        'MATH500',
        'MMLU-Pro',
        'GPQA-diamond',
        'HLE',
        'GAIA'
    ]
    
    all_results = []
    
    # Analyze each dataset
    for dataset_name in datasets_to_analyze:
        try:
            result = analyze_dataset(dataset_name)
            all_results.append(result)
        except Exception as e:
            print(f"Error analyzing {dataset_name}: {e}")
            all_results.append(None)
    
    # Create summary analysis
    summary = create_summary_analysis(all_results)
    
    # Prepare final results
    final_results = {
        'analysis_metadata': {
            'total_datasets': len(datasets_to_analyze),
            'successful_analyses': len([r for r in all_results if r is not None]),
            'alpha_values_tested': alpha_values,
            'feature_names': feature_names
        },
        'individual_dataset_results': all_results,
        'summary_analysis': summary
    }
    
    # Save results to JSON
    output_file = 'comprehensive_8_datasets_analysis.json'
    with open(output_file, 'w') as f:
        json.dump(final_results, f, indent=2)
    
    print(f"Comprehensive analysis complete!")
    print(f"Results saved to: {output_file}")
    print(f"Analyzed {len([r for r in all_results if r is not None])} datasets successfully")
    
    return final_results

if __name__ == "__main__":
    main()
