import os
import argparse
import numpy as np
import torch
import torch.nn.functional as F
from pathlib import Path
import re
import pickle
import json
from scipy.ndimage import map_coordinates
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from PIL import Image
import seaborn as sns

# Import from parent directory
import sys
sys.path.append(str(Path(__file__).parent.parent))

from utils import general
from utils.visualization import visualize_dirlab_warp_fixed, plot_accuracy_histogram, visualize_dirlab_warp_fixed_with_accuracy_mask

sns.set_palette("husl")

FONT_SIZE = 16

def calculate_crop_from_landmarks(landmarks_fixed, landmarks_moving, image_shape, landmark_padding):
        """Calculate crop size based on union of landmarks with padding"""
        # Combine all landmarks
        all_landmarks = np.vstack([landmarks_fixed, landmarks_moving])
        
        # Get min/max coordinates
        min_coords = np.min(all_landmarks, axis=0)
        max_coords = np.max(all_landmarks, axis=0)
        
        # Add padding
        min_coords = np.maximum(min_coords - landmark_padding, [0, 0, 0])
        max_coords = np.minimum(max_coords + landmark_padding, image_shape - np.array([1, 1, 1]))
        
        # Calculate crop size
        crop_size = (max_coords - min_coords).astype(int)
        
        # Ensure crop_size is even by adjusting coordinates
        for i in range(3):
            if crop_size[i] % 2 == 1:  # If odd
                # Try to expand max_coords first
                if max_coords[i] < image_shape[i] - 1:
                    max_coords[i] += 1
                # If can't expand max, shrink min
                elif min_coords[i] > 0:
                    min_coords[i] -= 1
                # If neither works, just add 1 to max (will be clamped later)
                else:
                    raise ValueError("Cannot adjust crop size to be even without going out of bounds.")
        
        # Recalculate crop size after adjustments
        crop_size = (max_coords - min_coords).astype(int)
        
        return crop_size.tolist(), min_coords.astype(int), max_coords.astype(int)


def crop_kp_to_landmarks(kp, min_coords):
    """Crop keypoints to the bounding box defined by min coordinates."""
    return kp - min_coords


def crop_img_to_landmarks(img, min_coords, max_coords):
    """Crop image to the bounding box defined by min and max coordinates."""
    return img[min_coords[0]:max_coords[0], 
               min_coords[1]:max_coords[1], 
               min_coords[2]:max_coords[2]]

def parse_case_numbers(case_string):
    """Parse case numbers from string (ranges or comma-separated)."""
    if not case_string:
        return list(range(1, 11))  # Default to all DIR-LAB cases
    
    case_numbers = set()
    parts = case_string.split(',')
    
    for part in parts:
        part = part.strip()
        if '-' in part:
            start, end = part.split('-', 1)
            start_num = int(start.strip())
            end_num = int(end.strip())
            for i in range(start_num, end_num + 1):
                case_numbers.add(i)
        else:
            num = int(part.strip())
            case_numbers.add(num)
    
    return sorted(case_numbers)

def find_worst_seed_for_case(method_dir, case_num):
    """Find the seed with minimum accuracy_mean_mm for a given case."""
    seed_dirs = [d for d in method_dir.iterdir() if d.is_dir() and d.name.startswith('seed_')]
    
    if not seed_dirs:
        return None, None  # No seeds found
    
    best_seed = None
    best_accuracy = float('-1')
    best_case_dir = None
    
    for seed_dir in seed_dirs:
        # Try different case directory name formats
        possible_case_dirs = [
            seed_dir / f"Case{case_num:02d}",  # Case01, Case02, etc.
            seed_dir / f"Case{case_num}",      # Case1, Case2, etc.
        ]
        
        case_dir = None
        for possible_dir in possible_case_dirs:
            if possible_dir.exists():
                case_dir = possible_dir
                break
        
        if case_dir is None:
            continue
            
        # Load metrics to get accuracy
        metrics_path = case_dir / "metrics.json"
        if not metrics_path.exists():
            continue
            
        try:
            with open(metrics_path, 'r') as f:
                metrics = json.load(f)

            accuracy = metrics.get('accuracy_mean_mm', metrics.get('accuracy_mean', float('-1')))
            if accuracy > best_accuracy:
                best_accuracy = accuracy
                best_seed = seed_dir.name
                best_case_dir = case_dir
                
        except Exception as e:
            print(f"Error reading metrics from {metrics_path}: {e}")
            continue
    
    return best_seed, best_case_dir

def find_best_seed_for_case(method_dir, case_num):
    """Find the seed with minimum accuracy_mean_mm for a given case."""
    seed_dirs = [d for d in method_dir.iterdir() if d.is_dir() and d.name.startswith('seed_')]
    
    if not seed_dirs:
        return None, None  # No seeds found
    
    best_seed = None
    best_accuracy = float('inf')
    best_case_dir = None
    
    for seed_dir in seed_dirs:
        # Try different case directory name formats
        possible_case_dirs = [
            seed_dir / f"Case{case_num:02d}",  # Case01, Case02, etc.
            seed_dir / f"Case{case_num}",      # Case1, Case2, etc.
        ]
        
        case_dir = None
        for possible_dir in possible_case_dirs:
            if possible_dir.exists():
                case_dir = possible_dir
                break
        
        if case_dir is None:
            continue
            
        # Load metrics to get accuracy
        metrics_path = case_dir / "metrics.json"
        if not metrics_path.exists():
            continue
            
        try:
            with open(metrics_path, 'r') as f:
                metrics = json.load(f)

            accuracy = metrics.get('accuracy_mean_mm', metrics.get('accuracy_mean', float('inf')))
            if accuracy < best_accuracy:
                best_accuracy = accuracy
                best_seed = seed_dir.name
                best_case_dir = case_dir
                
        except Exception as e:
            print(f"Error reading metrics from {metrics_path}: {e}")
            continue
    
    return best_seed, best_case_dir

def find_method_cases(results_path, case_numbers=None):
    """Find all method directories and their cases, selecting best seed for each case."""
    results_path = Path(results_path)
    methods = {}
    
    for method_dir in results_path.iterdir():
        if method_dir.is_dir():
            cases = []
            
            # Check if there are seed subdirectories
            seed_dirs = [d for d in method_dir.iterdir() if d.is_dir() and d.name.startswith('seed_')]
            
            if seed_dirs:
                # Process each requested case number
                for case_num in (case_numbers or list(range(1, 11))):
                    best_seed, best_case_dir = find_best_seed_for_case(method_dir, case_num)
                    # best_seed, best_case_dir = find_worst_seed_for_case(method_dir, case_num)
                    if best_seed and best_case_dir:
                        cases.append((case_num, best_case_dir, best_seed))
                        print(f"Selected {best_seed} for {method_dir.name} Case{case_num}")
            else:
                # No seed subdirectories, process cases directly
                for case_dir in method_dir.iterdir():
                    if case_dir.is_dir() and case_dir.name.startswith('Case'):
                        # Extract case number
                        case_match = re.search(r'Case(\d+)', case_dir.name)
                        if case_match:
                            case_num = int(case_match.group(1))
                            if case_numbers is None or case_num in case_numbers:
                                cases.append((case_num, case_dir, None))  # No seed
            
            if cases:
                methods[method_dir.name] = sorted(cases, key=lambda x: x[0])  # Sort by case number only
    
    return methods

def load_keypoints_from_json(case_dir):
    """Load keypoints from keypoints.json file."""
    keypoints_path = case_dir / "keypoints.json"
    if not keypoints_path.exists():
        return None, None, None
    
    try:
        with open(keypoints_path, 'r') as f:
            data = json.load(f)
        
        # Extract keypoints - adapt these keys based on your JSON structure
        kp_fixed = np.array(data.get('kp_fixed', data.get('landmarks_exp', [])))
        kp_fixed_warped = np.array(data.get('kp_fixed_warped', data.get('warped_landmarks_exp', [])))
        kp_moving = np.array(data.get('kp_mov', data.get('landmarks_insp', [])))
        
        return kp_fixed, kp_fixed_warped, kp_moving
    except Exception as e:
        print(f"Error loading keypoints from {keypoints_path}: {e}")
        return None, None, None

def load_metrics_from_json(case_dir):
    """Load metrics from metrics.json file."""
    metrics_path = case_dir / "metrics.json"
    if not metrics_path.exists():
        return None
    
    try:
        with open(metrics_path, 'r') as f:
            data = json.load(f)
        return data
    except Exception as e:
        print(f"Error loading metrics from {metrics_path}: {e}")
        return None

def load_warped_image(case_dir, img_insp, deformation_field=None):
    """Load or create warped moving image."""
    # Look for pre-computed warped image
    warped_img_path = case_dir / "warped_moving.npy"
    if warped_img_path.exists():
        return np.load(warped_img_path)
    
    # If no pre-computed image, try to create it from deformation field
    if deformation_field is not None:
        try:
            if isinstance(deformation_field, torch.Tensor):
                df_np = deformation_field.detach().cpu().numpy()
            else:
                df_np = deformation_field
            
            return general.apply_deformation_field(img_insp, df_np)
        except Exception as e:
            print(f"Error creating warped image: {e}")
            return img_insp  # Return original if warping fails
    
    return img_insp  # Return original if no deformation field

def load_deformation_field(case_dir):
    """Load deformation field from case directory."""
    # Look for common deformation field file names
    possible_names = [
        'deformation_field.npy',
        'displacement_field.npy',
        'df.npy',
        'disp_field.npy'
    ]
    
    for name in possible_names:
        df_path = case_dir / name
        if df_path.exists():
            return np.load(df_path)
    
    # Look for pickle files that might contain the deformation field
    for file_path in case_dir.glob('*.pkl'):
        try:
            with open(file_path, 'rb') as f:
                data = pickle.load(f)
                if isinstance(data, dict):
                    # Look for deformation field in common keys
                    for key in ['deformation_field', 'displacement_field', 'df', 'flow']:
                        if key in data:
                            return data[key]
                elif isinstance(data, np.ndarray) and data.ndim == 4:
                    return data
        except:
            continue
    
    return None

def create_histogram(accuracies, method_name, case_id, output_path, threshold=3.0, x_limits=None, y_limits=None, show_max_outlier=False):
    """Create and save histogram for accuracy data."""
    if accuracies is None or len(accuracies) == 0:
        print(f"Warning: No accuracy data for {method_name} Case{case_id}")
        return {}
    
    # Analyze outliers
    outlier_count = np.sum(np.array(accuracies) > threshold)
    outlier_percentage = (outlier_count / len(accuracies)) * 100 if len(accuracies) > 0 else 0.0
    
    # Create histogram
    fig, ax = plt.subplots(figsize=(6, 4))
    n_bins = min(50, max(10, len(accuracies) // 10))
    counts, bins, patches = ax.hist(accuracies, bins=n_bins, alpha=0.7, 
                                   color='skyblue', edgecolor='black', linewidth=0.5)
    
    # Add vertical line for threshold
    ax.axvline(threshold, color='red', linestyle='--', linewidth=2, 
               label=f'Threshold: {threshold}mm')
    
    # Add vertical line for maximum outlier if requested
    max_accuracy = np.max(accuracies)
    if show_max_outlier and max_accuracy > threshold:
        ax.axvline(max_accuracy, color='purple', linestyle=':', linewidth=2, 
                   label=f'Max outlier: {max_accuracy:.2f}mm')
    
    # Highlight outliers
    for i, (count, bin_left, bin_right) in enumerate(zip(counts, bins[:-1], bins[1:])):
        if bin_right > threshold:
            patches[i].set_color('coral')
            patches[i].set_alpha(0.8)
    
    # Calculate comprehensive statistics
    mean_acc = np.mean(accuracies)
    median_acc = np.median(accuracies)
    std_acc = np.std(accuracies)
    min_acc = np.min(accuracies)
    max_acc = np.max(accuracies)
    total_points = len(accuracies)
    
    # Calculate threshold statistics
    thresholds = [1, 2, 3]
    threshold_stats = {}
    for thresh in thresholds:
        within_threshold = sum(1 for acc in accuracies if acc <= thresh)
        percentage = (within_threshold / total_points) * 100
        threshold_stats[thresh] = {'count': within_threshold, 'percentage': percentage}
    above_last_th = sum(1 for acc in accuracies if acc > thresholds[-1])
    above_last_th_pct = (above_last_th / total_points) * 100
    
    display_method = re.sub(r'^\d+[-_]?', '', method_name)
    ax.set_xlabel('TRE (mm)', fontsize=FONT_SIZE)
    ax.set_ylabel('Frequency', fontsize=FONT_SIZE)
    
    # stats_text = f'Mean: {mean_acc:.2f}mm\n'
    # stats_text += f'Std: {std_acc:.2f}mm\n'
    # stats_text += f'Min: {min_acc:.2f}mm\n'
    # stats_text += f'Max: {max_acc:.2f}mm\n\n'
    # stats_text += 'Accuracy\nwithin thresholds:\n'
    stats_text = 'TRE within thresholds:\n'
    stats_text += f'≤ {thresholds[0]}mm: {threshold_stats[1]["count"]}/{total_points} ({threshold_stats[1]["percentage"]:.1f}%)\n'
    stats_text += f'≤ {thresholds[1]}mm: {threshold_stats[2]["count"]}/{total_points} ({threshold_stats[2]["percentage"]:.1f}%)\n'
    stats_text += f'≤ {thresholds[2]}mm: {threshold_stats[3]["count"]}/{total_points} ({threshold_stats[3]["percentage"]:.1f}%)\n'
    stats_text += f'> {thresholds[2]}mm: {above_last_th}/{total_points} ({above_last_th_pct:.1f}%)'
    stats_text += f'\nMax outlier: {max_accuracy:.2f}mm'
    # stats_text += f'Threshold: {threshold} mm\n'

    ax.text(0.98, 0.82, stats_text, transform=ax.transAxes, fontsize=FONT_SIZE,
            verticalalignment='top', horizontalalignment='right',
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.9),
            family='monospace')
    
    if x_limits:
        ax.set_xlim(x_limits)
    
    if y_limits:
        ax.set_ylim(y_limits)
    
    # ax.legend(fontsize=FONT_SIZE, loc='lower right')
    ax.grid(True, alpha=0.3)
    ax.tick_params(axis='both', which='major', labelsize=FONT_SIZE)
    
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight', pad_inches=0)
    plt.close()
    
    # Save detailed statistics to text file
    stats_file_path = output_path.parent / 'accuracy_statistics.txt'
    detailed_stats = f"Accuracy Statistics for {display_method} - Case {case_id}\n"
    detailed_stats += "=" * 60 + "\n\n"
    detailed_stats += f"Mean accuracy: {mean_acc:.2f} mm\n"
    detailed_stats += f"Std accuracy: {std_acc:.2f} mm\n"
    detailed_stats += f"Median accuracy: {median_acc:.2f} mm\n"
    detailed_stats += f"Min accuracy: {min_acc:.2f} mm\n"
    detailed_stats += f"Max accuracy: {max_acc:.2f} mm\n\n"
    detailed_stats += "Accuracy within thresholds:\n"
    for thresh in thresholds:
        count = threshold_stats[thresh]['count']
        pct = threshold_stats[thresh]['percentage']
        detailed_stats += f"  ≤ {thresh} mm: {count}/{total_points} ({pct:.1f}%)\n"
    detailed_stats += f"  > {thresholds[-1]} mm: {above_last_th}/{total_points} ({above_last_th_pct:.1f}%)\n\n"
    detailed_stats += f"Total data points: {total_points}\n"
    detailed_stats += f"Analysis threshold: {threshold} mm\n"
    detailed_stats += f"Outliers (>{threshold}mm): {outlier_count} ({outlier_percentage:.1f}%)\n"
    
    with open(stats_file_path, 'w') as f:
        f.write(detailed_stats)
    
    print(f"  Statistics saved: {stats_file_path}")
    
    return {
        'mean': mean_acc,
        'median': median_acc,
        'std': std_acc,
        'min': min_acc,
        'max': max_acc,
        'outlier_count': outlier_count,
        'outlier_percentage': outlier_percentage,
        'total_points': total_points,
        'threshold_stats': threshold_stats,
        'above_3mm': above_last_th,
        'above_3mm_percentage': above_last_th_pct,
    }

def process_method_case(images_path, method_name, case_num, case_dir, output_dir, seed_name=None, error_threshold=3.0, crop_to_landmarks=False):
    """Process a single method-case combination."""
    seed_info = f" (seed: {seed_name})" if seed_name else ""
    print(f"Processing {method_name} - Case {case_num}{seed_info}")
    
    # Load metrics from existing metrics.json file
    metrics_data = load_metrics_from_json(case_dir)
    if metrics_data is None:
        print(f"Warning: No metrics.json found for {method_name} Case {case_num}{seed_info}")
        return None
    
    # Load original images for visualization
    try:
        (
            img_insp,
            img_exp,
            landmarks_insp,
            landmarks_exp,
            mask_exp,
            mask_insp,
            voxel_size,
        ) = general.load_image_DIRLab(case_num, f"{images_path}/Case")
    except Exception as e:
        print(f"Error loading DIR-LAB case {case_num}: {e}")
        return None
    
    # calc crop size using original landmarks
    crop_size, min_coords, max_coords = calculate_crop_from_landmarks(landmarks_insp, landmarks_exp, img_insp.shape, landmark_padding=[10, 10, 10])
    print(f'crop_size {crop_size}, min_coords {min_coords}, max_coords {max_coords}')
    
    # Load keypoints from JSON file
    kp_fixed, kp_fixed_warped, kp_moving = load_keypoints_from_json(case_dir)
    if kp_fixed is None:
        print(f"Warning: No keypoints.json found for {method_name} Case {case_num}{seed_info}")
        return None
    
    # do uncrop for sinr data as everything was computed with crop
    if 'sinr' in method_name.lower():
        kp_fixed = kp_fixed + min_coords
        kp_fixed_warped = kp_fixed_warped + min_coords
        kp_moving = kp_moving + min_coords

    print(f"  Loaded keypoints: {len(kp_fixed)} fixed, {len(kp_fixed_warped)} warped, {len(kp_moving)} moving")

    if crop_to_landmarks:
        kp_fixed = crop_kp_to_landmarks(kp_fixed, min_coords)
        kp_fixed_warped = crop_kp_to_landmarks(kp_fixed_warped, min_coords)
        kp_moving = crop_kp_to_landmarks(kp_moving, min_coords)
        img_insp = crop_img_to_landmarks(img_insp, min_coords, max_coords)
        img_exp = crop_img_to_landmarks(img_exp, min_coords, max_coords)
                                               
    # Create output directory for this method-case-seed combination
    if seed_name:
        case_output_dir = output_dir / method_name / seed_name / f"Case{case_num}"
    else:
        case_output_dir = output_dir / method_name / f"Case{case_num}"
    case_output_dir.mkdir(parents=True, exist_ok=True)
    
    # Get individual keypoint accuracies from metrics
    all_accuracies = np.array(metrics_data.get('all_accuracies_mm', metrics_data.get('all_accuracies', [])))
    
    # Create keypoint visualization with individual accuracy-based highlighting
    kp_overlay_path = case_output_dir / f"kp_overlay_Y_case_{case_num}.png"
    print(f'Creating visualization with error threshold: {error_threshold} mm')
    
    # Create mask based on individual accuracies
    keypoint_mask = None
    if len(all_accuracies) > 0 and error_threshold > 0:
        keypoint_mask = all_accuracies > error_threshold
        print(f"Highlighting {np.sum(keypoint_mask)}/{len(keypoint_mask)} keypoints with error > {error_threshold} mm")
    
    visualize_dirlab_warp_fixed_with_accuracy_mask(
        img_insp,
        img_exp,
        case_num,
        kp_fixed,
        kp_fixed_warped,
        kp_moving,
        keypoint_mask,
        axis=1,  # Y axis projection
        voxel_size=voxel_size,
        visualize=False,
        save_path=str(kp_overlay_path)
    )
    
    # Create accuracy histogram using loaded metrics
    if len(all_accuracies) > 0:
        # Also create acc_hist.png for compatibility with analyze_metrics
        acc_hist_path = case_output_dir / 'acc_hist.png'
        create_histogram(
            all_accuracies,
            method_name,
            str(case_num).zfill(2),
            acc_hist_path,
            threshold=3.0,
            x_limits=None,
            y_limits=None,
            show_max_outlier=True
        )
    
    # Extract metrics from loaded data
    accuracy_mean = metrics_data.get('accuracy_mean_mm', 0.0)
    accuracy_std = metrics_data.get('accuracy_std_mm', 0.0)
    folded_voxels_percent = metrics_data.get('folded_voxels_percent', 0.0)
    peak_memory_mb = metrics_data.get('peak_memory_mb', 0.0)
    
    print(f"  Accuracy: {accuracy_mean:.2f} ± {accuracy_std:.2f} mm")
    
    # Return metrics in the expected format
    return {
        'case_id': case_num,
        'method': method_name,
        'seed': seed_name,
        'accuracy_mean': accuracy_mean,
        'accuracy_std': accuracy_std,
        'all_accuracies': all_accuracies.tolist(),
        'folded_voxels_percent': folded_voxels_percent,
        'peak_memory_mb': peak_memory_mb,
        'voxel_size': voxel_size.tolist() if hasattr(voxel_size, 'tolist') else list(voxel_size)
    }

def summarize_results(methods_results, output_dir, xlimits=None, ylimits=None):
    """Create summary of all results."""
    summary = {}
    
    for method_name, cases_results in methods_results.items():
        if not cases_results:
            continue
            
        # Collect all metrics from all cases/seeds
        all_accuracies = []
        case_means = []
        folded_voxels = []
        peak_memories = []
        
        for result in cases_results:
            if result:
                case_means.append(result['accuracy_mean'])
                all_accuracies.extend(result['all_accuracies'])
                folded_voxels.append(result.get('folded_voxels_percent', 0.0))
                peak_memories.append(result.get('peak_memory_mb', 0.0))
        
        if case_means:
            # Compute statistics over all cases and seeds
            summary[method_name] = {
                'mean_accuracy': np.mean(case_means),
                'std_accuracy_across_cases': np.std(case_means),  # Std across case means
                'std_accuracy_all_landmarks': np.std(all_accuracies),  # Std across all landmarks
                'case_count': len(case_means),
                'case_means': case_means,
                'all_accuracies': all_accuracies,
                'folded_voxels_mean': np.mean(folded_voxels),
                'folded_voxels_std': np.std(folded_voxels),
                'peak_memory_mean': np.mean(peak_memories),
                'peak_memory_max': np.max(peak_memories) if peak_memories else 0.0
            }
            
            # Create overall accuracy histogram for this method
            if all_accuracies:
                method_output_dir = output_dir / method_name
                method_output_dir.mkdir(parents=True, exist_ok=True)
                overall_hist_path = method_output_dir / 'overall_accuracy_histogram.png'
                
                create_histogram(
                    all_accuracies,
                    method_name,
                    'all_cases',
                    overall_hist_path,
                    threshold=3.0,
                    x_limits=xlimits,
                    y_limits=ylimits,
                    show_max_outlier=True
                )
                print(f"Created overall histogram for {method_name}: {overall_hist_path}")
    
    # Save summary
    summary_path = output_dir / "results_summary.json"
    # Convert numpy types to native Python types for JSON serialization
    json_summary = {}
    for method, data in summary.items():
        json_summary[method] = {
            'mean_accuracy': float(data['mean_accuracy']),
            'std_accuracy_across_cases': float(data['std_accuracy_across_cases']),
            'std_accuracy_all_landmarks': float(data['std_accuracy_all_landmarks']),
            'case_count': int(data['case_count']),
            'case_means': [float(x) for x in data['case_means']],
            'all_accuracies': [float(x) for x in data['all_accuracies']],
            'folded_voxels_mean': float(data['folded_voxels_mean']),
            'folded_voxels_std': float(data['folded_voxels_std']),
            'peak_memory_mean': float(data['peak_memory_mean']),
            'peak_memory_max': float(data['peak_memory_max'])
        }
    
    with open(summary_path, 'w') as f:
        json.dump(json_summary, f, indent=2)
    
    # Print summary
    print("\n" + "="*80)
    print("RESULTS SUMMARY")
    print("="*80)
    
    for method_name, data in summary.items():
        print(f"{method_name}:")
        print(f"  Mean accuracy: {data['mean_accuracy']:.2f} ± {data['std_accuracy_across_cases']:.2f} mm (across cases)")
        print(f"  Overall std: {data['std_accuracy_all_landmarks']:.2f} mm (all landmarks)")
        print(f"  Folded voxels: {data['folded_voxels_mean']:.3f} ± {data['folded_voxels_std']:.3f} %")
        print(f"  Peak memory: {data['peak_memory_mean']:.0f} MB (mean), {data['peak_memory_max']:.0f} MB (max)")
        print(f"  Cases processed: {data['case_count']}")
        print()
    
    print(f"Summary saved to: {summary_path}")

def create_methods_collage(methods_results, output_dir, case_numbers=None, n_cols=4, spacing=(0, 0)):
    """Create collage of keypoint overlay images from different methods."""
    # Collect all method names and their result paths
    methods_data = {}
    
    for method_name, cases_results in methods_results.items():
        if not cases_results:
            continue
            
        method_images = {}
        
        for result in cases_results:
            if result is None:
                continue
            
            case_id = result['case_id']
            seed = result.get('seed', None)
            
            # Build path to the generated overlay image
            if seed:
                case_dir = output_dir / method_name / seed / f"Case{case_id}"
            else:
                case_dir = output_dir / method_name / f"Case{case_id}"
            
            overlay_path = case_dir / f"kp_overlay_Y_case_{case_id}.png"
            
            if overlay_path.exists():
                if case_id not in method_images:
                    method_images[case_id] = []
                method_images[case_id].append(overlay_path)
        
        if method_images:
            methods_data[method_name] = method_images
    
    if not methods_data:
        print("No overlay images found for collage creation!")
        return
    
    # Create collage for each case
    if case_numbers is None:
        case_numbers = set()
        for method_data in methods_data.values():
            case_numbers.update(method_data.keys())
        case_numbers = sorted(case_numbers)
    
    for case_num in case_numbers:
        create_case_collage(methods_data, case_num, output_dir, n_cols, spacing)

def create_histogram_collage(methods_results, output_dir, case_numbers=None, n_cols=4, spacing=(0, 0)):
    """Create collage of accuracy histogram images from different methods."""
    # Collect all method names and their histogram paths
    methods_data = {}
    
    for method_name, cases_results in methods_results.items():
        if not cases_results:
            continue
            
        method_images = {}
        
        for result in cases_results:
            if result is None:
                continue
            
            case_id = result['case_id']
            seed = result.get('seed', None)
            
            # Build path to the generated histogram image
            if seed:
                case_dir = output_dir / method_name / seed / f"Case{case_id}"
            else:
                case_dir = output_dir / method_name / f"Case{case_id}"
            
            hist_path = case_dir / 'acc_hist.png'
            
            if hist_path.exists():
                if case_id not in method_images:
                    method_images[case_id] = []
                method_images[case_id].append(hist_path)
        
        if method_images:
            methods_data[method_name] = method_images
    
    if not methods_data:
        print("No histogram images found for collage creation!")
        return
    
    # Create collage for each case
    if case_numbers is None:
        case_numbers = set()
        for method_data in methods_data.values():
            case_numbers.update(method_data.keys())
        case_numbers = sorted(case_numbers)
    
    for case_num in case_numbers:
        create_histogram_case_collage(methods_data, case_num, output_dir, n_cols, spacing)

def create_case_collage(methods_data, case_num, output_dir, n_cols=4, spacing=(0, 0)):
    """Create collage for a specific case across methods."""
    # Collect images for this case
    case_images = {}
    
    for method_name, method_cases in methods_data.items():
        if case_num in method_cases:
            # Use first available image for this method-case combination
            overlay_path = method_cases[case_num][0]
            case_images[method_name] = overlay_path
    
    if not case_images:
        print(f"No images found for Case {case_num}")
        return
    
    # Sort methods by number prefix if present
    def sort_key(method):
        match = re.match(r'^(\d+)', method)
        return int(match.group(1)) if match else float('inf')
    
    methods = sorted(case_images.keys(), key=sort_key)
    
    # Calculate layout
    n_methods = len(methods)
    n_cols = min(n_cols, n_methods)
    n_rows = (n_methods + n_cols - 1) // n_cols
    
    # Get image dimensions to calculate proper aspect ratio
    max_img_width = 0
    max_img_height = 0
    
    for method in methods:
        overlay_path = case_images[method]
        img = Image.open(overlay_path)
        max_img_width = max(max_img_width, img.width)
        max_img_height = max(max_img_height, img.height)
    
    # Calculate figure size proportional to actual image dimensions
    # Use the same DPI as we'll save with for consistency
    target_dpi = 300
    spacing_w, spacing_h = spacing if isinstance(spacing, tuple) else (spacing, spacing)
    
    # Include spacing in total dimensions
    total_width_inches = (n_cols * max_img_width + (n_cols - 1) * spacing_w) / target_dpi
    total_height_inches = (n_rows * max_img_height + (n_rows - 1) * spacing_h) / target_dpi

    # Create figure with subplots, specifying DPI
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(total_width_inches, total_height_inches), dpi=target_dpi)
    
    # Handle single row/column cases
    if n_rows == 1 and n_cols == 1:
        axes = [axes]
    elif n_rows == 1 or n_cols == 1:
        axes = axes.flatten()
    else:
        axes = axes.flatten()
    
    # Hide unused subplots
    for i in range(n_methods, len(axes)):
        axes[i].axis('off')
    
    for i, method in enumerate(methods):
        ax = axes[i]
        overlay_path = case_images[method]
        
        img = Image.open(overlay_path)
        
        ax.imshow(img)
        ax.axis('off')
        
        # Strip number prefix from method name for title
        display_name = re.sub(r'^\d+[-_]?', '', method).strip()
        
        # Position title at the top of the image
        ax.text(0.5, 0.98, display_name, transform=ax.transAxes, 
               fontsize=18, fontweight='bold', ha='center', va='top',
               bbox=dict(boxstyle='round,pad=0.1', facecolor='white', alpha=0.8, edgecolor='black'))
    
    # Adjust spacing - remove default padding and set custom spacing
    if spacing_w == 0 and spacing_h == 0:
        plt.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0, 
                           wspace=0.0, hspace=0.0)
    else:
        # Convert pixel spacing to relative spacing
        relative_spacing_w = spacing_w / (fig.get_figwidth() * fig.dpi / n_cols)
        relative_spacing_h = spacing_h / (fig.get_figheight() * fig.dpi / n_rows)
        plt.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0,
                           wspace=relative_spacing_w, hspace=relative_spacing_h)
    
    # Save collage
    collage_path = output_dir / f"kp_overlay_collage_case_{case_num}.png"
    plt.savefig(collage_path, dpi=target_dpi, bbox_inches='tight', facecolor='white', pad_inches=0.0)
    plt.close()
    
    print(f"Collage for Case {case_num} saved to: {collage_path}")

def create_histogram_case_collage(methods_data, case_num, output_dir, n_cols=4, spacing=(0, 0)):
    """Create histogram collage for a specific case across methods."""
    # Collect histogram images for this case
    case_images = {}
    
    for method_name, method_cases in methods_data.items():
        if case_num in method_cases:
            # Use first available histogram for this method-case combination
            hist_path = method_cases[case_num][0]
            case_images[method_name] = hist_path
    
    if not case_images:
        print(f"No histogram images found for Case {case_num}")
        return
    
    # Sort methods by number prefix if present
    def sort_key(method):
        match = re.match(r'^(\d+)', method)
        return int(match.group(1)) if match else float('inf')
    
    methods = sorted(case_images.keys(), key=sort_key)
    
    # Calculate layout
    n_methods = len(methods)
    n_cols = min(n_cols, n_methods)
    n_rows = (n_methods + n_cols - 1) // n_cols
    
    # Create figure with subplots
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(6 * n_cols, 5 * n_rows))
    
    # Handle single row/column cases
    if n_rows == 1 and n_cols == 1:
        axes = [axes]
    elif n_rows == 1 or n_cols == 1:
        axes = axes.flatten()
    else:
        axes = axes.flatten()
    
    # Hide unused subplots
    for i in range(n_methods, len(axes)):
        axes[i].axis('off')
    
    for i, method in enumerate(methods):
        ax = axes[i]
        hist_path = case_images[method]
        
        img = Image.open(hist_path)
        
        ax.imshow(img)
        ax.axis('off')
        
        # Strip number prefix from method name for title
        display_name = re.sub(r'^\d+[-_]?', '', method).strip()
        
        # Position title at the top of the image
        ax.text(0.55, 0.98, display_name, transform=ax.transAxes, 
               fontsize=14, fontweight='bold', ha='center', va='top',
               bbox=dict(boxstyle='round,pad=0.1', facecolor='white', alpha=0.8, edgecolor='black'))
    
    # Adjust spacing - remove default padding and set custom spacing
    spacing_w, spacing_h = spacing if isinstance(spacing, tuple) else (spacing, spacing)
    
    if spacing_w == 0 and spacing_h == 0:
        plt.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0, 
                           wspace=0.0, hspace=0.0)
    else:
        # Convert pixel spacing to relative spacing
        relative_spacing_w = spacing_w / (fig.get_figwidth() * fig.dpi / n_cols)
        relative_spacing_h = spacing_h / (fig.get_figheight() * fig.dpi / n_rows)
        plt.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0,
                           wspace=relative_spacing_w, hspace=relative_spacing_h)
    
    # Save histogram collage
    collage_path = output_dir / f"histogram_collage_case_{case_num}.png"
    plt.savefig(collage_path, dpi=300, bbox_inches='tight', facecolor='white', pad_inches=0.0)
    plt.close()
    
    print(f"Histogram collage for Case {case_num} saved to: {collage_path}")

def crop_image(img_path, target_width):
    """Crop image to target width in x direction."""
    img = Image.open(img_path)
    width, height = img.size
    
    if width > target_width:
        # Crop from left edge (0 to target_width)
        img = img.crop((0, 0, target_width, height))
    
    return img

def create_overall_histogram_collage(methods_results, output_dir, n_cols=4, spacing=(0, 0)):
    """Create collage of overall accuracy histogram images from different methods."""
    # Collect overall histogram paths for each method
    methods_data = {}
    
    for method_name, cases_results in methods_results.items():
        if not cases_results:
            continue
            
        # Look for overall histogram in method directory
        method_output_dir = output_dir / method_name
        overall_hist_path = method_output_dir / 'overall_accuracy_histogram.png'
        
        if overall_hist_path.exists():
            methods_data[method_name] = overall_hist_path
    
    if not methods_data:
        print("No overall histogram images found for collage creation!")
        return
    
    # Sort methods by number prefix if present
    def sort_key(method):
        match = re.match(r'^(\d+)', method)
        return int(match.group(1)) if match else float('inf')
    
    methods = sorted(methods_data.keys(), key=sort_key)
    
    # Calculate layout
    n_methods = len(methods)
    n_cols = min(n_cols, n_methods)
    n_rows = (n_methods + n_cols - 1) // n_cols
    
    # Get image dimensions to calculate proper aspect ratio
    max_img_width = 0
    max_img_height = 0
    
    for method in methods:
        hist_path = methods_data[method]
        img = Image.open(hist_path)
        max_img_width = max(max_img_width, img.width)
        max_img_height = max(max_img_height, img.height)
    
    # Calculate figure size proportional to actual image dimensions
    # Use the same DPI as we'll save with for consistency
    target_dpi = 300
    spacing_w, spacing_h = spacing if isinstance(spacing, tuple) else (spacing, spacing)
    
    # Include spacing in total dimensions
    total_width_inches = (n_cols * max_img_width + (n_cols - 1) * spacing_w) / target_dpi
    total_height_inches = (n_rows * max_img_height + (n_rows - 1) * spacing_h) / target_dpi

    # Create figure with subplots
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(total_width_inches, total_height_inches), dpi=target_dpi)
    
    # Handle single row/column cases
    if n_rows == 1 and n_cols == 1:
        axes = [axes]
    elif n_rows == 1 or n_cols == 1:
        axes = axes.flatten()
    else:
        axes = axes.flatten()
    
    # Hide unused subplots
    for i in range(n_methods, len(axes)):
        axes[i].axis('off')
    
    for i, method in enumerate(methods):
        ax = axes[i]
        hist_path = methods_data[method]
        
        img = Image.open(hist_path)
        
        ax.imshow(img)
        ax.axis('off')
        
        # Strip number prefix from method name for title
        display_name = re.sub(r'^\d+[-_]?', '', method).strip()
        
        # Position title at the top of the image
        ax.text(0.55, 0.98, display_name, transform=ax.transAxes, 
               fontsize=14, fontweight='bold', ha='center', va='top',
               bbox=dict(boxstyle='round,pad=0.1', facecolor='white', alpha=0.8, edgecolor='black'))
    
    # Adjust spacing - remove default padding and set custom spacing
    spacing_w, spacing_h = spacing if isinstance(spacing, tuple) else (spacing, spacing)
    
    if spacing_w == 0 and spacing_h == 0:
        plt.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0, 
                           wspace=0.0, hspace=0.0)
    else:
        # Convert pixel spacing to relative spacing
        relative_spacing_w = spacing_w / (fig.get_figwidth() * fig.dpi / n_cols)
        relative_spacing_h = spacing_h / (fig.get_figheight() * fig.dpi / n_rows)
        plt.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0,
                           wspace=relative_spacing_w, hspace=relative_spacing_h)
    
    # Save overall histogram collage
    collage_path = output_dir / "overall_histogram_collage.png"
    plt.savefig(collage_path, dpi=target_dpi, bbox_inches='tight', facecolor='white', pad_inches=0.0)
    plt.close()
    
    print(f"Overall histogram collage saved to: {collage_path}")

def main():
    parser = argparse.ArgumentParser(description='Process method results and create visualizations')
    parser.add_argument('images_path', type=str, 
                       help='Path to original DIR-LAB dataset (containing Case01Pack, Case02Pack, etc.)')
    parser.add_argument('results_path', type=str,
                       help='Path to method results (containing method directories)')
    parser.add_argument('--cases', '-k', type=str, default=None,
                       help='Case numbers: single (1), range (1-5), or comma-separated (1,3,5)')
    parser.add_argument('--output', '-o', type=str, default='processed_results',
                       help='Output directory for processed results')
    parser.add_argument('--methods', '-m', type=str, default=None,
                       help='Comma-separated list of method names to process (default: all)')
    parser.add_argument('--error_threshold', '-t', type=float, default=3.0,
                       help='Error threshold in mm - only highlight keypoints with error above this (default: 3.0)')
    parser.add_argument('--create_collage', '--collage', action='store_true',
                       help='Create collage of keypoint overlay images from different methods')
    parser.add_argument('--create_histogram_collage', '--hist_collage', action='store_true',
                       help='Create collage of accuracy histogram images from different methods')
    parser.add_argument('--columns', '-n', type=int, default=6,
                       help='Number of columns for collage layout (default: 6)')
    parser.add_argument('--crop_to_landmarks', action='store_true',
                       help='Crop images and keypoints to the bounding box defined by landmarks')
    parser.add_argument('--spacing', '-s', type=str, default='0,0',
                       help='Pixel spacing between images in collage as "width,height" (default: 0,0)')
    parser.add_argument('--xlimits', type=str, default=None,
                       help='X-axis limits for overall histogram collage as "min,max" (e.g., "0,10")')
    parser.add_argument('--ylimits', type=str, default=None,
                       help='Y-axis limits for overall histogram collage as "min,max" (e.g., "0,100")')

    args = parser.parse_args()
    
    # Validate paths
    images_path = Path(args.images_path)
    results_path = Path(args.results_path)
    output_dir = Path(args.output)
    
    if not images_path.exists():
        print(f"Error: Images path '{images_path}' does not exist!")
        return
    
    if not results_path.exists():
        print(f"Error: Results path '{results_path}' does not exist!")
        return
    
    # Parse case numbers
    case_numbers = parse_case_numbers(args.cases)
    print(f"Processing cases: {case_numbers}")
    
    # Find methods and cases
    methods = find_method_cases(results_path, case_numbers)
    if not methods:
        print("No method directories found!")
        return
    
    # Filter methods if specified
    if args.methods:
        method_filter = [m.strip() for m in args.methods.split(',')]
        methods = {k: v for k, v in methods.items() if k in method_filter}
    
    print(f"Found methods: {list(methods.keys())}")
    
    # Create output directory
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Process all method-case combinations
    methods_results = {}
    
    for method_name, cases in methods.items():
        print(f"\nProcessing method: {method_name}")
        methods_results[method_name] = []
        
        for case_info in cases:
            if len(case_info) == 3:
                case_num, case_dir, seed_name = case_info
            else:
                # Backward compatibility
                case_num, case_dir = case_info
                seed_name = None
                
            try:
                result = process_method_case(
                    str(images_path), 
                    method_name, 
                    case_num, 
                    case_dir, 
                    output_dir,
                    seed_name,
                    args.error_threshold,
                    args.crop_to_landmarks
                )
                methods_results[method_name].append(result)
            except Exception as e:
                seed_info = f" (seed: {seed_name})" if seed_name else ""
                print(f"Error processing {method_name} Case {case_num}{seed_info}: {e}")
                methods_results[method_name].append(None)

    # Parse xlimits parameter
    xlimits = None
    if args.xlimits:
        try:
            xlimits_parts = args.xlimits.split(',')
            if len(xlimits_parts) == 2:
                xlimits = (float(xlimits_parts[0]), float(xlimits_parts[1]))
            else:
                raise ValueError("Invalid xlimits format")
        except ValueError:
            print("Error: xlimits must be in format 'min,max' (e.g., '0,10')")
            return
    
    # Parse ylimits parameter
    ylimits = None
    if args.ylimits:
        try:
            ylimits_parts = args.ylimits.split(',')
            if len(ylimits_parts) == 2:
                ylimits = (float(ylimits_parts[0]), float(ylimits_parts[1]))
            else:
                raise ValueError("Invalid ylimits format")
        except ValueError:
            print("Error: ylimits must be in format 'min,max' (e.g., '0,100')")
            return
    
    # Create summary
    summarize_results(methods_results, output_dir, xlimits, ylimits)
    
    # Parse spacing parameter
    try:
        spacing_parts = args.spacing.split(',')
        if len(spacing_parts) == 1:
            spacing = (int(spacing_parts[0]), int(spacing_parts[0]))
        elif len(spacing_parts) == 2:
            spacing = (int(spacing_parts[0]), int(spacing_parts[1]))
        else:
            raise ValueError("Invalid spacing format")
    except ValueError:
        print("Error: Spacing must be in format 'width' or 'width,height'")
        return
    
    # Create collage if requested
    if args.create_collage:
        print("\nCreating keypoint overlay collages...")
        create_methods_collage(
            methods_results, 
            output_dir, 
            case_numbers=case_numbers, 
            n_cols=args.columns,
            spacing=spacing
        )

    # Create histogram collage if requested with filtered methods
    if args.create_histogram_collage:
        print("\nCreating histogram collages...")
        
        # Filter out specific methods from histogram collage
        exclude_methods_hist = ['1', '2', '3', '5'] # ['sinr', 'nodeo', 'idir', 'kan-idir']
        methods_results_hist = {}
        
        for method_name, results in methods_results.items():
            exclude_method = False
            for exclude_pattern in exclude_methods_hist:
                if exclude_pattern.lower() in method_name.lower():
                    print(f'  Excluding "{method_name}" from histogram collage')
                    exclude_method = True
                    break
            
            if not exclude_method:
                methods_results_hist[method_name] = results
        
        create_histogram_collage(
            methods_results_hist,
            output_dir,
            case_numbers=case_numbers,
            n_cols=1,
            spacing=(0, 10)
        )
        
        # Also create overall histogram collage
        print("\nCreating overall histogram collage...")
        create_overall_histogram_collage(
            methods_results_hist,
            output_dir,
            n_cols=2,
            spacing=(0, 10)
        )


if __name__ == "__main__":
    main()
