"""
Multi-seed ABBR value experiment runner.

This module runs the ABBR vs consistency comparison experiment across multiple random seeds
and aggregates the results to show the average generalization performance.
"""

import numpy as np
import pandas as pd
import random
import os
from typing import Dict, List, Tuple
from tqdm import tqdm

from abbr_experiment_config import ExperimentConfig, AVAILABLE_DATASETS
from abbr_rule_generator import (
    generate_random_rules, 
    calculate_rule_metrics, 
    evaluate_rule_on_test,
    analyze_feature_types
)

def run_single_seed_experiment(
    config: ExperimentConfig, 
    seed: int, 
    verbose: bool = False
) -> Dict:
    """
    Run the ABBR experiment for a single random seed.
    
    Args:
        config: Experiment configuration
        seed: Random seed to use
        verbose: Whether to print detailed progress
        
    Returns:
        Dictionary with experiment results for this seed
    """
    # Set random seed
    random.seed(seed)
    np.random.seed(seed)
    
    # Load dataset
    dataset = config.dataset_class(random_seed=seed)
    
    X_train = dataset.get_X_train()
    X_test = dataset.get_X_test()
    y_train_quantile = dataset.get_y_train_quantile()
    y_test_quantile = dataset.get_y_test_quantile()
    
    # Get binary predictions with specified threshold
    y_train_binary = dataset.get_y_train_preds(config.confidence_threshold)
    y_test_binary = dataset.get_y_test_preds(config.confidence_threshold)
    
    if verbose:
        print(f"\nSeed {seed}:")
        print(f"  Dataset: {config.dataset_class.__name__}")
        print(f"  Train size: {len(X_train)}, Test size: {len(X_test)}")
        print(f"  Features: {len(X_train.columns)}")
    
    # Generate random rules
    if verbose:
        print(f"  Generating rules...")
    
    rules = generate_random_rules(
        X=X_train,
        y_quantile=y_train_quantile,
        num_rules=config.num_rules_to_generate,
        max_conditions=config.max_conditions_per_rule,
        min_support=config.min_rule_support,
        max_valid_rules=config.max_valid_rules,
        verbose=verbose
    )
    
    if len(rules) == 0:
        return {
            'seed': seed,
            'error': 'No valid rules generated'
        }
    
    # Calculate metrics for all rules
    metrics_df = calculate_rule_metrics(rules, X_train, y_train_quantile, y_train_binary)
    
    if len(metrics_df) == 0:
        return {
            'seed': seed,
            'error': 'No rules with valid metrics'
        }
    
    # Select top rule by ABBR and top rule by consistency
    top_abbr_rule = metrics_df.nlargest(1, 'abbr').iloc[0]
    top_consistency_rule = metrics_df.nlargest(1, 'consistency').iloc[0]
    
    # Evaluate both rules on test set
    abbr_rule_test_consistency, abbr_rule_test_support = evaluate_rule_on_test(
        top_abbr_rule['rule'], X_test, y_test_binary
    )
    
    consistency_rule_test_consistency, consistency_rule_test_support = evaluate_rule_on_test(
        top_consistency_rule['rule'], X_test, y_test_binary
    )
    
    # Calculate generalization gaps
    abbr_gap = top_abbr_rule['consistency'] - abbr_rule_test_consistency
    consistency_gap = top_consistency_rule['consistency'] - consistency_rule_test_consistency
    
    if verbose:
        print(f"  ABBR rule: train_cons={top_abbr_rule['consistency']:.3f}, test_cons={abbr_rule_test_consistency:.3f}, gap={abbr_gap:.3f}")
        print(f"  Cons rule: train_cons={top_consistency_rule['consistency']:.3f}, test_cons={consistency_rule_test_consistency:.3f}, gap={consistency_gap:.3f}")
    
    return {
        'seed': seed,
        'dataset': config.dataset_class.__name__,
        'num_rules_generated': len(rules),
        'num_rules_with_metrics': len(metrics_df),
        
        # ABBR-selected rule results
        'abbr_rule_train_abbr': top_abbr_rule['abbr'],
        'abbr_rule_train_consistency': top_abbr_rule['consistency'],
        'abbr_rule_train_support': top_abbr_rule['support'],
        'abbr_rule_test_consistency': abbr_rule_test_consistency,
        'abbr_rule_test_support': abbr_rule_test_support,
        'abbr_rule_generalization_gap': abbr_gap,
        
        # Consistency-selected rule results
        'consistency_rule_train_abbr': top_consistency_rule['abbr'],
        'consistency_rule_train_consistency': top_consistency_rule['consistency'],
        'consistency_rule_train_support': top_consistency_rule['support'],
        'consistency_rule_test_consistency': consistency_rule_test_consistency,
        'consistency_rule_test_support': consistency_rule_test_support,
        'consistency_rule_generalization_gap': consistency_gap,
        
        # Comparison
        'abbr_better_generalization': abbr_gap < consistency_gap,
        'generalization_gap_difference': consistency_gap - abbr_gap,  # Positive means ABBR is better
        'error': None  # No error occurred
    }

def run_multi_seed_experiment(config: ExperimentConfig) -> Tuple[pd.DataFrame, Dict]:
    """
    Run the ABBR experiment across multiple random seeds.
    
    Args:
        config: Experiment configuration
        
    Returns:
        Tuple of (detailed_results_df, summary_stats_dict)
    """
    print(f"Running multi-seed ABBR experiment")
    print(f"Dataset: {config.dataset_class.__name__}")
    print(f"Seeds: {config.base_seed} to {config.base_seed + config.num_seeds - 1}")
    print(f"Confidence threshold: {config.confidence_threshold}")
    print(f"Min rule support: {config.min_rule_support}")
    print(f"Max conditions per rule: {config.max_conditions_per_rule}")
    
    # Show feature analysis for the first seed
    dataset_sample = config.dataset_class(random_seed=config.base_seed)
    X_sample = dataset_sample.get_X_train()
    analyze_feature_types(X_sample, verbose=True)
    
    results = []
    seeds = range(config.base_seed, config.base_seed + config.num_seeds)
    
    # Run experiments with progress bar
    for seed in tqdm(seeds, desc="Running experiments"):
        try:
            result = run_single_seed_experiment(config, seed, verbose=False)
            results.append(result)
        except Exception as e:
            print(f"Error in seed {seed}: {e}")
            results.append({
                'seed': seed,
                'error': str(e)
            })
    
    # Convert to DataFrame
    results_df = pd.DataFrame(results)
    
    # Filter out failed experiments
    successful_results = results_df[~results_df['seed'].isna() & results_df['error'].isna()].copy()
    
    if len(successful_results) == 0:
        print("ERROR: No successful experiments!")
        return results_df, {}
    
    print(f"\nSuccessful experiments: {len(successful_results)} / {config.num_seeds}")
    
    # Calculate summary statistics
    summary_stats = {
        'config': config,
        'total_seeds': config.num_seeds,
        'successful_seeds': len(successful_results),
        'failed_seeds': config.num_seeds - len(successful_results),
        
        # ABBR rule statistics
        'abbr_rule_avg_train_consistency': successful_results['abbr_rule_train_consistency'].mean(),
        'abbr_rule_std_train_consistency': successful_results['abbr_rule_train_consistency'].std(),
        'abbr_rule_avg_test_consistency': successful_results['abbr_rule_test_consistency'].mean(),
        'abbr_rule_std_test_consistency': successful_results['abbr_rule_test_consistency'].std(),
        'abbr_rule_avg_generalization_gap': successful_results['abbr_rule_generalization_gap'].mean(),
        'abbr_rule_std_generalization_gap': successful_results['abbr_rule_generalization_gap'].std(),
        
        # Consistency rule statistics
        'consistency_rule_avg_train_consistency': successful_results['consistency_rule_train_consistency'].mean(),
        'consistency_rule_std_train_consistency': successful_results['consistency_rule_train_consistency'].std(),
        'consistency_rule_avg_test_consistency': successful_results['consistency_rule_test_consistency'].mean(),
        'consistency_rule_std_test_consistency': successful_results['consistency_rule_test_consistency'].std(),
        'consistency_rule_avg_generalization_gap': successful_results['consistency_rule_generalization_gap'].mean(),
        'consistency_rule_std_generalization_gap': successful_results['consistency_rule_generalization_gap'].std(),
        
        # Comparison statistics
        'avg_generalization_gap_difference': successful_results['generalization_gap_difference'].mean(),
        'std_generalization_gap_difference': successful_results['generalization_gap_difference'].std(),
        'fraction_abbr_better': successful_results['abbr_better_generalization'].mean(),
        
        # Rule generation statistics
        'avg_rules_generated': successful_results['num_rules_generated'].mean(),
        'avg_rules_with_metrics': successful_results['num_rules_with_metrics'].mean(),
    }
    
    return results_df, summary_stats

def print_summary_results(summary_stats: Dict):
    """Print a formatted summary of the experiment results."""
    
    print(f"\n" + "="*80)
    print(f"EXPERIMENT SUMMARY")
    print(f"="*80)
    
    print(f"Dataset: {summary_stats['config'].dataset_class.__name__}")
    print(f"Successful seeds: {summary_stats['successful_seeds']} / {summary_stats['total_seeds']}")
    print(f"Confidence threshold: {summary_stats['config'].confidence_threshold}")
    print(f"Min rule support: {summary_stats['config'].min_rule_support}")
    
    print(f"\nRULE GENERATION:")
    print(f"  Average rules generated per seed: {summary_stats['avg_rules_generated']:.1f}")
    print(f"  Average rules with valid metrics: {summary_stats['avg_rules_with_metrics']:.1f}")
    
    print(f"\nABBR-SELECTED RULES:")
    print(f"  Train consistency: {summary_stats['abbr_rule_avg_train_consistency']:.3f} ± {summary_stats['abbr_rule_std_train_consistency']:.3f}")
    print(f"  Test consistency:  {summary_stats['abbr_rule_avg_test_consistency']:.3f} ± {summary_stats['abbr_rule_std_test_consistency']:.3f}")
    print(f"  Generalization gap: {summary_stats['abbr_rule_avg_generalization_gap']:.3f} ± {summary_stats['abbr_rule_std_generalization_gap']:.3f}")
    
    print(f"\nCONSISTENCY-SELECTED RULES:")
    print(f"  Train consistency: {summary_stats['consistency_rule_avg_train_consistency']:.3f} ± {summary_stats['consistency_rule_std_train_consistency']:.3f}")
    print(f"  Test consistency:  {summary_stats['consistency_rule_avg_test_consistency']:.3f} ± {summary_stats['consistency_rule_std_test_consistency']:.3f}")
    print(f"  Generalization gap: {summary_stats['consistency_rule_avg_generalization_gap']:.3f} ± {summary_stats['consistency_rule_std_generalization_gap']:.3f}")
    
    print(f"\nCOMPARISON:")
    print(f"  Average gap difference (Consistency - ABBR): {summary_stats['avg_generalization_gap_difference']:.3f} ± {summary_stats['std_generalization_gap_difference']:.3f}")
    print(f"  Fraction of seeds where ABBR generalizes better: {summary_stats['fraction_abbr_better']:.3f}")
    
    if summary_stats['avg_generalization_gap_difference'] > 0:
        print(f"  -> ABBR rules generalize BETTER on average (smaller gap)")
    else:
        print(f"  -> Consistency rules generalize BETTER on average (smaller gap)")
    
    print(f"="*80)

def save_results(results_df: pd.DataFrame, summary_stats: Dict, config: ExperimentConfig):
    """Save experiment results to files."""
    
    # Create output directory
    os.makedirs(config.output_dir, exist_ok=True)
    
    prefix = config.get_output_prefix()
    
    if config.save_detailed_results:
        detailed_file = os.path.join(config.output_dir, f"{prefix}_detailed.csv")
        results_df.to_csv(detailed_file, index=False)
        print(f"Detailed results saved to: {detailed_file}")
    
    if config.save_summary_results:
        summary_file = os.path.join(config.output_dir, f"{prefix}_summary.txt")
        with open(summary_file, 'w') as f:
            f.write(f"ABBR Value Experiment Summary\n")
            f.write(f"Dataset: {config.dataset_class.__name__}\n")
            f.write(f"Seeds: {config.base_seed} to {config.base_seed + config.num_seeds - 1}\n")
            f.write(f"Confidence threshold: {config.confidence_threshold}\n")
            f.write(f"Min rule support: {config.min_rule_support}\n")
            f.write(f"Max conditions per rule: {config.max_conditions_per_rule}\n\n")
            
            for key, value in summary_stats.items():
                if key != 'config':
                    f.write(f"{key}: {value}\n")
        
        print(f"Summary saved to: {summary_file}")

if __name__ == "__main__":
    # Example usage
    config = ExperimentConfig(
        dataset_class=AVAILABLE_DATASETS['fico'],
        confidence_threshold=0.9,
        min_rule_support=0.1,
        num_seeds=10,  # Use smaller number for testing
    )
    
    results_df, summary_stats = run_multi_seed_experiment(config)
    print_summary_results(summary_stats)
    save_results(results_df, summary_stats, config) 