from typing import Dict, List, Optional
import numpy as np
from scipy import stats

from rich import print


def perform_paired_ttests(
    results: Dict[str, Dict[str, np.ndarray]],
    base_method: str = "rescue",
    compare_methods: Optional[List[str]] = None,
    alpha: float = 0.05,
    bonferroni_correction: bool = False,
) -> Dict:
    """
    Perform paired t-tests between `base_method` and each method in `compare_methods`.

    This compares the final regret at the common budget (last grid point) across seeds.
    Assumption: `results[method]['per_run']` is an array shaped [num_seeds, n_grid]
    with rows ordered consistently across methods (this holds for runs produced by
    `plot_regret_vs_cost` in this script).

    Prints t-statistic, p-value, effect size, and whether the difference is significant.
    
    Args:
        results : Dict[str, Dict[str, np.ndarray]]
            Dictionary containing results for each method
        base_method : str
            The baseline method to compare against
        compare_methods : Optional[List[str]]
            List of methods to compare with base_method. If None, compares with all other methods.
        alpha : float
            Significance level (default: 0.05)
        bonferroni_correction : bool
            Whether to apply Bonferroni correction for multiple comparisons
    
    Returns:
        Dict: Dictionary containing statistical results with structure:
            {
                'base_method': str,
                'base_method_stats': {
                    'mean': float,
                    'std': float,
                    'n_samples': int
                },
                'comparisons': {
                    'method_name': {
                        'n_samples': int,
                        'method_mean': float,
                        'method_std': float,
                        'mean_difference': float,
                        't_statistic': float,
                        'p_value': float,
                        'is_significant': bool,
                        'cohens_d': float,
                        'wilcoxon_statistic': float,
                        'wilcoxon_p_value': float,
                        'ci_95_lower': float,
                        'ci_95_upper': float,
                        'corrected_alpha': float
                    },
                    ...
                }
            }
    """
    # Initialize return dictionary
    stats_results = {
        'base_method': base_method,
        'base_method_stats': {},
        'comparisons': {}
    }
    
    if compare_methods is None:
        compare_methods = [m for m in results.keys() if m != base_method]

    if base_method not in results:
        print(f"Base method '{base_method}' not found in results. Available: {list(results.keys())}")
        return stats_results

    base_arr = results[base_method]["per_run"]
    if base_arr.size == 0:
        print(f"No runs for base method '{base_method}'.")
        return stats_results

    # Apply Bonferroni correction if requested
    corrected_alpha = alpha / len(compare_methods) if bonferroni_correction and len(compare_methods) > 0 else alpha
    if bonferroni_correction and len(compare_methods) > 1:
        print(f"Applying Bonferroni correction: alpha = {alpha} / {len(compare_methods)} = {corrected_alpha:.4f}\n")

    print(f"{'='*80}")
    print(f"Statistical Comparison: Base Method = '{base_method}'")
    print(f"{'='*80}\n")
    
    # Store base method stats (will be computed in first valid comparison)
    base_stats_computed = False

    for cmp in compare_methods:
        if cmp not in results:
            print(f"Comparison method '{cmp}' not found in results; skipping.\n")
            continue

        cmp_arr = results[cmp]["per_run"]

        if base_arr.shape[0] != cmp_arr.shape[0]:
            print(f"Warning: different number of runs for '{base_method}' ({base_arr.shape[0]}) and '{cmp}' ({cmp_arr.shape[0]}).")

        # Use min number of paired samples
        n_pairs = min(base_arr.shape[0], cmp_arr.shape[0])
        if n_pairs < 2:
            print(f"Not enough paired samples to run t-test for '{base_method}' vs '{cmp}' (n={n_pairs}). Need at least 2.\n")
            continue

        # Compare final regret at the last grid point (final budget)
        base_final = base_arr[:n_pairs, -1]
        cmp_final = cmp_arr[:n_pairs, -1]

        # Drop any NaNs pairwise
        mask = ~(np.isnan(base_final) | np.isnan(cmp_final))
        base_final = base_final[mask]
        cmp_final = cmp_final[mask]

        if base_final.size < 2:
            print(f"After removing NaNs, not enough samples to run t-test for '{base_method}' vs '{cmp}'.\n")
            continue

        print(f"Comparing: {base_method} vs {cmp}")
        print(f"{'-'*80}")
        print(f"Number of paired samples: {base_final.size}")
        
        # Descriptive statistics
        base_mean = np.mean(base_final)
        cmp_mean = np.mean(cmp_final)
        base_std = np.std(base_final, ddof=1)
        cmp_std = np.std(cmp_final, ddof=1)
        
        # Store base method stats (only once)
        if not base_stats_computed:
            stats_results['base_method_stats'] = {
                'mean': float(base_mean),
                'std': float(base_std),
                'n_samples': int(base_final.size)
            }
            base_stats_computed = True
        
        print(f"\nDescriptive Statistics:")
        print(f"  {base_method}: mean = {base_mean:.6f}, std = {base_std:.6f}")
        print(f"  {cmp}: mean = {cmp_mean:.6f}, std = {cmp_std:.6f}")
        
        # Mean difference and direction
        mean_diff = base_mean - cmp_mean
        if abs(mean_diff) < 1e-10:
            direction_str = "essentially equal to"
        elif mean_diff < 0:
            direction_str = "LOWER (better) than"
        else:
            direction_str = "HIGHER (worse) than"
        
        print(f"  Mean difference: {base_method} is {abs(mean_diff):.6f} {direction_str} {cmp}")
        
        # Paired t-test
        t_stat, p_val = stats.ttest_rel(base_final, cmp_final)
        significance = "SIGNIFICANT" if (p_val < corrected_alpha) else "NOT significant"
        
        print(f"\nPaired t-test:")
        print(f"  t-statistic = {t_stat:.4f}")
        print(f"  p-value = {p_val:.4e}")
        print(f"  Result: {significance} at alpha = {corrected_alpha:.4f}")
        
        # Effect size (Cohen's d for paired samples)
        diff = base_final - cmp_final
        cohens_d = np.mean(diff) / np.std(diff, ddof=1)
        
        # Interpret effect size
        if abs(cohens_d) < 0.2:
            effect_interpretation = "negligible"
        elif abs(cohens_d) < 0.5:
            effect_interpretation = "small"
        elif abs(cohens_d) < 0.8:
            effect_interpretation = "medium"
        else:
            effect_interpretation = "large"
        
        print(f"\nEffect Size:")
        print(f"  Cohen's d = {cohens_d:.4f} ({effect_interpretation})")
        
        # Wilcoxon signed-rank test (non-parametric alternative)
        try:
            w_stat, w_pval = stats.wilcoxon(base_final, cmp_final)
            w_significance = "SIGNIFICANT" if (w_pval < corrected_alpha) else "NOT significant"
            print(f"\nWilcoxon signed-rank test (non-parametric):")
            print(f"  W-statistic = {w_stat:.4f}")
            print(f"  p-value = {w_pval:.4e}")
            print(f"  Result: {w_significance} at alpha = {corrected_alpha:.4f}")
        except Exception as e:
            print(f"\nWilcoxon test failed: {e}")
            w_stat, w_pval = np.nan, np.nan
        
        # Confidence interval for mean difference
        se_diff = np.std(diff, ddof=1) / np.sqrt(len(diff))
        ci_95 = stats.t.interval(0.95, len(diff)-1, loc=np.mean(diff), scale=se_diff)
        print(f"\n95% Confidence Interval for mean difference:")
        print(f"  [{ci_95[0]:.6f}, {ci_95[1]:.6f}]")
        
        # Store comparison results
        stats_results['comparisons'][cmp] = {
            'n_samples': int(base_final.size),
            'method_mean': float(cmp_mean),
            'method_std': float(cmp_std),
            'mean_difference': float(mean_diff),
            't_statistic': float(t_stat),
            'p_value': float(p_val),
            'is_significant': bool(p_val < corrected_alpha),
            'cohens_d': float(cohens_d),
            'wilcoxon_statistic': float(w_stat) if not np.isnan(w_stat) else None,
            'wilcoxon_p_value': float(w_pval) if not np.isnan(w_pval) else None,
            'ci_95_lower': float(ci_95[0]),
            'ci_95_upper': float(ci_95[1]),
            'corrected_alpha': float(corrected_alpha)
        }
        
        # Overall conclusion
        print(f"\n{'*'*80}")
        if p_val < corrected_alpha:
            winner = base_method if mean_diff < 0 else cmp
            print(f"CONCLUSION: {winner} performs significantly better (lower regret)")
        else:
            print(f"CONCLUSION: No significant difference between {base_method} and {cmp}")
        print(f"{'*'*80}\n\n")
    
    return stats_results


def analyze_area_under_curve(
    results: Dict[str, Dict[str, np.ndarray]],
    base_method: str = "rescue",
    compare_methods: Optional[List[str]] = None,
) -> Dict:
    """
    Compare methods by computing the area under the regret curve (AUC).
    Lower AUC = better overall performance across the entire budget.
    
    This gives a single metric for overall convergence quality.
    
    Returns:
        Dict: Dictionary containing AUC analysis results with structure:
            {
                'base_method': str,
                'base_method_auc': {
                    'mean': float,
                    'std': float,
                    'per_seed': list of floats,
                    'n_seeds': int
                },
                'comparisons': {
                    'method_name': {
                        'mean_auc': float,
                        'std_auc': float,
                        'per_seed_auc': list of floats,
                        'n_seeds': int,
                        'auc_difference': float,
                        'percent_difference': float,
                        't_statistic': float,
                        'p_value': float,
                        'is_significant': bool,
                        'cohens_d': float
                    },
                    ...
                }
            }
    """
    
    # Initialize return dictionary
    auc_results = {
        'base_method': base_method,
        'base_method_auc': {},
        'comparisons': {}
    }
    
    if compare_methods is None:
        compare_methods = [m for m in results.keys() if m != base_method]
    
    if base_method not in results:
        print(f"Base method '{base_method}' not found.")
        return auc_results
    
    print(f"{'='*80}")
    print(f"Area Under Curve (AUC) Analysis: Base Method = '{base_method}'")
    print(f"{'='*80}\n")
    
    base_data = results[base_method]
    base_regrets = base_data['per_run']
    # Look for 'grid' key first (from plot_regret_vs_cost), fallback to 'cost_grid', then indices
    base_costs = base_data.get('grid', base_data.get('cost_grid', np.arange(base_regrets.shape[1])))
    
    # Compute AUC for base method (for each seed)
    base_aucs = []
    for seed_idx in range(base_regrets.shape[0]):
        auc = np.trapz(base_regrets[seed_idx, :], base_costs)
        base_aucs.append(auc)
    base_aucs = np.array(base_aucs)
    
    # Store base method AUC stats
    auc_results['base_method_auc'] = {
        'mean': float(np.mean(base_aucs)),
        'std': float(np.std(base_aucs, ddof=1)) if len(base_aucs) > 1 else 0.0,
        'per_seed': base_aucs.tolist(),
        'n_seeds': int(len(base_aucs))
    }
    
    for cmp in compare_methods:
        if cmp not in results:
            continue
        
        cmp_data = results[cmp]
        cmp_regrets = cmp_data['per_run']
        # Look for 'grid' key first (from plot_regret_vs_cost), fallback to 'cost_grid', then indices
        cmp_costs = cmp_data.get('grid', cmp_data.get('cost_grid', np.arange(cmp_regrets.shape[1])))
        
        # Compute AUC for comparison method
        cmp_aucs = []
        n_seeds = min(base_regrets.shape[0], cmp_regrets.shape[0])
        for seed_idx in range(n_seeds):
            auc = np.trapz(cmp_regrets[seed_idx, :], cmp_costs)
            cmp_aucs.append(auc)
        cmp_aucs = np.array(cmp_aucs)
        
        # Compare
        base_aucs_matched = base_aucs[:n_seeds]
        
        base_mean_auc = np.mean(base_aucs_matched)
        cmp_mean_auc = np.mean(cmp_aucs)
        
        auc_diff = base_mean_auc - cmp_mean_auc
        percent_better = (auc_diff / abs(cmp_mean_auc)) * 100
        
        # Statistical test
        if n_seeds >= 2:
            t_stat, p_val = stats.ttest_rel(base_aucs_matched, cmp_aucs)
            significance = "SIGNIFICANT" if p_val < 0.05 else "NOT significant"
            
            diff = base_aucs_matched - cmp_aucs
            cohens_d = np.mean(diff) / np.std(diff, ddof=1) if np.std(diff, ddof=1) > 0 else np.nan
        else:
            t_stat, p_val, cohens_d = np.nan, np.nan, np.nan
            significance = "N/A"
        
        print(f"Comparing: {base_method} vs {cmp}")
        print(f"{'-'*80}")
        print(f"  {base_method} mean AUC: {base_mean_auc:.2f}")
        print(f"  {cmp} mean AUC: {cmp_mean_auc:.2f}")
        print(f"  AUC difference: {auc_diff:.2f} ({percent_better:.1f}% better)" if auc_diff < 0 else f"  AUC difference: {auc_diff:.2f} ({-percent_better:.1f}% worse)")
        print(f"  Statistical test: t={t_stat:.4f}, p={p_val:.4e} -> {significance}")
        print(f"  Effect size (Cohen's d): {cohens_d:.4f}")
        
        if auc_diff < 0:
            print(f"\n  => {base_method} has {abs(percent_better):.1f}% BETTER overall performance (lower AUC)")
        else:
            print(f"\n  => {base_method} has {abs(percent_better):.1f}% WORSE overall performance (higher AUC)")
        print()
        
        # Store comparison results
        auc_results['comparisons'][cmp] = {
            'mean_auc': float(cmp_mean_auc),
            'std_auc': float(np.std(cmp_aucs, ddof=1)) if len(cmp_aucs) > 1 else 0.0,
            'per_seed_auc': cmp_aucs.tolist(),
            'n_seeds': int(n_seeds),
            'auc_difference': float(auc_diff),
            'percent_difference': float(percent_better),
            't_statistic': float(t_stat) if not np.isnan(t_stat) else None,
            'p_value': float(p_val) if not np.isnan(p_val) else None,
            'is_significant': bool(significance == "SIGNIFICANT"),
            'cohens_d': float(cohens_d) if not np.isnan(cohens_d) else None
        }
    
    return auc_results