import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

def load_data(dataset_name, eaval_type):
    """Load data for a specific dataset and evaluation type."""
    df = pd.read_csv(f'/Users/qinjielin/Downloads/NWU/25corl/corl_ws/ObsScaling/eval_results/{dataset_name}_{eaval_type}_results_with_flops_and_openllm.csv')
    df_masked = pd.read_csv(f'/Users/qinjielin/Downloads/NWU/25corl/corl_ws/ObsScaling/eval_results/{dataset_name}_{eaval_type}_v4_results_with_flops_and_openllm.csv')
    
    if eaval_type == 'action_sequencing':
        metrics = ['task_success_rate', 'execution_success_rate']
    elif eaval_type == 'goal_interpretation':
        metrics = ['all_f1']
    else:
        raise ValueError(f"Invalid evaluation type: {eaval_type}")
    
    return df, df_masked, metrics

def get_valid_models(df, df_masked, metrics):
    """Get models that are valid (intersecting, non-zero, not excluded)."""
    # Find intersecting models
    common_models = set(df['Model']).intersection(set(df_masked['Model']))
    
    # Filter out unwanted models (lg, deepseekv3, kimi, qwen-)
    exclude_patterns = ['lg', 'deepseek', 'kimi', 'qwen-']
    excluded_models = []
    for model in common_models:
        model_lower = model.lower()
        if any(pattern in model_lower for pattern in exclude_patterns):
            excluded_models.append(model)
    
    filtered_models = [m for m in common_models if m not in excluded_models]
    
    # Filter dataframes to filtered models only
    df_common = df[df['Model'].isin(filtered_models)].copy()
    df_masked_common = df_masked[df_masked['Model'].isin(filtered_models)].copy()
    
    # Filter out models with zero values in any metric for either dataset
    valid_models = []
    for model in filtered_models:
        naive_vals = df_common[df_common['Model'] == model][metrics].values[0]
        structured_vals = df_masked_common[df_masked_common['Model'] == model][metrics].values[0]
        
        # Check if any value is zero in either dataset
        if not (np.any(naive_vals == 0) or np.any(structured_vals == 0)):
            valid_models.append(model)
    
    # Filter to valid models only
    df_common = df_common[df_common['Model'].isin(valid_models)].copy()
    df_masked_common = df_masked_common[df_masked_common['Model'].isin(valid_models)].copy()
    
    # Sort by model family and then by model size
    df_common = df_common.sort_values(['Model Family', 'Model Size (B)'])
    df_masked_common = df_masked_common.sort_values(['Model Family', 'Model Size (B)'])
    
    return df_common, df_masked_common, valid_models

def create_model_labels_and_colors(df_common):
    """Create model labels and family colors."""
    model_labels = []
    family_positions = []
    family_colors = {}
    current_pos = 0
    
    # Define color palette for families
    families = df_common['Model Family'].unique()
    color_palette = plt.cm.Set3(np.linspace(0, 1, len(families)))
    
    for i, family in enumerate(families):
        family_colors[family] = color_palette[i]
        family_models = df_common[df_common['Model Family'] == family]
        family_positions.append((current_pos, current_pos + len(family_models) - 1, family))
        
        for _, row in family_models.iterrows():
            model_name = row['Model'].split('/')[-1]
            size = row['Model Size (B)']
            model_labels.append(f"{model_name}\n({size}B)")
        
        current_pos += len(family_models)
    
    return model_labels, family_positions, family_colors

def plot_single_metric(ax, df_common, df_masked_common, metric_name, metric_list, model_labels, 
                      family_positions, family_colors, eaval_type, dataset_name):
    """Plot a single metric comparison."""
    # Prepare data for plotting
    x = np.arange(len(model_labels))
    width = 0.35
    
    # Calculate means for each model across the metric group
    naive_means = []
    structured_means = []
    bar_colors_naive = []
    bar_colors_structured = []
    
    for model in df_common['Model']:
        naive_vals = df_common[df_common['Model'] == model][metric_list].values[0]
        structured_vals = df_masked_common[df_masked_common['Model'] == model][metric_list].values[0]
        family = df_common[df_common['Model'] == model]['Model Family'].values[0]
        
        naive_means.append(np.mean(naive_vals))
        structured_means.append(np.mean(structured_vals))
        
        # Use family color with different alpha for naive vs structured
        base_color = family_colors[family]
        bar_colors_naive.append((*base_color[:3], 0.3))
        bar_colors_structured.append((*base_color[:3], 1.0))
    
    # Create bars with family-specific colors
    bars1 = ax.bar(x - width/2, naive_means, width, label='Naive Output', 
                  color=bar_colors_naive, edgecolor='black', linewidth=0.5)
    bars2 = ax.bar(x + width/2, structured_means, width, label='Structured Output', 
                  color=bar_colors_structured, edgecolor='black', linewidth=0.5)
    
    # Customize the plot
    ax.set_xlabel('Models', fontsize=10, fontweight='bold')
    if eaval_type == 'action_sequencing':
        ax.set_ylabel(metric_name, fontsize=10, fontweight='bold')
    else:
        ax.set_ylabel(f'{metric_name}', fontsize=10, fontweight='bold')
    
    ax.set_title(f'{metric_name}\n{dataset_name.title()} - {eaval_type.replace("_", " ").title()}', 
                fontsize=12, fontweight='bold', pad=15)
    ax.set_xticks(x)
    ax.set_xticklabels(model_labels, rotation=45, ha='right', fontsize=8)
    ax.legend(fontsize=9)
    ax.grid(True, alpha=0.3, linestyle='--')
    ax.set_ylim(0, max(max(naive_means), max(structured_means)) * 1.1)
    
    # Add family group separators and labels
    for start, end, family in family_positions:
        # Add vertical separators
        if start > 0:
            ax.axvline(x=start - 0.5, color='gray', linestyle='-', alpha=0.5, linewidth=1)
        
        # Add family labels with family color
        family_center = (start + end) / 2
        ax.text(family_center, max(max(naive_means), max(structured_means)) * 1.05, family, 
               ha='center', va='bottom', fontweight='bold', fontsize=9,
               bbox=dict(boxstyle='round,pad=0.3', facecolor=family_colors[family], alpha=0.8))

def create_combined_bar_plots(dataset_eval_pairs):
    """Create bar plots for multiple dataset-evaluation type pairs."""
    
    # Set up the plot style
    plt.style.use('default')
    
    # Calculate total number of plots needed
    total_plots = 0
    all_plot_info = []
    
    for dataset_name, eaval_type in dataset_eval_pairs:
        print(f"\nProcessing {dataset_name} - {eaval_type}...")
        
        # Load data
        df, df_masked, metrics = load_data(dataset_name, eaval_type)
        
        # Get valid models
        df_common, df_masked_common, valid_models = get_valid_models(df, df_masked, metrics)
        
        if len(valid_models) == 0:
            print(f"No valid models found for {dataset_name} - {eaval_type}")
            continue
        
        print(f"Found {len(valid_models)} valid models for {dataset_name} - {eaval_type}")
        
        # Create model labels and colors
        model_labels, family_positions, family_colors = create_model_labels_and_colors(df_common)
        
        # Define metric groups based on evaluation type
        if eaval_type == 'action_sequencing':
            metric_groups = {
                'Task Success Rate': ['task_success_rate'],
                'Execution Success Rate': ['execution_success_rate']
            }
        else:  # goal_interpretation
            metric_groups = {
                'F1 Score': ['all_f1']
            }
        
        # Store plot info
        for metric_name, metric_list in metric_groups.items():
            all_plot_info.append({
                'dataset_name': dataset_name,
                'eaval_type': eaval_type,
                'metric_name': metric_name,
                'metric_list': metric_list,
                'df_common': df_common,
                'df_masked_common': df_masked_common,
                'model_labels': model_labels,
                'family_positions': family_positions,
                'family_colors': family_colors
            })
            total_plots += 1
    
    if total_plots == 0:
        print("No valid plots to create.")
        return
    
    # Calculate grid dimensions (3 columns, as many rows as needed)
    n_cols = 3
    n_rows = (total_plots + n_cols - 1) // n_cols  # Ceiling division
    
    # Create the figure
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(8 * n_cols, 6 * n_rows))
    
    # Properly handle axes array for different scenarios
    if n_rows == 1 and n_cols == 1:
        axes = [axes]  # Single subplot case
    elif n_rows == 1:
        axes = axes  # Single row, multiple columns - axes is already 1D
    else:
        axes = axes.flatten()  # Multiple rows - flatten to 1D
    
    # Plot each metric
    for i, plot_info in enumerate(all_plot_info):
        ax = axes[i]
        
        plot_single_metric(
            ax, 
            plot_info['df_common'], 
            plot_info['df_masked_common'],
            plot_info['metric_name'],
            plot_info['metric_list'],
            plot_info['model_labels'],
            plot_info['family_positions'],
            plot_info['family_colors'],
            plot_info['eaval_type'],
            plot_info['dataset_name']
        )
    
    # Hide unused subplots
    for i in range(total_plots, len(axes)):
        axes[i].set_visible(False)
    
    plt.tight_layout()
    
    # Create filename
    pair_names = "_".join([f"{d}_{e}" for d, e in dataset_eval_pairs])
    plt.savefig(f'/Users/qinjielin/Downloads/NWU/25corl/corl_ws/ObsScaling/plots/combined_bar_comparison_{pair_names}.png', 
                dpi=300, bbox_inches='tight')
    plt.show()

def print_combined_summary_statistics(dataset_eval_pairs):
    """Print summary statistics for all dataset-evaluation pairs."""
    print(f"\n{'='*80}")
    print(f"COMBINED SUMMARY STATISTICS")
    print(f"{'='*80}")
    
    for dataset_name, eaval_type in dataset_eval_pairs:
        print(f"\n{'-'*40}")
        print(f"{dataset_name.upper()} - {eaval_type.upper()}")
        print(f"{'-'*40}")
        
        # Load data
        df, df_masked, metrics = load_data(dataset_name, eaval_type)
        
        # Get valid models
        df_common, df_masked_common, valid_models = get_valid_models(df, df_masked, metrics)
        
        if len(valid_models) == 0:
            print("No valid models found.")
            continue
        
        print(f"Valid models: {len(valid_models)}")
        print(f"Model families: {', '.join(df_common['Model Family'].unique())}")
        
        # Calculate overall improvements
        naive_overall = df_common[metrics].mean(axis=1).mean()
        structured_overall = df_masked_common[metrics].mean(axis=1).mean()
        improvement = structured_overall - naive_overall
        
        if eaval_type == 'action_sequencing':
            print(f"Naive Output Average: {naive_overall:.1%}")
            print(f"Structured Output Average: {structured_overall:.1%}")
            print(f"Overall Improvement: {improvement:+.1%}")
        else:
            print(f"Naive Output Average: {naive_overall:.3f}")
            print(f"Structured Output Average: {structured_overall:.3f}")
            print(f"Overall Improvement: {improvement:+.3f}")
        
        print(f"Relative Improvement: {(improvement/naive_overall)*100:+.1f}%")

if __name__ == "__main__":
    # Define the dataset-evaluation pairs you want to plot
    dataset_eval_pairs = [
        ('virtualhome', 'action_sequencing'),
        ('virtualhome', 'goal_interpretation'),
        # Add more pairs as needed, e.g.:
        # ('behavior', 'action_sequencing'),
        # ('behavior', 'goal_interpretation'),
    ]
    
    # Print summary statistics
    print_combined_summary_statistics(dataset_eval_pairs)
    
    # Create combined bar plots
    print("\nCreating combined bar plots...")
    create_combined_bar_plots(dataset_eval_pairs)
    
    print("\nCombined bar plot visualization complete! Plot saved to plots/ directory.")

