#!/usr/bin/env python3
"""
Statistical analysis utilities for PINN experiments.
"""

import numpy as np
import pandas as pd
from typing import Dict, List, Tuple, Any
from scipy import stats
from scipy.stats import ttest_ind, f_oneway, mannwhitneyu, kruskal
import warnings


class StatisticalAnalyzer:
    """Statistical analysis for experimental results."""
    
    def __init__(self, alpha: float = 0.05):
        """
        Initialize statistical analyzer.
        
        Args:
            alpha: Significance level for statistical tests
        """
        self.alpha = alpha
        self.results = {}
    
    def analyze_experimental_results(self, results: Dict[str, Any]) -> Dict[str, Any]:
        """
        Perform comprehensive statistical analysis on experimental results.
        
        Args:
            results: Dictionary containing experimental results
            
        Returns:
            Dictionary containing statistical analysis results
        """
        self.results = results
        analysis = {}
        
        for problem in results.keys():
            analysis[problem] = self._analyze_problem(problem, results[problem])
        
        # Cross-problem analysis
        analysis['cross_problem'] = self._cross_problem_analysis(results)
        
        return analysis
    
    def _analyze_problem(self, problem: str, problem_results: Dict[str, List]) -> Dict[str, Any]:
        """Analyze results for a specific problem."""
        analysis = {
            'descriptive_stats': {},
            'pairwise_tests': {},
            'anova_results': {},
            'effect_sizes': {},
            'confidence_intervals': {}
        }
        
        # Extract data
        methods = list(problem_results.keys())
        data = {}
        
        for method in methods:
            if problem_results[method]:  # Check if method has results
                losses = [r['final_train_loss'] for r in problem_results[method]]
                times = [r['training_time'] for r in problem_results[method]]
                data[method] = {
                    'losses': np.array(losses),
                    'times': np.array(times)
                }
        
        if len(data) < 2:
            return analysis
        
        # Descriptive statistics
        analysis['descriptive_stats'] = self._compute_descriptive_stats(data)
        
        # Pairwise t-tests
        analysis['pairwise_tests'] = self._pairwise_t_tests(data)
        
        # ANOVA
        analysis['anova_results'] = self._anova_analysis(data)
        
        # Effect sizes (Cohen's d)
        analysis['effect_sizes'] = self._compute_effect_sizes(data)
        
        # Confidence intervals
        analysis['confidence_intervals'] = self._compute_confidence_intervals(data)
        
        return analysis
    
    def _compute_descriptive_stats(self, data: Dict[str, Dict]) -> Dict[str, Dict]:
        """Compute descriptive statistics."""
        stats_dict = {}
        
        for method, values in data.items():
            stats_dict[method] = {
                'losses': {
                    'mean': np.mean(values['losses']),
                    'std': np.std(values['losses']),
                    'median': np.median(values['losses']),
                    'q25': np.percentile(values['losses'], 25),
                    'q75': np.percentile(values['losses'], 75),
                    'min': np.min(values['losses']),
                    'max': np.max(values['losses']),
                    'n': len(values['losses'])
                },
                'times': {
                    'mean': np.mean(values['times']),
                    'std': np.std(values['times']),
                    'median': np.median(values['times']),
                    'q25': np.percentile(values['times'], 25),
                    'q75': np.percentile(values['times'], 75),
                    'min': np.min(values['times']),
                    'max': np.max(values['times']),
                    'n': len(values['times'])
                }
            }
        
        return stats_dict
    
    def _pairwise_t_tests(self, data: Dict[str, Dict]) -> Dict[str, Dict]:
        """Perform pairwise t-tests between methods."""
        methods = list(data.keys())
        results = {}
        
        for i, method1 in enumerate(methods):
            for j, method2 in enumerate(methods[i+1:], i+1):
                pair_name = f"{method1}_vs_{method2}"
                
                # Loss comparison
                loss1 = data[method1]['losses']
                loss2 = data[method2]['losses']
                
                # Check normality
                _, p_norm1 = stats.shapiro(loss1) if len(loss1) >= 3 else (0, 1)
                _, p_norm2 = stats.shapiro(loss2) if len(loss2) >= 3 else (0, 1)
                
                # Choose appropriate test
                if p_norm1 > self.alpha and p_norm2 > self.alpha:
                    # Use t-test
                    t_stat, p_value = ttest_ind(loss1, loss2)
                    test_type = 't-test'
                else:
                    # Use Mann-Whitney U test
                    u_stat, p_value = mannwhitneyu(loss1, loss2, alternative='two-sided')
                    test_type = 'Mann-Whitney U'
                
                # Time comparison
                time1 = data[method1]['times']
                time2 = data[method2]['times']
                
                if p_norm1 > self.alpha and p_norm2 > self.alpha:
                    t_stat_time, p_value_time = ttest_ind(time1, time2)
                    test_type_time = 't-test'
                else:
                    u_stat_time, p_value_time = mannwhitneyu(time1, time2, alternative='two-sided')
                    test_type_time = 'Mann-Whitney U'
                
                results[pair_name] = {
                    'losses': {
                        'test_type': test_type,
                        'p_value': p_value,
                        'significant': p_value < self.alpha,
                        'effect_size': self._cohens_d(loss1, loss2)
                    },
                    'times': {
                        'test_type': test_type_time,
                        'p_value': p_value_time,
                        'significant': p_value_time < self.alpha,
                        'effect_size': self._cohens_d(time1, time2)
                    }
                }
        
        return results
    
    def _anova_analysis(self, data: Dict[str, Dict]) -> Dict[str, Dict]:
        """Perform ANOVA analysis."""
        methods = list(data.keys())
        
        # Prepare data for ANOVA
        loss_groups = [data[method]['losses'] for method in methods]
        time_groups = [data[method]['times'] for method in methods]
        
        # Check normality for each group
        loss_normality = all(stats.shapiro(group)[1] > self.alpha for group in loss_groups if len(group) >= 3)
        time_normality = all(stats.shapiro(group)[1] > self.alpha for group in time_groups if len(group) >= 3)
        
        # Perform appropriate test
        if loss_normality:
            f_stat_loss, p_value_loss = f_oneway(*loss_groups)
            test_type_loss = 'ANOVA'
        else:
            h_stat_loss, p_value_loss = kruskal(*loss_groups)
            test_type_loss = 'Kruskal-Wallis'
        
        if time_normality:
            f_stat_time, p_value_time = f_oneway(*time_groups)
            test_type_time = 'ANOVA'
        else:
            h_stat_time, p_value_time = kruskal(*time_groups)
            test_type_time = 'Kruskal-Wallis'
        
        return {
            'losses': {
                'test_type': test_type_loss,
                'p_value': p_value_loss,
                'significant': p_value_loss < self.alpha,
                'methods': methods
            },
            'times': {
                'test_type': test_type_time,
                'p_value': p_value_time,
                'significant': p_value_time < self.alpha,
                'methods': methods
            }
        }
    
    def _compute_effect_sizes(self, data: Dict[str, Dict]) -> Dict[str, Dict]:
        """Compute effect sizes between methods."""
        methods = list(data.keys())
        effect_sizes = {}
        
        for i, method1 in enumerate(methods):
            for j, method2 in enumerate(methods[i+1:], i+1):
                pair_name = f"{method1}_vs_{method2}"
                
                loss1 = data[method1]['losses']
                loss2 = data[method2]['losses']
                time1 = data[method1]['times']
                time2 = data[method2]['times']
                
                effect_sizes[pair_name] = {
                    'losses': self._cohens_d(loss1, loss2),
                    'times': self._cohens_d(time1, time2)
                }
        
        return effect_sizes
    
    def _compute_confidence_intervals(self, data: Dict[str, Dict], confidence: float = 0.95) -> Dict[str, Dict]:
        """Compute confidence intervals for means."""
        ci_results = {}
        
        for method, values in data.items():
            # Loss confidence interval
            loss_mean = np.mean(values['losses'])
            loss_std = np.std(values['losses'], ddof=1)
            loss_n = len(values['losses'])
            
            # Time confidence interval
            time_mean = np.mean(values['times'])
            time_std = np.std(values['times'], ddof=1)
            time_n = len(values['times'])
            
            # Calculate confidence intervals
            alpha = 1 - confidence
            t_critical = stats.t.ppf(1 - alpha/2, df=min(loss_n, time_n) - 1)
            
            loss_margin = t_critical * (loss_std / np.sqrt(loss_n))
            time_margin = t_critical * (time_std / np.sqrt(time_n))
            
            ci_results[method] = {
                'losses': {
                    'mean': loss_mean,
                    'ci_lower': loss_mean - loss_margin,
                    'ci_upper': loss_mean + loss_margin,
                    'margin': loss_margin
                },
                'times': {
                    'mean': time_mean,
                    'ci_lower': time_mean - time_margin,
                    'ci_upper': time_mean + time_margin,
                    'margin': time_margin
                }
            }
        
        return ci_results
    
    def _cohens_d(self, group1: np.ndarray, group2: np.ndarray) -> float:
        """Calculate Cohen's d effect size."""
        n1, n2 = len(group1), len(group2)
        s1, s2 = np.std(group1, ddof=1), np.std(group2, ddof=1)
        
        # Pooled standard deviation
        pooled_std = np.sqrt(((n1 - 1) * s1**2 + (n2 - 1) * s2**2) / (n1 + n2 - 2))
        
        if pooled_std == 0:
            return 0.0
        
        # Cohen's d
        d = (np.mean(group1) - np.mean(group2)) / pooled_std
        return d
    
    def _cross_problem_analysis(self, results: Dict[str, Any]) -> Dict[str, Any]:
        """Analyze performance across different problems."""
        cross_analysis = {
            'method_rankings': {},
            'consistency_analysis': {},
            'overall_performance': {}
        }
        
        # Collect data across problems
        method_performance = {}
        for problem, problem_results in results.items():
            for method, method_results in problem_results.items():
                if method_results:
                    losses = [r['final_train_loss'] for r in method_results]
                    if method not in method_performance:
                        method_performance[method] = []
                    method_performance[method].extend(losses)
        
        # Compute overall performance
        for method, losses in method_performance.items():
            cross_analysis['overall_performance'][method] = {
                'mean_loss': np.mean(losses),
                'std_loss': np.std(losses),
                'median_loss': np.median(losses),
                'n_experiments': len(losses)
            }
        
        # Method rankings per problem
        for problem, problem_results in results.items():
            method_means = {}
            for method, method_results in problem_results.items():
                if method_results:
                    losses = [r['final_train_loss'] for r in method_results]
                    method_means[method] = np.mean(losses)
            
            # Sort by performance (lower is better)
            sorted_methods = sorted(method_means.items(), key=lambda x: x[1])
            cross_analysis['method_rankings'][problem] = [method for method, _ in sorted_methods]
        
        return cross_analysis
    
    def generate_report(self, analysis: Dict[str, Any]) -> str:
        """Generate a comprehensive statistical report."""
        report = []
        report.append("=" * 80)
        report.append("STATISTICAL ANALYSIS REPORT")
        report.append("=" * 80)
        
        # Overall performance summary
        if 'cross_problem' in analysis:
            report.append("\n📊 OVERALL PERFORMANCE SUMMARY")
            report.append("-" * 50)
            overall = analysis['cross_problem']['overall_performance']
            for method, stats in overall.items():
                report.append(f"{method.upper():>12}: Mean Loss = {stats['mean_loss']:.6f} ± {stats['std_loss']:.6f} "
                            f"(n={stats['n_experiments']})")
        
        # Problem-specific analysis
        for problem, problem_analysis in analysis.items():
            if problem == 'cross_problem':
                continue
                
            report.append(f"\n🔬 {problem.upper().replace('_', ' ')} PROBLEM ANALYSIS")
            report.append("-" * 50)
            
            # Descriptive statistics
            if 'descriptive_stats' in problem_analysis:
                report.append("\nDescriptive Statistics:")
                for method, stats in problem_analysis['descriptive_stats'].items():
                    report.append(f"  {method.upper():>12}: Loss = {stats['losses']['mean']:.6f} ± {stats['losses']['std']:.6f}")
            
            # ANOVA results
            if 'anova_results' in problem_analysis:
                anova = problem_analysis['anova_results']
                if 'losses' in anova:
                    report.append(f"\nANOVA Results (Losses):")
                    report.append(f"  Test: {anova['losses']['test_type']}")
                    report.append(f"  p-value: {anova['losses']['p_value']:.6f}")
                    report.append(f"  Significant: {'Yes' if anova['losses']['significant'] else 'No'}")
            
            # Pairwise comparisons
            if 'pairwise_tests' in problem_analysis:
                report.append(f"\nPairwise Comparisons (α = {self.alpha}):")
                for pair, results in problem_analysis['pairwise_tests'].items():
                    loss_result = results['losses']
                    report.append(f"  {pair}: p = {loss_result['p_value']:.6f} "
                                f"({'*' if loss_result['significant'] else 'ns'}) "
                                f"d = {loss_result['effect_size']:.3f}")
        
        report.append("\n" + "=" * 80)
        report.append("Legend: * = significant, ns = not significant")
        report.append("Effect size interpretation: |d| < 0.2 (small), 0.2 ≤ |d| < 0.5 (medium), |d| ≥ 0.5 (large)")
        report.append("=" * 80)
        
        return "\n".join(report)
