import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from utils import *

base_llm_benchmark_eval = load_base_llm_benchmark_eval()
eval_result_path = "./eval_results/base_llm_emergent_capability_eval.csv"
emerg_cap_eval = pd.read_csv(eval_result_path)
base_llm_eval_with_emg_eval = pd.merge(base_llm_benchmark_eval, emerg_cap_eval, on='Model')

eai_csv = "./eval_results/virtualhome_action_sequencing_results_with_flops.csv"
# eai_csv = "./eval_results/virtualhome_goal_interpretation_v4_results.csv"
# file_name = "obsscaling_correlation_heatmap_goal_interpretation.png"
file_name = "obsscaling_correlation_heatmap_action_sequencing.png"
eai_eval = pd.read_csv(eai_csv)

print("Base LLM columns:", base_llm_eval_with_emg_eval.columns.tolist())
print("EAI columns:", eai_eval.columns.tolist())

# Merge the two dataframes on Model column, keeping only models that exist in both
# Keep ALL columns from base_llm_eval_with_emg_eval (including the benchmark metrics)
merged_df = pd.merge(base_llm_eval_with_emg_eval, eai_eval, on='Model', how='inner')
print(f"\nMerged dataframe shape: {merged_df.shape}")
print(f"Models in both datasets: {len(merged_df)}")
print(f"Columns: {merged_df.columns.tolist()}")

# Define the columns for correlation analysis
base_metrics = ['MMLU', 'ARC-C', 'HellaSwag', 'Winograd', 'TruthfulQA', 'GSM8K', 'XWinograd', 'HumanEval']
eai_metrics = ['task_success_rate', 'state_goal', 'relation_goal', 'action_goal', 'total_goal', 
               'execution_success_rate', 'parsing_error', 'hallucination_error', 
               'predicate_argument_number_error', 'wrong_order_error', 'missing_step_error', 
               'affordance_error', 'additional_step_error']
# eai_metrics = ['node_precision', 'node_recall', 'node_f1', 'edge_precision', 'edge_recall', 'edge_f1', 'action_precision', 'action_recall', 'action_f1', 'all_precision', 'all_recall', 'all_f1']

# Create correlation matrix
correlation_matrix = merged_df[base_metrics + eai_metrics].corr()

# Extract only the cross-correlations between base metrics and EAI metrics
cross_correlation = correlation_matrix.loc[base_metrics, eai_metrics]

print(f"\nCross-correlation matrix shape: {cross_correlation.shape}")

# Create a beautiful heatmap
plt.figure(figsize=(16, 10))

# Create the heatmap with custom styling
sns.heatmap(cross_correlation, 
            annot=True,  # Show correlation values
            cmap='RdBu_r',  # Red-Blue diverging colormap
            center=0,  # Center colormap at 0
            vmin=-1, vmax=1,  # Set color range from -1 to 1
            square=True,  # Make cells square
            cbar_kws={'label': 'Correlation Coefficient'},
            fmt='.3f',  # Show 3 decimal places
            linewidths=0.5,  # Add grid lines
            linecolor='white')

# Customize the plot
plt.title('Correlation between Base LLM Benchmarks and EAI Task Performance', 
          fontsize=16, fontweight='bold', pad=20)
plt.xlabel('EAI Task Metrics', fontsize=12, fontweight='bold')
plt.ylabel('Base LLM Benchmark Metrics', fontsize=12, fontweight='bold')

# Rotate x-axis labels for better readability
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)

# Adjust layout to prevent label cutoff
plt.tight_layout()

# Save the plot
plt.savefig(f'./plots/correlation/{file_name}', dpi=300, bbox_inches='tight')
print(f"\nCorrelation heatmap saved as './plots/correlation/{file_name}'")

# Display the correlation matrix
print("\nCross-correlation matrix:")
print(cross_correlation.round(3))

# Find the strongest correlations (both positive and negative)
print("\nTop 5 strongest positive correlations:")
positive_corr = cross_correlation.unstack().sort_values(ascending=False)
positive_corr = positive_corr[positive_corr > 0]
print(positive_corr.head().round(3))

print("\nTop 5 strongest negative correlations:")
negative_corr = cross_correlation.unstack().sort_values(ascending=True)
negative_corr = negative_corr[negative_corr < 0]
print(negative_corr.head().round(3))

plt.show()




