#!/usr/bin/env python3
"""
Statistical Significance Tests for New Baselines (IOC, Rule-Based, MPC-XAI) vs HCA
Comparing Answer Correctness, Faithfulness, and ROUGE-L across three domains
"""

import pandas as pd
import numpy as np
from scipy import stats
from pathlib import Path
import json
import logging

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


def load_baseline_results(results_dir: Path):
    """Load all baseline evaluation results"""
    
    baselines = {}
    
    # Load IOC results
    ioc_data = []
    for domain in ['greenhouse', 'tep', 'electricity']:
        file_path = results_dir / f'ioc_complete_metrics_{domain}.csv'
        if file_path.exists():
            df = pd.read_csv(file_path)
            df['domain'] = domain
            df['method'] = 'IOC'
            ioc_data.append(df)
    if ioc_data:
        baselines['IOC'] = pd.concat(ioc_data, ignore_index=True)
        logger.info(f"Loaded IOC: {len(baselines['IOC'])} samples")
    
    # Load Rule-Based results
    rule_data = []
    for domain in ['greenhouse', 'tep', 'electricity']:
        file_path = results_dir / f'rule_based_complete_metrics_{domain}.csv'
        if file_path.exists():
            df = pd.read_csv(file_path)
            df['domain'] = domain
            df['method'] = 'Rule-Based'
            rule_data.append(df)
    if rule_data:
        baselines['Rule-Based'] = pd.concat(rule_data, ignore_index=True)
        logger.info(f"Loaded Rule-Based: {len(baselines['Rule-Based'])} samples")
    
    # Load MPC-XAI results
    mpc_data = []
    for domain in ['greenhouse', 'tep', 'electricity']:
        file_path = results_dir / f'mpc_xai_complete_metrics_{domain}.csv'
        if file_path.exists():
            df = pd.read_csv(file_path)
            df['domain'] = domain
            df['method'] = 'MPC-XAI'
            mpc_data.append(df)
    if mpc_data:
        baselines['MPC-XAI'] = pd.concat(mpc_data, ignore_index=True)
        logger.info(f"Loaded MPC-XAI: {len(baselines['MPC-XAI'])} samples")
    
    return baselines


def load_hca_results(results_dir: Path):
    """Load HCA results from existing files"""
    
    hca_data = []
    
    # Try to load from domain-specific files
    for domain in ['greenhouse', 'tep', 'electricity']:
        file_path = results_dir / f'{domain}_complete_metrics.csv'
        if file_path.exists():
            df = pd.read_csv(file_path)
            df['domain'] = domain
            df['method'] = 'HCA'
            hca_data.append(df)
    
    if hca_data:
        hca_df = pd.concat(hca_data, ignore_index=True)
        logger.info(f"Loaded HCA: {len(hca_df)} samples")
        return hca_df
    
    return None


def compute_effect_size(group1, group2):
    """Compute Cohen's d effect size"""
    mean1, mean2 = np.mean(group1), np.mean(group2)
    std1, std2 = np.std(group1, ddof=1), np.std(group2, ddof=1)
    n1, n2 = len(group1), len(group2)
    
    # Pooled standard deviation
    pooled_std = np.sqrt(((n1 - 1) * std1**2 + (n2 - 1) * std2**2) / (n1 + n2 - 2))
    
    if pooled_std == 0:
        return 0.0
    
    cohen_d = (mean1 - mean2) / pooled_std
    return cohen_d


def paired_t_test_with_bonferroni(hca_df, baseline_df, metric, alpha=0.05):
    """
    Perform paired t-test with Bonferroni correction
    Assumes data can be aligned by question_id within each domain
    """
    
    results = {}
    domains = hca_df['domain'].unique()
    n_comparisons = len(domains)
    bonferroni_alpha = alpha / n_comparisons
    
    for domain in domains:
        hca_domain = hca_df[hca_df['domain'] == domain].sort_values('question_id')
        baseline_domain = baseline_df[baseline_df['domain'] == domain].sort_values('question_id')
        
        # Align by question_id
        merged = pd.merge(
            hca_domain[['question_id', metric]],
            baseline_domain[['question_id', metric]],
            on='question_id',
            suffixes=('_hca', '_baseline')
        )
        
        if len(merged) == 0:
            logger.warning(f"No matching questions for {domain}")
            continue
        
        hca_values = merged[f'{metric}_hca'].values
        baseline_values = merged[f'{metric}_baseline'].values
        
        # Paired t-test
        t_stat, p_value = stats.ttest_rel(hca_values, baseline_values)
        
        # Effect size
        cohen_d = compute_effect_size(hca_values, baseline_values)
        
        # Significance
        significant = p_value < bonferroni_alpha
        
        results[domain] = {
            'n_samples': len(merged),
            'hca_mean': np.mean(hca_values),
            'hca_std': np.std(hca_values, ddof=1),
            'baseline_mean': np.mean(baseline_values),
            'baseline_std': np.std(baseline_values, ddof=1),
            't_statistic': t_stat,
            'p_value': p_value,
            'p_value_bonferroni': bonferroni_alpha,
            'significant': significant,
            'cohen_d': cohen_d,
            'mean_difference': np.mean(hca_values) - np.mean(baseline_values)
        }
    
    return results


def run_all_statistical_tests(hca_df, baselines):
    """Run statistical tests for all baselines and metrics"""
    
    all_results = {}
    metrics = ['answer_correctness', 'faithfulness', 'rouge_l']
    
    for baseline_name, baseline_df in baselines.items():
        logger.info(f"\n{'='*80}")
        logger.info(f"Statistical Tests: HCA vs {baseline_name}")
        logger.info(f"{'='*80}")
        
        baseline_results = {}
        
        for metric in metrics:
            logger.info(f"\n--- Metric: {metric.upper()} ---")
            
            test_results = paired_t_test_with_bonferroni(hca_df, baseline_df, metric)
            baseline_results[metric] = test_results
            
            # Print results
            for domain, stats_dict in test_results.items():
                logger.info(f"\n{domain.upper()}:")
                logger.info(f"  HCA:       {stats_dict['hca_mean']:.4f} ± {stats_dict['hca_std']:.4f}")
                logger.info(f"  {baseline_name}: {stats_dict['baseline_mean']:.4f} ± {stats_dict['baseline_std']:.4f}")
                logger.info(f"  Difference: {stats_dict['mean_difference']:+.4f}")
                logger.info(f"  t-statistic: {stats_dict['t_statistic']:.4f}")
                logger.info(f"  p-value: {stats_dict['p_value']:.6f} (Bonferroni: {stats_dict['p_value_bonferroni']:.6f})")
                logger.info(f"  Cohen's d: {stats_dict['cohen_d']:.4f}")
                logger.info(f"  Significant: {'YES ✓' if stats_dict['significant'] else 'NO ✗'}")
        
        all_results[baseline_name] = baseline_results
    
    return all_results


def save_results_to_csv(all_results, output_dir: Path):
    """Save statistical test results to CSV"""
    
    rows = []
    
    for baseline_name, baseline_results in all_results.items():
        for metric, metric_results in baseline_results.items():
            for domain, stats_dict in metric_results.items():
                rows.append({
                    'baseline': baseline_name,
                    'metric': metric,
                    'domain': domain,
                    'n_samples': stats_dict['n_samples'],
                    'hca_mean': stats_dict['hca_mean'],
                    'hca_std': stats_dict['hca_std'],
                    'baseline_mean': stats_dict['baseline_mean'],
                    'baseline_std': stats_dict['baseline_std'],
                    'mean_difference': stats_dict['mean_difference'],
                    't_statistic': stats_dict['t_statistic'],
                    'p_value': stats_dict['p_value'],
                    'p_value_bonferroni': stats_dict['p_value_bonferroni'],
                    'significant': stats_dict['significant'],
                    'cohen_d': stats_dict['cohen_d']
                })
    
    df = pd.DataFrame(rows)
    output_file = output_dir / 'statistical_tests_new_baselines.csv'
    df.to_csv(output_file, index=False)
    logger.info(f"\nSaved results to {output_file}")
    
    return df


def generate_summary_report(all_results, output_dir: Path):
    """Generate a summary report of statistical findings"""
    
    report = []
    report.append("="*80)
    report.append("STATISTICAL SIGNIFICANCE TESTS: HCA vs NEW BASELINES")
    report.append("="*80)
    report.append("")
    report.append("Test: Paired t-test with Bonferroni correction (α = 0.05/3 domains = 0.0167)")
    report.append("Metrics: Answer Correctness (AC), Faithfulness (F), ROUGE-L")
    report.append("Domains: Greenhouse (67), TEP (55), Electricity (54)")
    report.append("")
    
    for baseline_name, baseline_results in all_results.items():
        report.append(f"\n{'='*80}")
        report.append(f"HCA vs {baseline_name}")
        report.append(f"{'='*80}")
        
        for metric in ['answer_correctness', 'faithfulness', 'rouge_l']:
            report.append(f"\n{metric.upper()}:")
            report.append("-" * 40)
            
            metric_results = baseline_results[metric]
            
            for domain in ['greenhouse', 'tep', 'electricity']:
                if domain not in metric_results:
                    continue
                
                stats_dict = metric_results[domain]
                report.append(f"\n{domain.capitalize()}:")
                report.append(f"  HCA:            {stats_dict['hca_mean']:.4f} ± {stats_dict['hca_std']:.4f}")
                report.append(f"  {baseline_name:15s} {stats_dict['baseline_mean']:.4f} ± {stats_dict['baseline_std']:.4f}")
                report.append(f"  Δ (HCA - {baseline_name}): {stats_dict['mean_difference']:+.4f}")
                report.append(f"  t({stats_dict['n_samples']-1}) = {stats_dict['t_statistic']:.4f}, p = {stats_dict['p_value']:.6f}")
                report.append(f"  Cohen's d = {stats_dict['cohen_d']:.4f}")
                report.append(f"  Significant: {'YES ✓' if stats_dict['significant'] else 'NO ✗'}")
            
            # Overall summary across domains
            all_diffs = [s['mean_difference'] for s in metric_results.values()]
            all_cohens = [s['cohen_d'] for s in metric_results.values()]
            sig_count = sum(1 for s in metric_results.values() if s['significant'])
            
            report.append(f"\n  Overall Summary:")
            report.append(f"    Mean difference: {np.mean(all_diffs):+.4f} (range: {np.min(all_diffs):+.4f} to {np.max(all_diffs):+.4f})")
            report.append(f"    Mean Cohen's d:  {np.mean(all_cohens):.4f} (range: {np.min(all_cohens):.4f} to {np.max(all_cohens):.4f})")
            report.append(f"    Significant domains: {sig_count}/3")
    
    # Save report
    report_file = output_dir / 'statistical_tests_new_baselines_report.txt'
    with open(report_file, 'w') as f:
        f.write('\n'.join(report))
    
    logger.info(f"\nSaved summary report to {report_file}")
    
    # Print to console
    print('\n'.join(report))


def main():
    # Paths
    results_dir = Path('./results')  # Relative path for anonymous submission
    output_dir = results_dir
    
    # Load data
    logger.info("Loading baseline results...")
    baselines = load_baseline_results(results_dir)
    
    logger.info("Loading HCA results...")
    hca_df = load_hca_results(results_dir)
    
    if hca_df is None or len(hca_df) == 0:
        logger.error("Could not load HCA results. Exiting.")
        return
    
    if not baselines:
        logger.error("Could not load any baseline results. Exiting.")
        return
    
    # Run statistical tests
    logger.info("\nRunning statistical tests...")
    all_results = run_all_statistical_tests(hca_df, baselines)
    
    # Save results
    logger.info("\nSaving results...")
    df = save_results_to_csv(all_results, output_dir)
    
    # Generate summary report
    generate_summary_report(all_results, output_dir)
    
    # Save JSON for programmatic access
    json_file = output_dir / 'statistical_tests_new_baselines.json'
    with open(json_file, 'w') as f:
        # Convert numpy types to Python native types for JSON serialization
        serializable_results = {}
        for baseline_name, baseline_results in all_results.items():
            serializable_results[baseline_name] = {}
            for metric, metric_results in baseline_results.items():
                serializable_results[baseline_name][metric] = {}
                for domain, stats_dict in metric_results.items():
                    serializable_results[baseline_name][metric][domain] = {
                        k: float(v) if isinstance(v, (np.floating, np.integer)) else 
                           int(v) if isinstance(v, (np.integer, bool)) else v
                        for k, v in stats_dict.items()
                    }
        json.dump(serializable_results, f, indent=2)
    
    logger.info(f"Saved JSON results to {json_file}")
    logger.info("\n✓ Statistical tests complete!")


if __name__ == '__main__':
    main()
