import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from utils import *


def create_correlation_heatmap(simulation_name, task_name, merged_df):
    """
    Create correlation heatmap for a specific simulation and task.
    
    Args:
        simulation_name (str): Name of the simulation (behavior/virtualhome)
        task_name (str): Name of the task (action_sequencing/goal_interpretation)
        merged_df (pd.DataFrame): Merged dataframe with results
        
    Returns:
        pd.DataFrame: Cross-correlation matrix with renamed columns
    """
    # Define the columns for correlation analysis
    base_metrics = ['GPQA', 'MUSR', 'IFEval', 'MMLU-PRO', 'BBH', 'MATH Lvl 5', 'Average']
    
    if task_name == "action_sequencing":
        eai_metrics = ['task_success_rate', 'state_goal', 'relation_goal', 'action_goal', 'total_goal', 
                      'execution_success_rate', 'parsing_error', 'hallucination_error', 
                      'predicate_argument_number_error', 'wrong_order_error', 'missing_step_error', 
                      'affordance_error', 'additional_step_error']
    else:
        eai_metrics = ['node_precision', 'node_recall', 'node_f1', 'edge_precision', 'edge_recall', 'edge_f1', 
                      'action_precision', 'action_recall', 'action_f1', 'all_precision', 'all_recall', 'all_f1']

    # Create mapping for better x-axis labels
    eai_label_mapping = {
        'task_success_rate': 'Task Success',
        'state_goal': 'State Goal',
        'relation_goal': 'Relation Goal', 
        'action_goal': 'Action Goal',
        'total_goal': 'Total Goal',
        'execution_success_rate': 'Execution Success',
        'parsing_error': 'Parsing',
        'hallucination_error': 'Hallucination',
        'predicate_argument_number_error': 'Predicate Arg',
        'wrong_order_error': 'Wrong Order',
        'missing_step_error': 'Missing Step',
        'affordance_error': 'Affordance',
        'additional_step_error': 'Additional Step',
        # For goal_interpretation metrics
        'node_precision': 'Node Precision',
        'node_recall': 'Node Recall', 
        'node_f1': 'Node F1',
        'edge_precision': 'Edge Precision',
        'edge_recall': 'Edge Recall',
        'edge_f1': 'Edge F1',
        'action_precision': 'Action Precision',
        'action_recall': 'Action Recall',
        'action_f1': 'Action F1',
        'all_precision': 'All Precision',
        'all_recall': 'All Recall',
        'all_f1': 'All F1'
    }

    # Process error columns (convert to success rates)
    for col in eai_metrics:
        if col.endswith('error'):
            print(f"Processing {col}")
            merged_df[col] = 100 - merged_df[col]

    # Create correlation matrix
    correlation_matrix = merged_df[base_metrics + eai_metrics].corr()
    
    # Extract only the cross-correlations between base metrics and EAI metrics
    cross_correlation = correlation_matrix.loc[base_metrics, eai_metrics]
    
    # Remove columns that are all NaN
    cross_correlation = cross_correlation.dropna(axis=1, how='all')
    
    # Rename the columns (x-axis labels) using the mapping
    cross_correlation = cross_correlation.rename(columns=eai_label_mapping)
    
    return cross_correlation


def plot_multiple_heatmaps(simulation_task_pairs):
    """
    Create multiple heatmaps in a single figure.
    
    Args:
        simulation_task_pairs (list): List of tuples [(simulation_name, task_name), ...]
    """
    n_plots = len(simulation_task_pairs)
    
    # Force single row layout
    cols = n_plots
    rows = 1
    
    # Store correlation matrices to calculate consistent sizing
    correlation_matrices = []
    
    # First pass: collect all correlation matrices to find max dimensions
    for simulation_name, task_name in simulation_task_pairs:
        try:
            merged_df = pd.read_csv(f'./eval_results/{simulation_name}_{task_name}_results_with_flops_and_openllm.csv')
            cross_correlation = create_correlation_heatmap(simulation_name, task_name, merged_df)
            if not cross_correlation.empty:
                correlation_matrices.append(cross_correlation)
        except FileNotFoundError:
            continue
    
    # Calculate maximum dimensions for consistent grid cell size
    max_rows = 7  # Number of base metrics (BBH, MATH Lvl 5, etc.)
    max_cols = max([corr.shape[1] for corr in correlation_matrices]) if correlation_matrices else 13
    
    # Calculate figure size to ensure consistent grid cell size
    cell_size = 0.8  # inches per grid cell
    fig_width = 6 * n_plots  # Simple width calculation
    fig_height = max_rows * cell_size + 2  # Add space for titles and colorbar
    
    fig, axes = plt.subplots(rows, cols, figsize=(fig_width, fig_height))
    if n_plots == 1:
        axes = [axes]
    else:
        axes = axes.flatten()
    
    # Reset correlation_matrices for second pass
    correlation_matrices = []
    
    for idx, (simulation_name, task_name) in enumerate(simulation_task_pairs):
        # Load data
        try:
            merged_df = pd.read_csv(f'./eval_results/{simulation_name}_{task_name}_results_with_flops_and_openllm.csv')
            print(f"\nProcessing {simulation_name}_{task_name}")
            print(f"Dataframe shape: {merged_df.shape}")
        except FileNotFoundError:
            print(f"Warning: File not found for {simulation_name}_{task_name}")
            continue
        
        # Create correlation matrix
        cross_correlation = create_correlation_heatmap(simulation_name, task_name, merged_df)
        
        if cross_correlation.empty:
            print(f"Warning: No valid correlations for {simulation_name}_{task_name}")
            continue
        
        correlation_matrices.append(cross_correlation)
        
        # Create subplot
        ax = axes[idx]
        
        # Create the heatmap without annotations and without individual colorbar
        sns.heatmap(cross_correlation, 
                   annot=False,  # Remove correlation values
                   cmap='RdBu_r',
                   center=0,
                   vmin=-1, vmax=1,
                   square=False,  # Allow rectangular cells
                   cbar=False,  # Remove individual colorbar
                   linewidths=0.5,
                   linecolor='white',
                   ax=ax)

        # Customize subplot
        title = f'{simulation_name.title()} - {task_name.replace("_", " ").title()}'
        ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
        
        # Only show x-axis labels on the bottom subplot
        if idx == n_plots - 1:
            ax.set_xlabel('EAI Task Metrics', fontsize=10, fontweight='bold')
        else:
            ax.set_xlabel('')
            ax.set_xticklabels([])
        
        # Only show y-axis labels on the leftmost subplot
        if idx == 0:
            ax.set_ylabel('Base LLM Benchmark Metrics', fontsize=10, fontweight='bold')
        else:
            ax.set_ylabel('')
            ax.set_yticklabels([])
        
        # Rotate x-axis labels and set alignment (only for bottom subplot)
        if idx == n_plots - 1:
            ax.tick_params(axis='x', rotation=45)
            for label in ax.get_xticklabels():
                label.set_ha('right')
        
        print(f"Remaining EAI metrics: {list(cross_correlation.columns)}")
    
    # Hide unused subplots
    for idx in range(n_plots, len(axes)):
        axes[idx].set_visible(False)
    
    # Ensure all subplots have the same aspect ratio
    for ax in axes[:n_plots]:
        ax.set_aspect('equal')
    
    # Add a single horizontal colorbar at the bottom
    cbar_ax = fig.add_axes([0.1, 0.05, 0.8, 0.02])  # [left, bottom, width, height]
    sm = plt.cm.ScalarMappable(cmap='RdBu_r', norm=plt.Normalize(vmin=-1, vmax=1))
    cbar = fig.colorbar(sm, cax=cbar_ax, orientation='horizontal')
    cbar.set_label('Correlation Coefficient', fontsize=12, fontweight='bold')
    
    # Adjust layout with constrained layout for better control
    plt.subplots_adjust(wspace=0.3)  # Add space between subplots
    
    # Save the plot
    file_name = f"openllm_correlation_heatmap_multiple.png"
    plt.savefig(f'./plots/correlation/{file_name}', dpi=300, bbox_inches='tight')
    print(f"\nMultiple correlation heatmaps saved as './plots/correlation/{file_name}'")
    
    plt.show()


def main():
    """Main function to run the correlation analysis."""
    # Define the simulation-task pairs you want to analyze
    simulation_task_pairs = [
        ("virtualhome", "action_sequencing"), 
        ("behavior", "action_sequencing"),
        ("virtualhome", "goal_interpretation"),
    ]
    
    # Create multiple heatmaps
    plot_multiple_heatmaps(simulation_task_pairs)


if __name__ == "__main__":
    main()




