import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from utils import *

# eval_result_path = "./eval_results/other/livebench_paper_table.csv"
eval_result_path = "./eval_results/other/livebench_image_name_mapping_v4.csv"
base_llm_eval_with_emg_eval = pd.read_csv(eval_result_path)

eai_csv = "./eval_results/other/trajectory_eval_action_sequencing.csv"
eai_eval = pd.read_csv(eai_csv)

filename = "correlation_heatmap_eai.png"

print("Base LLM columns:", base_llm_eval_with_emg_eval.columns.tolist())
print("EAI columns:", eai_eval.columns.tolist())

# Define the columns for correlation analysis
eai_metrics = ['Task_SR_V', 'Task_SR_B', 
              'Exec_SR_V', 'Exec_SR_B', 
              'Grammar_Parsing_V', 'Grammar_Parsing_B', 
              'Grammar_Halluc_V', 'Grammar_Halluc_B', 
              'Grammar_PredArgNum_V', 'Grammar_PredArgNum_B', 
              'Runtime_WrongOrder_V', 'Runtime_WrongOrder_B', 
              'Runtime_MissingStep_V', 'Runtime_MissingStep_B', 
              'Runtime_Affordance_V', 'Runtime_Affordance_B', 
              'Runtime_AdditionalStep_V', 'Runtime_AdditionalStep_B']

# Divide eai_eval into VirtualHome and Behavior dataframes
eai_metrics_v = [metric for metric in eai_metrics if metric.endswith('_V')]
eai_metrics_b = [metric for metric in eai_metrics if metric.endswith('_B')]

# Create VirtualHome dataframe
eai_eval_virtualhome = eai_eval[['Model'] + eai_metrics_v].copy()
eai_eval_virtualhome['dataset'] = 'VirtualHome'

# Create Behavior dataframe  
eai_eval_behavior = eai_eval[['Model'] + eai_metrics_b].copy()
eai_eval_behavior['dataset'] = 'Behavior'

# Rename metric columns (remove _V and _B suffixes)
metric_renaming_v = {metric: metric.replace('_V', '') for metric in eai_metrics_v}
metric_renaming_b = {metric: metric.replace('_B', '') for metric in eai_metrics_b}

eai_eval_virtualhome = eai_eval_virtualhome.rename(columns=metric_renaming_v)
eai_eval_behavior = eai_eval_behavior.rename(columns=metric_renaming_b)

print("VirtualHome metrics (renamed):", list(metric_renaming_v.values()))
print("Behavior metrics (renamed):", list(metric_renaming_b.values()))

# Merge base LLM data with each dataset separately
merged_df_v = pd.merge(base_llm_eval_with_emg_eval, eai_eval_virtualhome, on='Model', how='inner')
merged_df_b = pd.merge(base_llm_eval_with_emg_eval, eai_eval_behavior, on='Model', how='inner')

print(f"\nVirtualHome merged dataframe shape: {merged_df_v.shape}")
print(f"Behavior merged dataframe shape: {merged_df_b.shape}")

# Define the base metrics
base_metrics = ['Overall', 'Coding', 'Data Analysis', 'Instruction Following', 'Language', 'Math', 'Reasoning']

# Get the renamed EAI metrics for correlation analysis
eai_metrics_clean_v = list(metric_renaming_v.values())
eai_metrics_clean_b = list(metric_renaming_b.values())

# Create correlation matrices for each
correlation_matrix_v = merged_df_v[base_metrics + eai_metrics_clean_v].corr()
correlation_matrix_b = merged_df_b[base_metrics + eai_metrics_clean_b].corr()

# Extract cross-correlations for each
cross_correlation_v = correlation_matrix_v.loc[base_metrics, eai_metrics_clean_v]
cross_correlation_b = correlation_matrix_b.loc[base_metrics, eai_metrics_clean_b]

print(f"\nVirtualHome cross-correlation shape: {cross_correlation_v.shape}")
print(f"Behavior cross-correlation shape: {cross_correlation_b.shape}")

# Create two subplots side by side
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))

# VirtualHome heatmap
sns.heatmap(cross_correlation_v, 
            annot=True, 
            cmap='RdBu',  # Red-Blue diverging colormap (red for negative, blue for positive)
            center=0,  # Center at 0 to show positive/negative distinction
            vmin=-1, vmax=1,  # Full range from -1 to 1
            square=True,
            cbar_kws={'label': 'Correlation Coefficient'},
            fmt='.3f',
            linewidths=0.5,
            linecolor='white',
            ax=ax1)

ax1.set_title('VirtualHome Metrics Correlation', fontsize=14, fontweight='bold')
ax1.set_xlabel('VirtualHome EAI Metrics', fontsize=12, fontweight='bold')
ax1.set_ylabel('Base LLM Benchmark Metrics', fontsize=12, fontweight='bold')
ax1.tick_params(axis='x', rotation=45)

# Behavior heatmap
sns.heatmap(cross_correlation_b, 
            annot=True, 
            cmap='RdBu',  # Red-Blue diverging colormap
            center=0,  # Center at 0
            vmin=-1, vmax=1,  # Full range from -1 to 1
            square=True,
            cbar_kws={'label': 'Correlation Coefficient'},
            fmt='.3f',
            linewidths=0.5,
            linecolor='white',
            ax=ax2)

ax2.set_title('Behavior Simulation Metrics Correlation', fontsize=14, fontweight='bold')
ax2.set_xlabel('Behavior EAI Metrics', fontsize=12, fontweight='bold')
ax2.set_ylabel('Base LLM Benchmark Metrics', fontsize=12, fontweight='bold')
ax2.tick_params(axis='x', rotation=45)

# Adjust layout
plt.tight_layout()

# Save the plot
filename = "correlation_heatmap_eai_action_seq.png"
plt.savefig(f'./plots/correlation/{filename}', dpi=300, bbox_inches='tight')
print(f"\nSeparated correlation heatmaps saved as './plots/correlation/{filename}'")

# Display the correlation matrices
print("\nVirtualHome cross-correlation matrix:")
print(cross_correlation_v.round(3))

print("\nBehavior cross-correlation matrix:")
print(cross_correlation_b.round(3))

# Find strongest correlations for each
print("\nTop 5 strongest positive correlations (VirtualHome):")
positive_corr_v = cross_correlation_v.unstack().sort_values(ascending=False)
positive_corr_v = positive_corr_v[positive_corr_v > 0]
print(positive_corr_v.head().round(3))

print("\nTop 5 strongest positive correlations (Behavior):")
positive_corr_b = cross_correlation_b.unstack().sort_values(ascending=False)
positive_corr_b = positive_corr_b[positive_corr_b > 0]
print(positive_corr_b.head().round(3))

plt.show()




