import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from utils import *


def create_correlation_heatmap(simulation_name, task_name, merged_df):
    """
    Create correlation heatmap for a specific simulation and task.
    
    Args:
        simulation_name (str): Name of the simulation (behavior/virtualhome)
        task_name (str): Name of the task (action_sequencing/goal_interpretation)
        merged_df (pd.DataFrame): Merged dataframe with results
        
    Returns:
        pd.DataFrame: Cross-correlation matrix with renamed columns
    """
    # Define the columns for correlation analysis
    # base_metrics = ['GPQA', 'MUSR', 'IFEval', 'MMLU-PRO', 'BBH', 'MATH Lvl 5', 'Average']
    base_metrics = ['GPQA', 'MUSR', 'IFEval', 'MMLU-PRO', 'BBH', 'MATH Lvl 5']

    if task_name == "action_sequencing":
        eai_metrics = ['task_success_rate', 'state_goal', 'relation_goal', 'action_goal', 'total_goal', 
                      'execution_success_rate', 'parsing_error', 'hallucination_error', 
                      'predicate_argument_number_error', 'wrong_order_error', 'missing_step_error', 
                      'affordance_error', 'additional_step_error']
    elif task_name == "goal_interpretation":
        if simulation_name == "virtualhome":
            # eai_metrics = ['node_precision', 'node_recall', 'node_f1', 'edge_precision', 'edge_recall', 'edge_f1', 
            #               'action_precision', 'action_recall', 'action_f1', 'all_precision', 'all_recall', 'all_f1']
            eai_metrics = ['node_f1',  'edge_f1', 
                          'action_f1', 'all_f1']
        elif simulation_name == "behavior":
            eai_metrics = ['overall_f1', 'state_goal_f1', 'relation_goal_f1', 'state_hallucination_rate', 'object_hallucination_rate', 'format_error_rate', 'grammatically_valid_rate']
        else:
            raise ValueError(f"Invalid simulation name: {simulation_name}")
    else:
        raise ValueError(f"Invalid task name: {task_name}")

    # Create mapping for better x-axis labels
    eai_label_mapping = {
        'task_success_rate': 'Task Success',
        'state_goal': 'State Goal',
        'relation_goal': 'Relation Goal', 
        'action_goal': 'Action Goal',
        'total_goal': 'Total Goal',
        'execution_success_rate': 'Execution Success',
        'parsing_error': 'Parsing',
        'hallucination_error': 'Hallucination',
        'predicate_argument_number_error': 'Predicate Arg',
        'wrong_order_error': 'Wrong Order',
        'missing_step_error': 'Missing Step',
        'affordance_error': 'Affordance',
        'additional_step_error': 'Additional Step',
        # For goal_interpretation metrics
        'node_precision': 'Node Precision',
        'node_recall': 'Node Recall', 
        'node_f1': 'Node F1',
        'edge_precision': 'Edge Precision',
        'edge_recall': 'Edge Recall',
        'edge_f1': 'Edge F1',
        'action_precision': 'Action Precision',
        'action_recall': 'Action Recall',
        'action_f1': 'Action F1',
        'all_precision': 'All Precision',
        'all_recall': 'All Recall',
        'all_f1': 'All F1',
        'overall_f1': 'Overall F1',
        'state_goal_f1': 'State Goal F1',
        'relation_goal_f1': 'Relation Goal F1',
        'state_hallucination_num': 'State Hallucination Num',
        'state_hallucination_rate': 'State Hallucination Rate',
        'object_hallucination_num': 'Object Hallucination Num',
        'object_hallucination_rate': 'Object Hallucination Rate',
        'format_error_rate': 'Format Error Rate',
        'grammatically_valid_rate': 'Grammatically Valid Rate'
    }

    # Process error columns (convert to success rates)
    for col in eai_metrics:
        if col.endswith('error'):
            print(f"Processing {col}")
            print(merged_df[col].unique())
            merged_df[col] = 100 - merged_df[col]
    for col in ['state_hallucination_rate', 'object_hallucination_rate', 'format_error_rate', 'grammatically_valid_rate']:
        if col in merged_df.columns:
            print(f"Processing {col}")
            merged_df[col] = 100 - merged_df[col]

    # Create correlation matrix
    correlation_matrix = merged_df[base_metrics + eai_metrics].corr()
    
    # Extract only the cross-correlations between base metrics and EAI metrics
    cross_correlation = correlation_matrix.loc[base_metrics, eai_metrics]
    
    # Remove columns that are all NaN
    cross_correlation = cross_correlation.dropna(axis=1, how='all')
    
    # Rename the columns (x-axis labels) using the mapping
    cross_correlation = cross_correlation.rename(columns=eai_label_mapping)
    
    return cross_correlation


def create_merged_heatmap(simulation_task_pairs):
    """
    Create a single heatmap by merging all correlation matrices with hierarchical labeling.
    
    Args:
        simulation_task_pairs (list): List of tuples [(simulation_name, task_name), ...]
    """
    all_correlations = []
    column_mapping = {}  # Store original column names for color coding
    hierarchical_labels = []  # Store the new hierarchical labels
    
    # Create metric to ID mapping
    metric_to_id = {}
    id_counter = 1
    
    # Define color mapping for different metric types
    metric_colors = {
        'Task Success': '#1f77b4',  # blue
        'State Goal': '#ff7f0e',    # orange
        'Relation Goal': '#2ca02c', # green
        'Action Goal': '#d62728',   # red
        'Total Goal': '#9467bd',    # purple
        'Execution Success': '#8c564b', # brown
        'Parsing': '#e377c2',       # pink
        'Hallucination': '#7f7f7f', # gray
        'Predicate Arg': '#bcbd22', # olive
        'Wrong Order': '#17becf',   # cyan
        'Missing Step': '#ff9896',  # light red
        'Affordance': '#98df8a',    # light green
        'Additional Step': '#c5b0d5', # light purple
        # Goal interpretation metrics
        'Node Precision': '#1f77b4',
        'Node Recall': '#1f77b4',
        'Node F1': '#1f77b4',
        'Edge Precision': '#ff7f0e',
        'Edge Recall': '#ff7f0e',
        'Edge F1': '#ff7f0e',
        'Action Precision': '#2ca02c',
        'Action Recall': '#2ca02c',
        'Action F1': '#2ca02c',
        'All Precision': '#d62728',
        'All Recall': '#d62728',
        'All F1': '#d62728',
        # Additional metrics
        'Overall F1': '#8B4513',    # saddle brown
        'State Goal F1': '#FF6347', # tomato
        'Relation Goal F1': '#32CD32', # lime green
        'Format Error Rate': '#4B0082',  # indigo
        'State Hallucination Rate': '#6A5ACD', # slate blue
        'Grammatically Valid Rate': '#DC143C',  # crimson
        'Object Hallucination Rate': '#B22222'  # fire brick
    }
    
    # Define short codes for simulations and tasks
    sim_codes = {'virtualhome': 'V', 'behavior': 'B'}
    task_codes = {'action_sequencing': '1', 'goal_interpretation': '2'}
    
    legend_info = {}  # Store legend information
    counter = 1
    
    for simulation_name, task_name in simulation_task_pairs:
        # Load data
        try:
            merged_df = pd.read_csv(f'./eval_results/{simulation_name}_{task_name}_results_with_flops_and_openllm.csv')
            print(f"\nProcessing {simulation_name}_{task_name}")
            print(f"Dataframe shape: {merged_df.shape}")
        except FileNotFoundError:
            print(f"Warning: File not found for {simulation_name}_{task_name}")
            continue
        
        # Create correlation matrix
        cross_correlation = create_correlation_heatmap(simulation_name, task_name, merged_df)
        
        if cross_correlation.empty:
            print(f"Warning: No valid correlations for {simulation_name}_{task_name}")
            continue
        
        # Create hierarchical labels and store original metric names
        sim_code = sim_codes.get(simulation_name, simulation_name[0].upper())
        task_code = task_codes.get(task_name, str(counter))
        prefix = f"{sim_code}{task_code}"
        
        # Store legend information
        legend_info[prefix] = f"{simulation_name.title()} - {task_name.replace('_', ' ').title()}"
        
        new_columns = []
        for original_col in cross_correlation.columns:
            # Assign ID to metric if not already assigned
            if original_col not in metric_to_id:
                metric_to_id[original_col] = id_counter
                id_counter += 1
            
            metric_id = metric_to_id[original_col]
            hierarchical_label = f"{prefix}-{metric_id}"
            new_columns.append(hierarchical_label)
            column_mapping[hierarchical_label] = original_col
            hierarchical_labels.append(hierarchical_label)
        
        cross_correlation.columns = new_columns
        all_correlations.append(cross_correlation)
        print(f"Added {cross_correlation.shape[1]} metrics for {simulation_name}_{task_name}")
        counter += 1
    
    # Merge all correlation matrices horizontally
    if all_correlations:
        merged_correlation = pd.concat(all_correlations, axis=1)
        print(f"\nMerged correlation matrix shape: {merged_correlation.shape}")
    else:
        print("No valid correlation matrices found!")
        return
    
    # Create the single comprehensive heatmap
    fig, ax = plt.subplots(figsize=(merged_correlation.shape[1] * 0.4, merged_correlation.shape[0] * 0.8))
    
    sns.heatmap(merged_correlation, 
               annot=False,  # Remove correlation values for clarity
               cmap='plasma_r',  # Purple-orange gradient colormap, higher values = darker, other options: RdBu_r
               center=0,
               vmin=-1, vmax=1,
               square=True,  # Keep cells square
               cbar_kws={'label': 'Correlation Coefficient', 'shrink': 0.6, 'pad': 0.02},  # Make colorbar smaller and closer
               linewidths=0,  # Remove cell borders
               ax=ax)
    
    # Customize the plot
    ax.set_title('Correlation between Base LLM Benchmarks and EAI Task Performance', 
                fontsize=16, fontweight='bold', pad=20)
    ax.set_xlabel('EAI Task Metrics', fontsize=12, fontweight='bold')
    ax.set_ylabel('OpenLLM Leaderboard', fontsize=12, fontweight='bold')
    
    # Set hierarchical x-axis labels with colors based on metric types
    # ax.set_xticks([i + 0.5 for i in range(len(hierarchical_labels))])  # Move ticks to center of cells
    ax.set_xticklabels(hierarchical_labels, rotation=45, ha='right')
    
    # Add horizontal braces to group by simulation
    # Calculate the column ranges for each simulation using the actual processed data
    simulation_groups = {}
    current_pos = 0
    
    # Use the already processed correlation matrices
    for idx, correlation_matrix in enumerate(all_correlations):
        simulation_name, task_name = simulation_task_pairs[idx]
        n_cols = correlation_matrix.shape[1]
        
        # Group by simulation name only (not task)
        mapping_name = f"{simulation_name}_{task_name}"
        if mapping_name not in simulation_groups:
            simulation_groups[mapping_name] = []
        simulation_groups[mapping_name].append((current_pos, current_pos + n_cols))
        current_pos += n_cols
    
    # Merge ranges for same simulation
    merged_groups = {}
    for sim, ranges in simulation_groups.items():
        start = min(r[0] for r in ranges)
        end = max(r[1] for r in ranges)
        merged_groups[sim] = (start, end)
    
    # Add braces above the heatmap
    y_pos = len(merged_correlation.index) + 0.5  # Position above the heatmap
    brace_height = 0.3
    
    for sim_name, (start_col, end_col) in merged_groups.items():
        # Draw horizontal line
        ax.plot([start_col, end_col], [y_pos, y_pos], 'k-', linewidth=2)
        
        # Draw vertical lines at ends
        ax.plot([start_col, start_col], [y_pos, y_pos - brace_height], 'k-', linewidth=2)
        ax.plot([end_col, end_col], [y_pos, y_pos - brace_height], 'k-', linewidth=2)
        
        # Add simulation name label
        mid_point = (start_col + end_col) / 2
        print(f"Adding brace (start_col: {start_col}, end_col: {end_col}) for {sim_name} at {mid_point}, {y_pos + 0.2}")
        ax.text(mid_point, y_pos + 0.2, sim_name.title().split('_')[0], 
               ha='center', va='bottom', fontsize=12)
    
    # Adjust y-axis limits to accommodate braces
    ax.set_ylim(-0.5, len(merged_correlation.index) + 1.5)
    
    # Color the x-axis labels based on metric type
    for i, label in enumerate(ax.get_xticklabels()):
        hierarchical_label = label.get_text()
        original_metric = column_mapping[hierarchical_label]
        
        # Get base metric type for coloring (remove specific type indicators)
        base_metric = original_metric
        if 'Precision' in original_metric:
            base_metric = 'Precision'
        elif 'Recall' in original_metric:
            base_metric = 'Recall'
        elif 'F1' in original_metric:
            base_metric = 'F1'
        
        color = metric_colors.get(original_metric, metric_colors.get(base_metric, '#000000'))
        label.set_color(color)
        label.set_fontweight('bold')
    
    # Create legend for simulation/task codes
    legend_text = "Simulation-Task Codes: "
    for i, (code, description) in enumerate(legend_info.items()):
        if i > 0:
            legend_text += " | "
        legend_text += f"{code}: {description}"
    
    # Create metric ID legend with multiple lines
    metric_items = list(metric_to_id.items())
    num_lines = 3
    items_per_line = len(metric_items) // num_lines + (1 if len(metric_items) % num_lines else 0) + 2
    
    metric_legend_lines = []
    for line_num in range(num_lines):
        start_idx = line_num * items_per_line
        end_idx = min((line_num + 1) * items_per_line, len(metric_items))
        
        if start_idx < len(metric_items):
            line_items = []
            for i in range(start_idx, end_idx):
                metric, metric_id = metric_items[i]
                line_items.append(f"{metric_id}: {metric}")
            metric_legend_lines.append(" | ".join(line_items))
    
    metric_legend_text = "Metric IDs:\n" + "\n".join(metric_legend_lines)
    
    # Position both legends at bottom of the figure
    fig.text(0.1, 0.22, legend_text, fontsize=8, 
             bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
    
    fig.text(0.1, 0.10, metric_legend_text, fontsize=8, 
             bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))
    
    # Adjust layout to prevent label cutoff and make room for bottom legends
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.4)  # Make room for both bottom legends
    
    # Save the plot
    file_name = f"openllm_correlation_heatmap_hierarchical.png"
    plt.savefig(f'./plots/correlation/{file_name}', dpi=300, bbox_inches='tight')
    print(f"\nHierarchical correlation heatmap saved as './plots/correlation/{file_name}'")
    
    # Display the mapping
    # print("\nHierarchical label mapping:")
    # for hier_label, original in column_mapping.items():
    #     print(f"{hier_label}: {original}")
    
    # print("\nMetric to ID mapping:")
    # for metric, metric_id in metric_to_id.items():
    #     print(f"{metric_id}: {metric}")
    
    plt.show()
    
    return merged_correlation, column_mapping, metric_to_id


def create_latex_tables(simulation_task_pairs):
    """
    Create LaTeX tables for all simulation-task pairs and save as a single .tex file.
    
    Args:
        simulation_task_pairs (list): List of tuples [(simulation_name, task_name), ...]
    """
    all_tables = []
    
    # Document header
    latex_document = "% LaTeX document with correlation tables\n"
    latex_document += "% Required packages: booktabs, threeparttable, longtable, array\n\n"
    
    for simulation_name, task_name in simulation_task_pairs:
        # Load data
        try:
            merged_df = pd.read_csv(f'./eval_results/{simulation_name}_{task_name}_results_with_flops_and_openllm.csv')
            print(f"\nCreating LaTeX table for {simulation_name}_{task_name}")
            print(f"Dataframe shape: {merged_df.shape}")
        except FileNotFoundError:
            print(f"Warning: File not found for {simulation_name}_{task_name}")
            continue
        
        # Create correlation matrix
        cross_correlation = create_correlation_heatmap(simulation_name, task_name, merged_df)
        
        if cross_correlation.empty:
            print(f"Warning: No valid correlations for {simulation_name}_{task_name}")
            continue
        
        # Transpose the correlation matrix: LLM metrics as columns, EAI metrics as rows
        transposed_correlation = cross_correlation.T.round(3)
        
        # Create LaTeX table
        latex_table = "\\begin{table}[htbp]\n"
        latex_table += "\\centering\n"
        latex_table += f"\\caption{{Correlation between Base LLM Benchmarks and {simulation_name.title()} {task_name.replace('_', ' ').title()} Task Performance. Bold values indicate strong correlations ($|r| \\geq 0.7$), italic values indicate moderate correlations ($0.5 \\leq |r| < 0.7$).}}\n"
        latex_table += f"\\label{{tab:{simulation_name}_{task_name}_correlation}}\n"
        latex_table += "\\footnotesize\n"  # Make font smaller for better fit
        
        # Start tabular environment
        n_cols = len(transposed_correlation.columns)  # Number of LLM metrics
        col_spec = "l" + "c" * n_cols  # Left align first column, center others
        latex_table += f"\\begin{{tabular}}{{{col_spec}}}\n"
        latex_table += "\\toprule\n"
        
        # Header row - now showing LLM metrics as columns
        header_row = "EAI Task Metrics"
        for col in transposed_correlation.columns:  # These are now the LLM metrics
            # Escape special LaTeX characters
            col_clean = col.replace("_", "\\_").replace("&", "\\&")
            header_row += f" & {col_clean}"
        header_row += " \\\\\n"
        latex_table += header_row
        latex_table += "\\midrule\n"
        
        # Data rows - now EAI metrics as row labels
        for idx_name in transposed_correlation.index:  # These are now the EAI metrics
            # Escape special LaTeX characters and truncate long names if needed
            row_label = idx_name.replace("_", "\\_").replace("&", "\\&")
            if len(row_label) > 20:  # Truncate long EAI metric names
                row_label = row_label[:17] + "..."
            row = row_label
            
            for col_name in transposed_correlation.columns:  # LLM metrics
                value = transposed_correlation.loc[idx_name, col_name]
                if pd.isna(value):
                    row += " & --"
                else:
                    # Color code based on correlation strength
                    if abs(value) >= 0.7:
                        row += f" & \\textbf{{{value:.3f}}}"  # Bold for strong correlations
                    elif abs(value) >= 0.5:
                        row += f" & \\textit{{{value:.3f}}}"  # Italic for moderate correlations
                    else:
                        row += f" & {value:.3f}"
            row += " \\\\\n"
            latex_table += row
        
        latex_table += "\\bottomrule\n"
        latex_table += "\\end{tabular}\n"
        latex_table += "\\end{table}\n\n"
        
        all_tables.append(latex_table)
        print(f"Table dimensions: {transposed_correlation.shape[0]} rows × {transposed_correlation.shape[1]} columns")
    
    # Combine all tables
    latex_document += "".join(all_tables)
    
    # Add summary note at the end
    latex_document += "% Summary:\n"
    latex_document += f"% Total tables generated: {len(all_tables)}\n"
    latex_document += "% Tables included:\n"
    for simulation_name, task_name in simulation_task_pairs:
        latex_document += f"% - {simulation_name.title()} {task_name.replace('_', ' ').title()}\n"
    
    # Save to single .tex file
    output_filename = "./plots/correlation/all_correlation_tables.tex"
    with open(output_filename, 'w') as f:
        f.write(latex_document)
    
    print(f"\nAll LaTeX tables saved as '{output_filename}'")
    print(f"Total tables created: {len(all_tables)}")


def main():
    """Main function to run the correlation analysis."""
    # Define the simulation-task pairs you want to analyze
    simulation_task_pairs = [
        ("virtualhome", "action_sequencing"), 
        ("behavior", "action_sequencing"),
        ("virtualhome", "goal_interpretation"),
        ("behavior", "goal_interpretation"),
    ]
    
    # Create merged heatmap
    merged_corr, mapping, metric_to_id = create_merged_heatmap(simulation_task_pairs)
    
    # Create LaTeX tables for each pair
    create_latex_tables(simulation_task_pairs)


if __name__ == "__main__":
    main()




