import os
import json
import argparse
from pathlib import Path
import numpy as np
import re
from typing import Dict, List, Tuple
from scipy.stats import wilcoxon

def load_metrics(case_dir: Path) -> Dict:
    """Load metrics.json from a case directory."""
    metrics_path = case_dir / "metrics.json"
    if metrics_path.exists():
        try:
            with open(metrics_path, 'r') as f:
                return json.load(f)
        except (json.JSONDecodeError, IOError) as e:
            print(f"Warning: Could not load {metrics_path}: {e}")
    return {}

def analyze_outliers(accuracies: List[float], threshold: float) -> Tuple[int, float]:
    """Analyze outliers in accuracy data."""
    if not accuracies:
        return 0, 0.0
    
    outliers = [acc for acc in accuracies if acc > threshold]
    outlier_count = len(outliers)
    outlier_percentage = (outlier_count / len(accuracies)) * 100
    
    return outlier_count, outlier_percentage

def calculate_statistics(accuracies: List[float], threshold: float = 5.0) -> Dict:
    """Calculate comprehensive statistics for accuracy data."""
    if not accuracies:
        return {}
    
    # Analyze outliers
    outlier_count, outlier_percentage = analyze_outliers(accuracies, threshold)
    
    # Calculate comprehensive statistics
    mean_acc = np.mean(accuracies)
    median_acc = np.median(accuracies)
    std_acc = np.std(accuracies)
    min_acc = np.min(accuracies)
    max_acc = np.max(accuracies)
    total_points = len(accuracies)
    
    # Calculate threshold statistics
    thresholds = [1, 2, 3]
    threshold_stats = {}
    
    for thresh in thresholds:
        within_threshold = sum(1 for acc in accuracies if acc <= thresh)
        percentage = (within_threshold / total_points) * 100
        threshold_stats[thresh] = {'count': within_threshold, 'percentage': percentage}
    
    # Count points above 3mm
    above_last_th = sum(1 for acc in accuracies if acc > thresholds[-1])
    above_last_th_pct = (above_last_th / total_points) * 100
    
    # Return analysis results
    return {
        'mean': mean_acc,
        'median': median_acc,
        'std': std_acc,
        'min': min_acc,
        'max': max_acc,
        'outlier_count': outlier_count,
        'outlier_percentage': outlier_percentage,
        'total_points': total_points,
        'threshold_stats': threshold_stats,
        'above_3mm': above_last_th,
        'above_3mm_percentage': above_last_th_pct,
    }

def parse_case_numbers(case_string):
    """Parse case numbers from string (ranges or comma-separated)."""
    if not case_string:
        return None
    
    case_numbers = set()
    parts = case_string.split(',')
    
    for part in parts:
        part = part.strip()
        if '-' in part:
            # Handle range like "1-5"
            start, end = part.split('-', 1)
            start_num = int(start.strip())
            end_num = int(end.strip())
            for i in range(start_num, end_num + 1):
                case_numbers.add(f"{i}")
        else:
            # Handle single number
            num = int(part.strip())
            case_numbers.add(f"{num}")
    
    return sorted(case_numbers)

def find_method_cases(base_path: Path, case_number: str = None) -> Dict[str, List[Path]]:
    """Find all method directories and their cases, including seed subdirectories."""
    methods = {}
    
    for method_dir in base_path.iterdir():
        if method_dir.is_dir():
            cases = []
            
            # Check if this method has seed subdirectories
            seed_dirs = [d for d in method_dir.iterdir() if d.is_dir() and d.name.startswith('seed_')]
            
            if seed_dirs:
                # Method has seed subdirectories
                for seed_dir in seed_dirs:
                    for case_dir in seed_dir.iterdir():
                        if case_dir.is_dir() and case_dir.name.startswith('Case'):
                            if case_number is None:
                                cases.append(case_dir)
                            elif case_dir.name == f'Case{case_number}':
                                cases.append(case_dir)
            else:
                # Method has cases directly in method directory (old structure)
                for case_dir in method_dir.iterdir():
                    if case_dir.is_dir() and case_dir.name.startswith('Case'):
                        if case_number is None:
                            cases.append(case_dir)
                        elif case_dir.name == f'Case{case_number}':
                            cases.append(case_dir)
            
            if cases:
                methods[method_dir.name] = sorted(cases)
    
    return methods

def calculate_overall_statistics(base_path: Path, case_numbers: List[str]) -> Dict:
    """Calculate overall statistics for all methods across all cases."""
    all_methods = set()
    
    # Find all methods across all cases
    for case_num in case_numbers:
        if case_num is None:
            continue
        methods = find_method_cases(base_path, case_num)
        all_methods.update(methods.keys())
    
    if not all_methods:
        print("No methods found for overall statistics")
        return {}
    
    # Sort methods by numerical prefix
    def sort_key(method_name):
        match = re.match(r'^(\d+)', method_name)
        return int(match.group(1)) if match else float('inf')
    
    sorted_methods = sorted(all_methods, key=sort_key)
    
    overall_stats = {}
    
    for method_name in sorted_methods:
        # Collect all accuracy data for this method across all cases and seeds
        all_accuracies = []
        
        for case_num in case_numbers:
            if case_num is None:
                continue
            methods = find_method_cases(base_path, case_num)
            if method_name in methods:
                for case_dir in methods[method_name]:
                    metrics = load_metrics(case_dir)
                    accuracies = metrics.get('all_accuracies_mm') or metrics.get('all_accuracies')
                    if accuracies and isinstance(accuracies, list) and accuracies:
                        all_accuracies.extend(accuracies)
        
        if not all_accuracies:
            continue
        
        # Calculate statistics for this method (aggregated across all seeds)
        stats = calculate_statistics(all_accuracies)
        if stats:
            overall_stats[method_name] = stats
    
    return overall_stats

def perform_wilcoxon_tests(base_path: Path, case_numbers: List[str]) -> Dict:
    """Perform Wilcoxon signed-rank tests comparing methods."""
    # Define method groups for comparison
    kan_methods = []
    baseline_methods = []
    
    # Collect all accuracy data by method
    method_accuracies = {}
    
    for case_num in case_numbers:
        if case_num is None:
            continue
        methods = find_method_cases(base_path, case_num)
        
        for method_name, cases in methods.items():
            if method_name not in method_accuracies:
                method_accuracies[method_name] = []
            
            for case_dir in cases:
                metrics = load_metrics(case_dir)
                accuracies = metrics.get('all_accuracies_mm') or metrics.get('all_accuracies')
                if accuracies and isinstance(accuracies, list) and accuracies:
                    method_accuracies[method_name].extend(accuracies)
    
    # Identify method categories based on name patterns
    for method_name in method_accuracies.keys():
        clean_name = re.sub(r'^\d+[-_]?', '', method_name).lower().strip()
        print(f"Processing method: {clean_name} ({method_name})")
        if 'kan' in clean_name or 'randkan' in clean_name:
            kan_methods.append(method_name)
        elif 'ccidir' in clean_name or 'idir' in clean_name \
              or 'idir jac' in clean_name or 'idir_jac' in clean_name \
              or 'sinr' in clean_name or 'nodeo' in clean_name:
            baseline_methods.append(method_name)
    
    print("\nWILCOXON SIGNED-RANK TEST RESULTS")
    print("=" * 60)
    kan_method_names = [re.sub(r'^\d+[-_]?', '', m) for m in kan_methods]
    baseline_method_names = [re.sub(r'^\d+[-_]?', '', m) for m in baseline_methods]
    print(f"KAN-based methods: {kan_method_names}")
    print(f"Baseline methods: {baseline_method_names}")
    print()
    
    test_results = {}
    
    # Compare each KAN method against each baseline method
    for kan_method in kan_methods:
        kan_clean = re.sub(r'^\d+[-_]?', '', kan_method)
        test_results[kan_clean] = {}
        
        for baseline_method in baseline_methods:
            baseline_clean = re.sub(r'^\d+[-_]?', '', baseline_method)
            
            kan_acc = method_accuracies[kan_method]
            baseline_acc = method_accuracies[baseline_method]
            
            # Ensure equal sample sizes for paired test
            min_size = min(len(kan_acc), len(baseline_acc))
            if min_size < 10:
                print(f"Skipping {kan_clean} vs {baseline_clean}: insufficient data ({min_size} samples)")
                continue
            
            # Take random samples of equal size
            np.random.seed(42)  # For reproducibility
            kan_sample = np.random.choice(kan_acc, min_size, replace=False)
            baseline_sample = np.random.choice(baseline_acc, min_size, replace=False)
            
            # Perform Wilcoxon signed-rank test
            # Alternative hypothesis: KAN method has lower errors (better performance)
            try:
                statistic, p_value = wilcoxon(kan_sample, baseline_sample, alternative='less')
                
                # Calculate effect size (mean difference)
                mean_diff = np.mean(baseline_sample) - np.mean(kan_sample)
                
                test_results[kan_clean][baseline_clean] = {
                    'statistic': statistic,
                    'p_value': p_value,
                    'mean_diff': mean_diff,
                    'sample_size': min_size,
                    'kan_mean': np.mean(kan_sample),
                    'baseline_mean': np.mean(baseline_sample)
                }
                
                # Interpret results
                significance = ""
                if p_value < 0.001:
                    significance = "***"
                elif p_value < 0.01:
                    significance = "**"
                elif p_value < 0.05:
                    significance = "*"
                else:
                    significance = "ns"
                
                improvement = "BETTER" if mean_diff > 0 else "WORSE"
                
                print(f"{kan_clean} vs {baseline_clean}:")
                print(f"  Mean accuracy: {np.mean(kan_sample):.3f} vs {np.mean(baseline_sample):.3f} mm")
                print(f"  Mean difference: {mean_diff:.3f} mm ({improvement})")
                print(f"  Wilcoxon p-value: {p_value:.6f} {significance}")
                print(f"  Sample size: {min_size}")
                print()
                
            except ValueError as e:
                print(f"Could not perform test for {kan_clean} vs {baseline_clean}: {e}")
                continue
    
    # Compare KAN methods against each other
    print("KAN vs KAN COMPARISONS:")
    print("-" * 40)
    
    kan_vs_kan_results = {}
    
    for i, kan_method1 in enumerate(kan_methods):
        kan_clean1 = re.sub(r'^\d+[-_]?', '', kan_method1)
        kan_vs_kan_results[kan_clean1] = {}
        
        for j, kan_method2 in enumerate(kan_methods):
            if i >= j:  # Skip self-comparison and duplicate pairs
                continue
                
            kan_clean2 = re.sub(r'^\d+[-_]?', '', kan_method2)
            
            kan_acc1 = method_accuracies[kan_method1]
            kan_acc2 = method_accuracies[kan_method2]
            
            # Ensure equal sample sizes for paired test
            min_size = min(len(kan_acc1), len(kan_acc2))
            if min_size < 10:
                print(f"Skipping {kan_clean1} vs {kan_clean2}: insufficient data ({min_size} samples)")
                continue
            
            # Take random samples of equal size
            np.random.seed(42)  # For reproducibility
            kan_sample1 = np.random.choice(kan_acc1, min_size, replace=False)
            kan_sample2 = np.random.choice(kan_acc2, min_size, replace=False)
            
            # Perform Wilcoxon signed-rank test
            # Alternative hypothesis: first method has lower errors than second
            try:
                statistic, p_value = wilcoxon(kan_sample1, kan_sample2, alternative='less')
                
                # Calculate effect size (mean difference)
                mean_diff = np.mean(kan_sample2) - np.mean(kan_sample1)
                
                kan_vs_kan_results[kan_clean1][kan_clean2] = {
                    'statistic': statistic,
                    'p_value': p_value,
                    'mean_diff': mean_diff,
                    'sample_size': min_size,
                    'kan1_mean': np.mean(kan_sample1),
                    'kan2_mean': np.mean(kan_sample2)
                }
                
                # Interpret results
                significance = ""
                if p_value < 0.001:
                    significance = "***"
                elif p_value < 0.01:
                    significance = "**"
                elif p_value < 0.05:
                    significance = "*"
                else:
                    significance = "ns"
                
                improvement = "BETTER" if mean_diff > 0 else "WORSE"
                
                print(f"{kan_clean1} vs {kan_clean2}:")
                print(f"  Mean accuracy: {np.mean(kan_sample1):.3f} vs {np.mean(kan_sample2):.3f} mm")
                print(f"  Mean difference: {mean_diff:.3f} mm ({improvement})")
                print(f"  Wilcoxon p-value: {p_value:.6f} {significance}")
                print(f"  Sample size: {min_size}")
                print()
                
            except ValueError as e:
                print(f"Could not perform test for {kan_clean1} vs {kan_clean2}: {e}")
                continue
    
    # Summary of significant results
    print("SUMMARY OF STATISTICAL SIGNIFICANCE:")
    print("-" * 40)
    significant_improvements = 0
    total_comparisons = 0
    
    for kan_method, comparisons in test_results.items():
        for baseline_method, result in comparisons.items():
            total_comparisons += 1
            if result['p_value'] < 0.05 and result['mean_diff'] > 0:
                significant_improvements += 1
                significance_level = "p<0.001" if result['p_value'] < 0.001 else \
                                  "p<0.01" if result['p_value'] < 0.01 else "p<0.05"
                print(f"✓ {kan_method} significantly better than {baseline_method} ({significance_level})")
    
    # Add KAN vs KAN significant results to the same summary
    for kan_method1, comparisons in kan_vs_kan_results.items():
        for kan_method2, result in comparisons.items():
            total_comparisons += 1
            if result['p_value'] < 0.05 and result['mean_diff'] > 0:
                significant_improvements += 1
                significance_level = "p<0.001" if result['p_value'] < 0.001 else \
                                  "p<0.01" if result['p_value'] < 0.01 else "p<0.05"
                print(f"✓ {kan_method1} significantly better than {kan_method2} ({significance_level})")
            else:
                print(f"✗ {kan_method1} not significantly better than {kan_method2} (p={result['p_value']:.6f})")
    
    if significant_improvements == 0:
        print("No statistically significant improvements found.")
    else:
        print(f"\nTotal: {significant_improvements}/{total_comparisons} comparisons show significant improvement")
    
    print()
    return test_results

def main():
    parser = argparse.ArgumentParser(description='Analyze accuracy metrics and perform statistical tests')
    parser.add_argument('base_path', type=str, help='Base path containing method directories')
    parser.add_argument('--case', '-k', type=str, default=None,
                       help='Case numbers: single (01), range (1-5), or comma-separated (01,03,05)')
    parser.add_argument('--threshold', '-t', type=float, default=5.0,
                       help='Outlier threshold in mm (default: 5.0)')
    
    args = parser.parse_args()
    
    # Validate base path
    base_path = Path(args.base_path)
    if not base_path.exists():
        print(f"Error: Base path '{base_path}' does not exist!")
        return
    
    # Parse case numbers
    case_numbers = parse_case_numbers(args.case)
    if case_numbers is None:
        case_numbers = [None]  # Process all cases if no specific cases given
    
    print(f"Outlier threshold: {args.threshold}mm")
    print("-" * 50)
    
    # Store overall results across all cases
    all_case_results = {}
    processed_cases = []
    
    # Process each case
    for case_num in case_numbers:
        # Find methods and cases
        methods = find_method_cases(base_path, case_num)
        if not methods:
            case_info = f" for case {case_num}" if case_num else ""
            print(f"No method directories with cases found{case_info}!")
            continue
        
        case_info = f" (Case {case_num})" if case_num else ""
        print(f"Found methods{case_info}: {list(methods.keys())}")
        
        # Count seeds per method for informative output
        for method_name, cases in methods.items():
            method_dir = base_path / method_name
            seed_dirs = [d for d in method_dir.iterdir() if d.is_dir() and d.name.startswith('seed_')]
            if seed_dirs:
                print(f"  {method_name}: {len(seed_dirs)} seeds, {len(cases)} total cases")
            else:
                print(f"  {method_name}: {len(cases)} cases (no seeds)")
        
        # Collect accuracy data and calculate statistics
        all_results = {}
        
        for method_name, cases in methods.items():
            for case_dir in cases:
                case_id = case_dir.name.replace('Case', '')
                
                # Load metrics
                metrics = load_metrics(case_dir)
                accuracies = metrics.get('all_accuracies_mm') or metrics.get('all_accuracies')
                if accuracies and isinstance(accuracies, list) and accuracies:
                    # Calculate statistics
                    results = calculate_statistics(accuracies, args.threshold)
    
                    if results:
                        if method_name not in all_results:
                            all_results[method_name] = {}
                        all_results[method_name][case_id] = results
                        print(f"{method_name} Case{case_id}:")
                        print(f"  Mean: {results['mean']:.2f}mm, Median: {results['median']:.2f}mm")
                        print(f"  Outliers: {results['outlier_count']}/{results['total_points']} "
                              f"({results['outlier_percentage']:.1f}%)")
        
        print()
        
        # Summary statistics for this case
        if all_results:
            all_case_results[case_num] = all_results
            processed_cases.append(case_num)
            
            print("=" * 50)
            print(f"SUMMARY{case_info}")
            print("=" * 50)
            
            for method_name, method_results in all_results.items():
                display_name = re.sub(r'^\d+[-_]?', '', method_name)
                print(f"\n{display_name}:")
                
                all_outlier_percentages = [r['outlier_percentage'] for r in method_results.values()]
                all_means = [r['mean'] for r in method_results.values()]
                
                avg_outlier_pct = np.mean(all_outlier_percentages)
                avg_mean = np.mean(all_means)
                
                print(f"  Average outlier percentage: {avg_outlier_pct:.1f}%")
                print(f"  Average mean accuracy: {avg_mean:.2f}mm")
                print(f"  Cases processed: {len(method_results)}")
            print()

    # Create overall statistics if cases were processed
    if processed_cases:
        overall_stats = calculate_overall_statistics(base_path, [c for c in processed_cases if c is not None])
        
        # Print overall statistics
        if overall_stats:
            print("=" * 60)
            print("OVERALL STATISTICS ACROSS ALL CASES")
            print("=" * 60)
            
            # Create comparison table
            print(f"{'Method':<20} {'≤1mm':<8} {'≤2mm':<8} {'≤3mm':<8} {'>3mm':<8} {'Total':<8}")
            print("-" * 80)
            
            for method_name, stats in overall_stats.items():
                display_name = re.sub(r'^\d+[-_]?', '', method_name)
                pct_1mm = f"{stats['threshold_stats'][1]['percentage']:.2f}%"
                pct_2mm = f"{stats['threshold_stats'][2]['percentage']:.2f}%"
                pct_3mm = f"{stats['threshold_stats'][3]['percentage']:.2f}%"
                pct_above = f"{stats['above_3mm_percentage']:.2f}%"
                total = stats['total_points']
                
                print(f"{display_name:<20} {pct_1mm:<8} {pct_2mm:<8} {pct_3mm:<8} {pct_above:<8} {total:<8}")
            
            print()
            
            # Find best performing method for each metric
            best_mean = min(overall_stats.items(), key=lambda x: x[1]['mean'])
            best_1mm = max(overall_stats.items(), key=lambda x: x[1]['threshold_stats'][1]['percentage'])
            best_2mm = max(overall_stats.items(), key=lambda x: x[1]['threshold_stats'][2]['percentage'])
            best_3mm = max(overall_stats.items(), key=lambda x: x[1]['threshold_stats'][3]['percentage'])
            
            print("BEST PERFORMING METHODS:")
            best_mean_name = re.sub(r'^\d+[-_]?', '', best_mean[0])
            best_1mm_name = re.sub(r'^\d+[-_]?', '', best_1mm[0])
            best_2mm_name = re.sub(r'^\d+[-_]?', '', best_2mm[0])
            best_3mm_name = re.sub(r'^\d+[-_]?', '', best_3mm[0])
            
            print(f"  Lowest mean error: {best_mean_name} ({best_mean[1]['mean']:.2f}mm)")
            print(f"  Highest ≤1mm rate: {best_1mm_name} ({best_1mm[1]['threshold_stats'][1]['percentage']:.2f}%)")
            print(f"  Highest ≤2mm rate: {best_2mm_name} ({best_2mm[1]['threshold_stats'][2]['percentage']:.2f}%)")
            print(f"  Highest ≤3mm rate: {best_3mm_name} ({best_3mm[1]['threshold_stats'][3]['percentage']:.2f}%)")
            print()
            
            # Perform Wilcoxon tests
            wilcoxon_results = perform_wilcoxon_tests(base_path, [c for c in processed_cases if c is not None])

if __name__ == "__main__":
    main()
