import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

dataset_name = 'virtualhome'
eaval_type = 'goal_interpretation' # goal_interpretation or action_sequencing

df = pd.read_csv(f'/Users/qinjielin/Downloads/NWU/25corl/corl_ws/ObsScaling/eval_results/{dataset_name}_{eaval_type}_results_with_flops_and_openllm.csv')
df_masked = pd.read_csv(f'/Users/qinjielin/Downloads/NWU/25corl/corl_ws/ObsScaling/eval_results/{dataset_name}_{eaval_type}_v4_results_with_flops_and_openllm.csv')

if eaval_type == 'action_sequencing':
    metrics = ['task_success_rate', 'execution_success_rate']
    # ['task_success_rate', 'execution_success_rate', 'total_goal', 'state_goal'],
    # ['relation_goal', 'action_goal', 'parsing_error', 'hallucination_error'],
    # ['wrong_order_error', 'missing_step_error', 'additional_step_error', 'affordance_error']

elif eaval_type == 'goal_interpretation':
    metrics = ['all_f1']
    # ['node_precision', 'edge_precision', 'action_precision', 'all_precision', 'node_recall', 'edge_recall', 'action_recall', 'all_recall', 'node_f1', 'edge_f1', 'action_f1', 'all_f1']
else:
    raise ValueError(f"Invalid evaluation type: {eaval_type}")

def create_bar_plots():
    """Create bar plots comparing naive vs structured output for intersecting models with no zero values."""
    
    # Find intersecting models
    common_models = set(df['Model']).intersection(set(df_masked['Model']))
    print(f"Found {len(common_models)} common models")
    
    # Filter dataframes to common models only
    df_common = df[df['Model'].isin(common_models)].copy()
    df_masked_common = df_masked[df_masked['Model'].isin(common_models)].copy()
    
    # Filter out unwanted models (lg, deepseekv3, kimi)
    exclude_patterns = ['lg', 'deepseek', 'kimi', 'qwen-']
    excluded_models = []
    for model in common_models:
        model_lower = model.lower()
        if any(pattern in model_lower for pattern in exclude_patterns):
            excluded_models.append(model)
    
    filtered_models = [m for m in common_models if m not in excluded_models]
    print(f"Excluded {len(excluded_models)} models (lg, deepseek, kimi): {len(filtered_models)} models remain")
    
    df_common = df_common[df_common['Model'].isin(filtered_models)].copy()
    df_masked_common = df_masked_common[df_masked_common['Model'].isin(filtered_models)].copy()
    
    # Filter out models with zero values in any metric for either dataset
    valid_models = []
    for model in filtered_models:
        naive_vals = df_common[df_common['Model'] == model][metrics].values[0]
        structured_vals = df_masked_common[df_masked_common['Model'] == model][metrics].values[0]
        
        # Check if any value is zero in either dataset
        if not (np.any(naive_vals == 0) or np.any(structured_vals == 0)):
            valid_models.append(model)
    
    print(f"After filtering out models with zero values: {len(valid_models)} models remain")
    
    # Filter to valid models only
    df_common = df_common[df_common['Model'].isin(valid_models)].copy()
    df_masked_common = df_masked_common[df_masked_common['Model'].isin(valid_models)].copy()
    
    if len(valid_models) == 0:
        print("No models found without zero values. Exiting.")
        return
    
    # Sort by model family and then by model size
    df_common = df_common.sort_values(['Model Family', 'Model Size (B)'])
    df_masked_common = df_masked_common.sort_values(['Model Family', 'Model Size (B)'])
    
    # Create model labels with family grouping and assign colors
    model_labels = []
    family_positions = []
    family_colors = {}
    current_pos = 0
    
    # Define color palette for families
    families = df_common['Model Family'].unique()
    color_palette = plt.cm.Set3(np.linspace(0, 1, len(families)))  # Using Set3 colormap for distinct colors
    
    for i, family in enumerate(families):
        family_colors[family] = color_palette[i]
        family_models = df_common[df_common['Model Family'] == family]
        family_positions.append((current_pos, current_pos + len(family_models) - 1, family))
        
        for _, row in family_models.iterrows():
            model_name = row['Model'].split('/')[-1]  # Get just the model name without org
            size = row['Model Size (B)']
            model_labels.append(f"{model_name}\n({size}B)")
        
        current_pos += len(family_models)
    
    # Set up the plot style
    plt.style.use('default')
    
    # Define metric groups based on evaluation type
    if eaval_type == 'action_sequencing':
        # For action sequencing, we have only 2 metrics
        metric_groups = {
            'Task Success Rate': ['task_success_rate'],
            'Execution Success Rate': ['execution_success_rate']
        }
        fig, axes = plt.subplots(1, 2, figsize=(16, 8))  # 1 row, 2 columns
    else:
        # For goal interpretation, use 3 plots in a row
        metric_groups = {
            # 'Precision': ['node_precision', 'edge_precision', 'action_precision', 'all_precision'],
            # 'Recall': ['node_recall', 'edge_recall', 'action_recall', 'all_recall'],
            # 'F1 Score': ['node_f1', 'edge_f1', 'action_f1', 'all_f1']
            'F1 Score': ['all_f1']
        }
        fig, axes = plt.subplots(1, 3, figsize=(24, 8))  # 1 row, 3 columns
    
    # Plot each metric group
    for idx, (group_name, group_metrics) in enumerate(metric_groups.items()):
        ax = axes[idx]
        
        # Prepare data for plotting
        x = np.arange(len(model_labels))
        width = 0.35
        
        # Calculate means for each model across the metric group
        naive_means = []
        structured_means = []
        bar_colors_naive = []
        bar_colors_structured = []
        
        for model in df_common['Model']:
            naive_vals = df_common[df_common['Model'] == model][group_metrics].values[0]
            structured_vals = df_masked_common[df_masked_common['Model'] == model][group_metrics].values[0]
            family = df_common[df_common['Model'] == model]['Model Family'].values[0]
            
            naive_means.append(np.mean(naive_vals))
            structured_means.append(np.mean(structured_vals))
            
            # Use family color with different alpha for naive vs structured
            base_color = family_colors[family]
            bar_colors_naive.append((*base_color[:3], 0.3))  # More transparent for naive
            bar_colors_structured.append((*base_color[:3], 1.0))  # Full opacity for structured
        
        # Create bars with family-specific colors
        bars1 = ax.bar(x - width/2, naive_means, width, label='Naive Output', 
                      color=bar_colors_naive, edgecolor='black', linewidth=0.5)
        bars2 = ax.bar(x + width/2, structured_means, width, label='Structured Output', 
                      color=bar_colors_structured, edgecolor='black', linewidth=0.5)
        
        # Customize the plot
        ax.set_xlabel('Models', fontsize=12, fontweight='bold')
        if eaval_type == 'action_sequencing':
            ax.set_ylabel(group_name, fontsize=12, fontweight='bold')
        else:
            ax.set_ylabel(f'Average {group_name}', fontsize=12, fontweight='bold')
        
        ax.set_title(f'{group_name} Comparison: Naive vs Structured Output\n{dataset_name.title()} Dataset (Non-zero values only)', 
                    fontsize=14, fontweight='bold', pad=20)
        ax.set_xticks(x)
        ax.set_xticklabels(model_labels, rotation=45, ha='right', fontsize=10)
        ax.legend(fontsize=11)
        ax.grid(True, alpha=0.3, linestyle='--')
        ax.set_ylim(0, max(max(naive_means), max(structured_means)) * 1.1)
        
        # Add family group separators and labels
        for start, end, family in family_positions:
            # Add vertical separators
            if start > 0:
                ax.axvline(x=start - 0.5, color='gray', linestyle='-', alpha=0.5, linewidth=1)
            
            # Add family labels with family color
            family_center = (start + end) / 2
            ax.text(family_center, max(max(naive_means), max(structured_means)) * 1.05, family, 
                   ha='center', va='bottom', fontweight='bold', fontsize=11,
                   bbox=dict(boxstyle='round,pad=0.3', facecolor=family_colors[family], alpha=0.8))
        
        # Add value labels on bars
        # for i, (bar1, bar2) in enumerate(zip(bars1, bars2)):
        #     height1 = bar1.get_height()
        #     height2 = bar2.get_height()
            
        #     if eaval_type == 'action_sequencing':
        #         # For success rates, show as percentages
        #         ax.text(bar1.get_x() + bar1.get_width()/2., height1 + 0.01,
        #                f'{height1:.1%}', ha='center', va='bottom', fontsize=8)
        #         ax.text(bar2.get_x() + bar2.get_width()/2., height2 + 0.01,
        #                f'{height2:.1%}', ha='center', va='bottom', fontsize=8)
        #     else:
        #         ax.text(bar1.get_x() + bar1.get_width()/2., height1 + 0.5,
        #                f'{height1:.1f}', ha='center', va='bottom', fontsize=8)
        #         ax.text(bar2.get_x() + bar2.get_width()/2., height2 + 0.5,
        #                f'{height2:.1f}', ha='center', va='bottom', fontsize=8)
    
    plt.tight_layout()
    plt.savefig(f'/Users/qinjielin/Downloads/NWU/25corl/corl_ws/ObsScaling/plots/bar_comparison_{dataset_name}_{eaval_type}_nonzero_filtered.png', 
                dpi=300, bbox_inches='tight')
    plt.show()

def print_summary_statistics():
    """Print summary statistics for the comparison."""
    
    # Find intersecting models
    common_models = set(df['Model']).intersection(set(df_masked['Model']))
    
    # Filter out unwanted models (lg, deepseekv3, kimi)
    exclude_patterns = ['lg', 'deepseek', 'kimi']
    excluded_models = []
    for model in common_models:
        model_lower = model.lower()
        if any(pattern in model_lower for pattern in exclude_patterns):
            excluded_models.append(model)
    
    filtered_models = [m for m in common_models if m not in excluded_models]
    
    # Filter out models with zero values
    valid_models = []
    for model in filtered_models:
        df_common_temp = df[df['Model'] == model]
        df_masked_temp = df_masked[df_masked['Model'] == model]
        
        if len(df_common_temp) > 0 and len(df_masked_temp) > 0:
            naive_vals = df_common_temp[metrics].values[0]
            structured_vals = df_masked_temp[metrics].values[0]
            
            # Check if any value is zero in either dataset
            if not (np.any(naive_vals == 0) or np.any(structured_vals == 0)):
                valid_models.append(model)
    
    # Filter dataframes to valid models only
    df_common = df[df['Model'].isin(valid_models)].copy()
    df_masked_common = df_masked[df_masked['Model'].isin(valid_models)].copy()
    
    print(f"\n{'='*60}")
    print(f"SUMMARY STATISTICS")
    print(f"{'='*60}")
    print(f"Dataset: {dataset_name}")
    print(f"Evaluation Type: {eaval_type}")
    print(f"Total common models: {len(common_models)}")
    print(f"Excluded models (lg, deepseek, kimi): {len(excluded_models)}")
    print(f"Models with non-zero values: {len(valid_models)}")
    print(f"Model families: {', '.join(df_common['Model Family'].unique())}")
    
    if len(valid_models) == 0:
        print("No models found without zero values.")
        return
    
    # Calculate overall improvements
    naive_overall = df_common[metrics].mean(axis=1).mean()
    structured_overall = df_masked_common[metrics].mean(axis=1).mean()
    improvement = structured_overall - naive_overall
    
    print(f"\nOVERALL PERFORMANCE:")
    if eaval_type == 'action_sequencing':
        print(f"Naive Output Average: {naive_overall:.1%}")
        print(f"Structured Output Average: {structured_overall:.1%}")
        print(f"Overall Improvement: {improvement:+.1%}")
        print(f"Relative Improvement: {(improvement/naive_overall)*100:+.1f}%")
    else:
        print(f"Naive Output Average: {naive_overall:.3f}")
        print(f"Structured Output Average: {structured_overall:.3f}")
        print(f"Overall Improvement: {improvement:+.3f}")
        print(f"Relative Improvement: {(improvement/naive_overall)*100:+.1f}%")
    
    # Best and worst performing models
    naive_model_scores = df_common[metrics].mean(axis=1)
    structured_model_scores = df_masked_common[metrics].mean(axis=1)
    improvements = structured_model_scores.values - naive_model_scores.values
    
    best_idx = np.argmax(improvements)
    worst_idx = np.argmin(improvements)
    
    best_model = df_common.iloc[best_idx]['Model']
    worst_model = df_common.iloc[worst_idx]['Model']
    
    print(f"\nMODEL PERFORMANCE:")
    if eaval_type == 'action_sequencing':
        print(f"Best Improvement: {df_common.iloc[best_idx]['Model Family']}-{best_model.split('/')[-1]} ({improvements[best_idx]:+.1%})")
        print(f"Worst Improvement: {df_common.iloc[worst_idx]['Model Family']}-{worst_model.split('/')[-1]} ({improvements[worst_idx]:+.1%})")
    else:
        print(f"Best Improvement: {df_common.iloc[best_idx]['Model Family']}-{best_model.split('/')[-1]} ({improvements[best_idx]:+.3f})")
        print(f"Worst Improvement: {df_common.iloc[worst_idx]['Model Family']}-{worst_model.split('/')[-1]} ({improvements[worst_idx]:+.3f})")

if __name__ == "__main__":
    # Print summary statistics
    print_summary_statistics()
    
    # Create bar plots
    print("\nCreating bar plots...")
    create_bar_plots()
    
    print("\nBar plot visualization complete! Plot saved to plots/ directory.")

