import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from pathlib import Path
import argparse
from typing import Dict, List, Tuple, Optional
import re


def get_marker_and_color(model_family: str) -> Tuple[str, str]:
    """
    Get marker and color for a model family.
    
    Args:
        model_family: Model family name
        
    Returns:
        Tuple of (marker, color)
    """
    # Define consistent markers and colors for each model family
    family_styles = {
        "Qwen": ("^", "#1f77b4"),      # Upward triangle, blue
        "Qwen1.5": ("D", "#ff7f0e"),   # Diamond, orange
        "Qwen3": ("^", "#1f77b4"),     # Upward triangle, blue (same as Qwen)
        "Llama-2": ("o", "#2ca02c"),   # Circle, green
        "Llama-3": ("s", "#d62728"),   # Square, red
        "Llama": ("*", "#9467bd"),     # Asterisk, purple
        "DeepSeek": ("P", "#8c564b"),  # Plus (filled), brown
        "GPT-OSS": ("*", "#e377c2"),   # Asterisk, pink
        "Mistral": ("o", "#17becf"),   # Circle, cyan
        "Yi": ("v", "#d62728"),        # Downward triangle, red
        "Baichuan": ("<", "#e377c2"),  # Left triangle, pink
        "Granite": (">", "#bcbd22"),   # Right triangle, olive
        "Gemma": ("h", "#7f7f7f"),    # Hexagon, gray
        "Microsoft": ("H", "#9467bd"), # Filled hexagon, purple
        "Google": ("p", "#17becf"),    # Plus, cyan
        "Exaone": ("8", "#ff9896"),    # Octagon, light red
        "Cohere": ("D", "#c5b0d5"),   # Diamond, light purple
        "Kimi": ("P", "#fdb462"),     # Plus (filled), light orange
    }
    
    return family_styles.get(model_family, ("o", "#7f7f7f"))  # Default: circle, gray


def load_results_data(csv_path: str) -> pd.DataFrame:
    """
    Load results data from CSV file.
    
    Args:
        csv_path: Path to the CSV file
        
    Returns:
        DataFrame with results data
    """
    df = pd.read_csv(csv_path)
    
    # Use the Model Family column directly instead of computing it
    # The column is already named 'Model Family' in the CSV
    
    # Convert Model Size (B) to log scale for plotting
    df['log_model_size'] = np.log10(df['Model Size (B)'].fillna(0.1))
    
    return df


def create_scaling_plots(df: pd.DataFrame, output_path: str = "scaling_plots.png"):
    """
    Create scaling plots for all metrics with legends at the top of each row.
    
    Args:
        df: DataFrame with results data
        output_path: Path to save the plot
    """
    # Define metrics to plot
    metrics = {
        'task_success_rate': 'Task Success Rate (%)',
        'execution_success_rate': 'Execution Success Rate (%)',
        'total_goal': 'Total Goal (%)',
        'state_goal': 'State Goal (%)',
        'relation_goal': 'Relation Goal (%)',
        'action_goal': 'Action Goal (%)',
        'parsing_error': 'Parsing Error (%)',
        'hallucination_error': 'Hallucination Error (%)',
        'wrong_order_error': 'Wrong Order Error (%)',
        'missing_step_error': 'Missing Step Error (%)',
        'additional_step_error': 'Additional Step Error (%)',
        'affordance_error': 'Affordance Error (%)'
    }
    
    # Create subplots
    n_metrics = len(metrics)
    n_cols = 4
    n_rows = (n_metrics + n_cols - 1) // n_cols
    
    # Adjust figure size to be less skinny - increase width and height
    fig_width = 24  # Increased from 20
    fig_height = 6 * n_rows  # Increased from 5 * n_rows
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height))
    
    if n_rows == 1:
        axes = axes.reshape(1, -1)
    
    # Get unique model families for legend
    families = sorted(df['Model Family'].unique())
    
    # Plot each metric
    for idx, (metric, title) in enumerate(metrics.items()):
        row = idx // n_cols
        col = idx % n_cols
        
        if n_rows == 1:
            ax = axes[col]
        else:
            ax = axes[row, col]
        
        # Plot each family with consistent markers and colors
        for family in families:
            family_data = df[df['Model Family'] == family]
            if len(family_data) == 0:
                continue
                
            marker, color = get_marker_and_color(family)
            
            # Plot the data points
            ax.scatter(family_data['log_model_size'], 
                      family_data[metric], 
                      marker=marker, 
                      color=color, 
                      s=100, 
                      alpha=0.7,
                      label=family)
        
        # Customize the plot
        ax.set_xlabel('Log10(Model Size (B))', fontsize=12)
        ax.set_ylabel(title, fontsize=12)
        ax.set_title(title, fontsize=14, fontweight='bold')
        ax.grid(True, alpha=0.3)
        
        # Set reasonable y-axis limits
        if metric in df.columns:
            # Get valid data (remove NaN values)
            valid_data = df[metric].dropna()
            if len(valid_data) > 0:
                min_val = valid_data.min()
                max_val = valid_data.max()
                # Add some padding to the range
                padding = (max_val - min_val) * 0.1
                ax.set_ylim(min_val - padding, max_val + padding)
            # If no valid data, let matplotlib auto-scale
        # If metric doesn't exist, let matplotlib auto-scale
    
    # Remove empty subplots
    for idx in range(len(metrics), n_rows * n_cols):
        row = idx // n_cols
        col = idx % n_cols
        if n_rows == 1:
            fig.delaxes(axes[col])
        else:
            fig.delaxes(axes[row, col])
    
    # Add legends at the top of each row
    for row in range(n_rows):
        # Create legend handles and labels
        handles = []
        labels = []
        for family in families:
            marker, color = get_marker_and_color(family)
            handles.append(plt.Line2D([0], [0], marker=marker, color=color, 
                                    markersize=15, linestyle='', alpha=0.7))
            labels.append(family)
        
        # Calculate the center position for the legend
        if n_rows == 1:
            legend_x = 0.5
            legend_y = 0.98
        else:
            legend_x = 0.5
            legend_y = 0.98 - (row * 0.33)
        
        # Add legend at the top center of the row
        # Calculate how many columns to use for 2 rows
        n_families = len(families)
        ncol = (n_families + 1) // 2  # This will create 2 rows
        
        legend = fig.legend(handles, labels, loc='upper center', 
                          bbox_to_anchor=(legend_x, legend_y), 
                          ncol=ncol, frameon=True, 
                          fancybox=True, shadow=True, fontsize=14)
        legend.get_frame().set_facecolor('white')
        legend.get_frame().set_alpha(0.9)
    
    # Adjust layout to make room for legends
    plt.tight_layout()
    plt.subplots_adjust(top=0.92, hspace=0.4, wspace=0.3)
    
    # Save the plot
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"Scaling plots saved to: {output_path}")
    
    return fig


def create_focused_plots(df: pd.DataFrame, output_path: str = "focused_scaling_plots.png"):
    """
    Create focused plots for key metrics only with legend at the top.
    
    Args:
        df: DataFrame with results data
        output_path: Path to save the plot
    """
    # Define key metrics for focused plots
    key_metrics = {
        'task_success_rate': 'Task Success Rate (%)',
        'execution_success_rate': 'Execution Success Rate (%)',
        'total_goal': 'Total Goal (%)',
        'parsing_error': 'Parsing Error (%)'
    }
    
    # Create subplots
    n_metrics = len(key_metrics)
    fig, axes = plt.subplots(1, n_metrics, figsize=(20, 5))
    
    # Get unique model families for legend
    families = sorted(df['Model Family'].unique())
    
    # Plot each key metric
    for idx, (metric, title) in enumerate(key_metrics.items()):
        ax = axes[idx]
        
        # Plot each family with consistent markers and colors
        for family in families:
            family_data = df[df['Model Family'] == family]
            if len(family_data) == 0:
                continue
                
            marker, color = get_marker_and_color(family)
            
            # Plot the data points
            ax.scatter(family_data['log_model_size'], 
                      family_data[metric], 
                      marker=marker, 
                      color=color, 
                      s=120, 
                      alpha=0.8,
                      label=family)
        
        # Customize the plot
        ax.set_xlabel('Log10(Model Size (B))', fontsize=12)
        ax.set_ylabel(title, fontsize=12)
        ax.set_title(title, fontsize=14, fontweight='bold')
        ax.grid(True, alpha=0.3)
        
        # Set reasonable y-axis limits
        if metric in df.columns:
            # Get valid data (remove NaN values)
            valid_data = df[metric].dropna()
            if len(valid_data) > 0:
                min_val = valid_data.min()
                max_val = valid_data.max()
                # Add some padding to the range
                padding = (max_val - min_val) * 0.1
                ax.set_ylim(min_val - padding, max_val + padding)
            # If no valid data, let matplotlib auto-scale
        # If metric doesn't exist, let matplotlib auto-scale
    
    # Add legend at the top
    handles = []
    labels = []
    for family in families:
        marker, color = get_marker_and_color(family)
        handles.append(plt.Line2D([0], [0], marker=marker, color=color, 
                                markersize=15, linestyle='', alpha=0.8))
        labels.append(family)
    
    # Calculate how many columns to use for 2 rows
    n_families = len(families)
    ncol = (n_families + 1) // 2  # This will create 2 rows
    
    # Add legend at the top of the figure
    fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 1.05), 
              ncol=ncol, frameon=True, fancybox=True, shadow=True, fontsize=14)
    
    # Adjust layout
    plt.tight_layout()
    plt.subplots_adjust(top=0.85)  # Make room for legend at the top
    
    # Save the plot
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"Focused scaling plots saved to: {output_path}")
    
    return fig


def main():
    """
    Main function to create scaling plots.
    """
    parser = argparse.ArgumentParser(description="Create scaling plots for evaluation results")
    parser.add_argument("--input", type=str, required=True,
                       help="Path to the CSV file with results data")
    parser.add_argument("--output-dir", type=str, default="plots",
                       help="Output directory for plots (default: plots)")
    parser.add_argument("--focused-only", action="store_true",
                       help="Only create focused plots (key metrics only)")
    
    args = parser.parse_args()
    
    # Create output directory
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Load data
    print(f"Loading data from: {args.input}")
    df = load_results_data(args.input)
    
    print(f"Loaded {len(df)} models with {len(df.columns)} metrics")
    print(f"Model families found: {sorted(df['Model Family'].unique())}")
    
    # Set style
    plt.style.use('seaborn-v0_8')
    sns.set_palette("husl")
    
    if not args.focused_only:
        # Create comprehensive plots
        comprehensive_path = output_dir / "comprehensive_scaling_plots.png"
        create_scaling_plots(df, comprehensive_path)
    
    # Create focused plots
    focused_path = output_dir / "focused_scaling_plots.png"
    create_focused_plots(df, focused_path)
    
    print(f"Plots saved to: {output_dir}")
    
    return df


if __name__ == "__main__":
    df = main()
