import os
import pandas as pd
import matplotlib.pyplot as plt
import re
import warnings

warnings.filterwarnings("ignore")

# Configuration settings
base_dir = 'runs/detect'  # Main directory containing YOLO results
output_dir = 'yolo_plots'  # Directory to save plots
os.makedirs(output_dir, exist_ok=True)  # Create output directory

# Columns to extract and plot
target_columns = [
    'train/box_loss',
    'metrics/precision(B)',
    'metrics/recall(B)',
    'metrics/mAP50(B)',
    'metrics/mAP50-95(B)',
    'val/box_loss'
]
column_titles = [
    'Train Box Loss',
    'Precision',
    'Recall',
    'mAP@0.5',
    'mAP@0.5:0.95',
    'Val Box Loss'
]

# Color palettes
version_palette = ['#ABC6E4', '#C39398', '#FCDABA', '#A7D2BA', '#D0CADE']  # For version comparison
best_palette = ['#32037D', '#7C1A97', '#C94E65', '#D9995B']  # For best models comparison

# Find all YOLO directories
model_dirs = [d for d in os.listdir(base_dir) 
             if (d.startswith('yolo') or d.startswith('V')) and os.path.isdir(os.path.join(base_dir, d))]

if not model_dirs:
    print("No YOLO model directories found!")
    exit()

# Store all model data
all_data = {}
model_info = {}  # Store model version and size information

print(f"Found {len(model_dirs)} YOLO model directories:")
for model_dir in model_dirs:
    csv_path = os.path.join(base_dir, model_dir, 'results.csv')
    
    if not os.path.exists(csv_path):
        print(f"  ⚠️ results.csv not found in {model_dir}, skipping")
        continue
    
    print(f"  ✓ Processing {model_dir}...")
    try:
        # Read CSV file
        df = pd.read_csv(csv_path)
        
        # Check for required columns
        missing_cols = [col for col in ['epoch'] + target_columns if col not in df.columns]
        if missing_cols:
            print(f"    ⚠️ Missing columns: {', '.join(missing_cols)}, skipping")
            continue
        
        # Extract model version and size information
        version = "Unknown"
        size = "Unknown"
        
        # Try matching naming patterns
        match_yolo = re.match(r'[yY]olo(\d+)([a-zA-Z]+)', model_dir)
        match_v = re.match(r'[vV](\d+)([a-zA-Z]+)', model_dir)
        
        if match_yolo:
            version = f"v{match_yolo.group(1)}"
            size = match_yolo.group(2).upper()
            model_name = f"{version}-{size}"
        elif match_v:
            version = f"v{match_v.group(1)}"
            size = match_v.group(2).upper()
            model_name = f"{version}-{size}"
        else:
            # Fallback extraction if patterns don't match
            version_match = re.search(r'(\d+)', model_dir)
            if version_match:
                version = f"v{version_match.group(1)}"
            size_match = re.search(r'([a-zA-Z]+)$', model_dir)
            if size_match:
                size = size_match.group(1).upper()
            model_name = model_dir.upper()
        
        # Store data
        all_data[model_name] = df[['epoch'] + target_columns]
        model_info[model_name] = {'version': version, 'size': size, 'dir': model_dir}
    
    except Exception as e:
        print(f"    ❌ Error processing {model_dir}: {str(e)}")

if not all_data:
    print("No valid data to process!")
    exit()

# Group models by version
version_groups = {}
for model_name, data in all_data.items():
    version = model_info[model_name]['version']
    if version not in version_groups:
        version_groups[version] = {}
    version_groups[version][model_name] = data

print("\nGrouping results by version:")
for version, models in version_groups.items():
    print(f"  {version}: {', '.join(models.keys())}")

# Find best model for each version (based on max mAP@0.5:0.95)
best_models = {}
for version, models in version_groups.items():
    best_model = None
    best_mAP = -1
    best_epoch = -1
    
    for model_name, df in models.items():
        max_mAP = df['metrics/mAP50-95(B)'].max()
        if max_mAP > best_mAP:
            best_mAP = max_mAP
            best_model = model_name
            best_epoch = df['metrics/mAP50-95(B)'].idxmax() + 1  # Get epoch of max mAP
    
    if best_model:
        best_models[version] = {
            'model': best_model,
            'mAP': best_mAP,
            'epoch': best_epoch,
            'size': model_info[best_model]['size']
        }
        print(f"  {version} best model: {best_model} (mAP@0.5:0.95 = {best_mAP:.4f} @ epoch {best_epoch})")

# Find the global best model (highest mAP among all best models)
global_best_model = None
global_best_mAP = -1
global_best_epoch = -1
global_best_version = None

for version, info in best_models.items():
    if info['mAP'] > global_best_mAP:
        global_best_mAP = info['mAP']
        global_best_model = info['model']
        global_best_epoch = info['epoch']
        global_best_version = version

if global_best_model:
    print(f"\nGlobal best model: {global_best_model} (mAP@0.5:0.95 = {global_best_mAP:.4f} @ epoch {global_best_epoch})")

# Plot comparison charts for each version (2x3 layout)
print("\nGenerating version comparison plots...")
for version, models in version_groups.items():
    if len(models) == 0:
        continue
        
    # Create 2x3 figure
    fig, axes = plt.subplots(2, 3, figsize=(24, 16))
    fig.suptitle(f'YOLO {version} Model Performance Comparison', fontsize=28, fontweight='bold')
    
    # Use the version palette with soft colors
    line_styles = ['-', '--', '-.', ':'] * 3
    
    # Create subplot for each metric
    for idx, (column, title) in enumerate(zip(target_columns, column_titles)):
        ax = axes[idx//3, idx%3]
        
        # Set Y-axis label based on metric type
        if 'loss' in column:
            ax.set_ylabel('Loss', fontsize=16)
        else:
            ax.set_ylabel('Metric Value', fontsize=16)
            if 'mAP' in column:
                ax.set_ylim(0, 1)  # Set 0-1 range for metrics
        
        ax.set_xlabel('Epoch', fontsize=16)
        ax.set_title(title, fontsize=18)
        ax.grid(True, linestyle='--', alpha=0.7)
        ax.tick_params(axis='both', which='major', labelsize=14)
        
        # Plot data for each model
        for j, (model_name, df) in enumerate(models.items()):
            # Use the specified color palette
            color = version_palette[j % len(version_palette)]
            style_idx = j % len(line_styles)
            size = model_info[model_name]['size']
            
            ax.plot(
                df['epoch'], 
                df[column],
                label=f"{size}",
                color=color,
                linestyle=line_styles[style_idx],
                linewidth=3,
                alpha=0.9
            )
            
            # Add special marker for best model in this version
            if best_models.get(version) and best_models[version]['model'] == model_name:
                best_epoch = best_models[version]['epoch']
                if best_epoch <= len(df):
                    best_value = df[column].iloc[best_epoch-1]
                    ax.plot(
                        best_epoch, 
                        best_value,
                        marker='*',
                        markersize=15,
                        color='gold',
                        markeredgecolor='black'
                    )
        
        # Add legend
        if idx == 0:
            ax.legend(fontsize=14, loc='best', title='Model Size')
    
    plt.tight_layout(rect=[0, 0, 1, 0.96])  # Adjust for title
    
    # Save plot
    save_path = os.path.join(output_dir, f'{version}_comparison.png')
    plt.savefig(save_path, dpi=300)
    print(f"  Saved: {save_path}")
    plt.close()

# Plot comparison of best models (2x3 layout)
print("\nGenerating best models comparison plot...")
if best_models:
    # Create 2x3 figure
    fig, axes = plt.subplots(2, 3, figsize=(24, 16))
    fig.suptitle('Performance Comparison of Best Models from Each YOLO Version', 
                 fontsize=28, fontweight='bold')
    
    # Use the best models palette with vibrant colors
    line_styles = ['-', '--', '-.', ':'] * 3
    
    # Create subplot for each metric
    for idx, (column, title) in enumerate(zip(target_columns, column_titles)):
        ax = axes[idx//3, idx%3]
        
        # Set Y-axis label based on metric type
        if 'loss' in column:
            ax.set_ylabel('Loss', fontsize=16)
        else:
            ax.set_ylabel('Metric Value', fontsize=16)
            if 'mAP' in column:
                ax.set_ylim(0, 1)  # Set 0-1 range for metrics
        
        ax.set_xlabel('Epoch', fontsize=16)
        ax.set_title(title, fontsize=18)
        ax.grid(True, linestyle='--', alpha=0.7)
        ax.tick_params(axis='both', which='major', labelsize=14)
        
        # Plot data for each best model
        for j, (version, best_info) in enumerate(best_models.items()):
            # Use the specified color palette
            color = best_palette[j % len(best_palette)]
            style_idx = j % len(line_styles)
            model_name = best_info['model']
            size = best_info['size']
            df = all_data[model_name]
            
            ax.plot(
                df['epoch'], 
                df[column],
                label=f"{version}-{size}",
                color=color,
                linestyle=line_styles[style_idx],
                linewidth=3.5,
                alpha=0.9
            )
            
            # Mark best point ONLY for the global best model
            if model_name == global_best_model:
                best_epoch = best_info['epoch']
                if best_epoch <= len(df):
                    best_value = df[column].iloc[best_epoch-1]
                    ax.plot(
                        best_epoch, 
                        best_value,
                        marker='*',
                        markersize=18,
                        color='gold',
                        markeredgecolor='black'
                    )
        
        # Add legend
        if idx == 0:
            ax.legend(fontsize=14, loc='best', title='Model Version')
    
    plt.tight_layout(rect=[0, 0, 1, 0.96])  # Adjust for title
    
    # Save plot
    save_path = os.path.join(output_dir, 'best_models_comparison.png')
    plt.savefig(save_path, dpi=300)
    print(f"  Saved: {save_path}")
    plt.close()

print(f"\nAll plots saved to: {os.path.abspath(output_dir)}")
print("Processing complete!")