import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import json
import re
import copy

from utils import *

dataset_name = "virtualhome" #"behavior" #"virtualhome"
eaval_type = "action_sequencing" #"goal_interpretation" #"action_sequencing"
eval_path = f"./eval_results/{dataset_name}_{eaval_type}_results_with_flops_and_openllm.csv"
eval = pd.read_csv(eval_path)
eval = eval.dropna(subset=['FLOPs (1E21)'])

if not os.path.exists(f"./plots/{dataset_name}/{eaval_type}/pca"):
    os.makedirs(f"./plots/{dataset_name}/{eaval_type}/pca", exist_ok=True)

# if eaval_type == "action_sequencing":
#     EVAL_METRIC_LIST = [
#         'task_success_rate', 'execution_success_rate', 'total_goal', 'state_goal',
#         'relation_goal', 'action_goal', 'parsing_error', 'hallucination_error',
#         'wrong_order_error', 'missing_step_error', 'additional_step_error', 'affordance_error'
#     ]
# elif eaval_type == "goal_interpretation":
#     EVAL_METRIC_LIST = ['node_precision', 'edge_precision', 'action_precision', 'all_precision',
#                 'node_recall', 'edge_recall', 'action_recall', 'all_recall',
#                 'node_f1', 'edge_f1', 'action_f1', 'all_f1']
# else:
#     raise ValueError(f"Invalid evaluation type: {eaval_type}")
EVAL_METRIC_LIST = ['Average', 'BBH', 'MATH Lvl 5', 'GPQA', 'MUSR', 'MMLU-PRO', 'IFEval']

# Remove columns that contain all NaN values
# eval = eval.dropna(axis=1, how='all')
eval = eval.dropna(subset=['Average'])
eval = eval[~eval['Model'].str.contains('Chat|google/gemma-2-2b-it|google/gemma-2-27b-it|google/gemma-2-9b-it|deepseek-ai/DeepSeek-R1-Distill-Qwen-14B|01-ai/Yi-1.5-9B|gemma-2b|ibm-granite/granite-3.1-2b-base|ibm-granite/granite-3.1-8b-base|meta-llama/Llama-3.3-70B-Instruct|meta-llama/Meta-Llama-3-70B-Instruct|meta-llama/Meta-Llama-3-8B-Instruct', case=False, na=False)] # these don't have LLM data

eval = eval.sort_values(['Model Family', 'Model'], ascending=[True, True])
for _, row in eval.iterrows():
    print(f"  • {row['Model']} (Family: {row['Model Family']})")

# Print diagnostic information
print(f"Original DataFrame shape: {eval.shape}")
print(f"Columns with all NaN values removed")
print(f"Remaining columns: {list(eval.columns)}")

# convert eval metric to rate
for x in EVAL_METRIC_LIST:
    eval[x] = eval[x] / 100

# Check for any remaining columns with high NaN percentages
nan_percentages = eval.isnull().sum() / len(eval) * 100
high_nan_cols = nan_percentages[nan_percentages > 50]
if not high_nan_cols.empty:
    print(f"\nColumns with >50% NaN values:")
    for col, pct in high_nan_cols.items():
        print(f"  {col}: {pct:.1f}% NaN")

# Check if all required metrics are available
missing_metrics = [metric for metric in EVAL_METRIC_LIST if metric not in eval.columns]
if missing_metrics:
    print(f"\nWarning: Missing metrics: {missing_metrics}")
    # Update the list to only include available columns
    EVAL_METRIC_LIST = [metric for metric in EVAL_METRIC_LIST if metric in eval.columns]
    print(f"Updated EVAL_METRIC_LIST: {EVAL_METRIC_LIST}")

metric_df = eval[EVAL_METRIC_LIST]

impute_kwargs = DEFAULT_PCA_PREPROCESS_KWARGS['imputation_kwargs'].copy()
impute_kwargs['verbose'] = True
imputed_metric_df, _ = pca_impute(metric_df, **impute_kwargs)

pca, imputed_metric_pc, _ = perform_pca(imputed_metric_df, **DEFAULT_PCA_PREPROCESS_KWARGS['pca_kwargs']) 

eval = pd.concat([eval, imputed_metric_pc], axis=1)

import matplotlib.patches as patches
from matplotlib.path import Path

n_components = DEFAULT_PCA_PREPROCESS_KWARGS['pca_kwargs']['n_components']
fig, ax = plt.subplots(1, 1, figsize=(5, 4))
ax.bar(range(1, n_components + 1), pca.explained_variance_ratio_)

ax.set_xticklabels(range(1, n_components+1))
ax.set_xticks(range(1, n_components+1))
ax.set_xlabel("PC")
ax.set_ylabel("Explained variance ratio")

top_n = 3

sum_top_n = pca.explained_variance_ratio_[:top_n].sum() 

# Draw a custom curly brace (half bracket) above the top 3 bars
brace_height = max(pca.explained_variance_ratio_[:top_n]) + 0.02  # Slightly higher to clear the bars
brace_x_start = 1 - 0.4  # Slightly before the first bar
brace_x_end = 3 + 0.4   # Slightly past the third bar

# Define the points for the path (a simple upside-down half-bracket)
vertices = [
    (brace_x_start, brace_height),  # Left bottom of the brace
    (brace_x_start, brace_height + 0.02),  # Left top of the brace
    (brace_x_end, brace_height + 0.02),  # Right top of the brace
    (brace_x_end, brace_height),  # Right bottom of the brace
]
codes = [Path.MOVETO, Path.LINETO, Path.LINETO, Path.LINETO]

path = Path(vertices, codes)
patch = patches.PathPatch(path, facecolor='none', lw=1.5, edgecolor='gray')
ax.add_patch(patch)

# Annotation for the brace
ax.annotate(f"{sum_top_n:.3f}", 
            ((brace_x_start + brace_x_end) / 2, brace_height + 0.04),  # Position for the text
            textcoords="data",
            ha="center", va="bottom", fontsize=14)
ax.set_title("PCA Explained Variance")

ax.set_ylim([0.0, 0.94])

plt.tight_layout()

##########################plot the pca explained variance##########################
# plt.show()
plt.savefig(f"./plots/{dataset_name}/{eaval_type}/pca/pca_explained_variance.png")
print("="*100)
print(f"Saved PCA explained variance plot to ./plots/{dataset_name}/{eaval_type}/pca/pca_explained_variance.png")
print("="*100)


# visualize the PCA components

fig, ax = plt.subplots(figsize=(10, 5))

# nicer visualization: make the PC dims mostly "postively" correlated to model performance
vis_weights = pca.components_.copy()
for idx in [1, 3, 4]:
    vis_weights[idx] *= -1

sns.heatmap(vis_weights, annot=True, fmt='.2f', cmap='coolwarm', ax=ax)
ax.set_yticklabels(imputed_metric_pc.columns)
ax.set_xticklabels(EVAL_METRIC_LIST, rotation=30,fontsize=10)

##########################plot the pca components##########################
# plt.show()
plt.savefig(f"./plots/{dataset_name}/{eaval_type}/pca/pca_components.png")
print("="*100)
print(f"Saved PCA components plot to ./plots/{dataset_name}/{eaval_type}/pca/pca_components.png")
print("="*100)


# PLOT_COMPARABLE_FLOPS_MODEL_FAMILY = [
#     'Llama-3', 
#     'Yi', 
#     'Qwen1.5', 
#     'Mistral', 
#     'Gemma', 
#     'Granite', 
#     'Qwen3', 
#     'GPT-OSS', 
#     'Qwen', 
#     'Llama-2', 
#     'Llama',
#     'DeepSeek'
# ]

# Count rows for each model family and filter to those with >= 2 rows
model_family_counts = eval['Model Family'].value_counts()
PLOT_COMPARABLE_FLOPS_MODEL_FAMILY = [
    family for family in eval['Model Family'].unique() if model_family_counts[family] >= 2 and family != "Exaone" and family != "Baichuan"
]
print(f"PLOT_COMPARABLE_FLOPS_MODEL_FAMILY: {PLOT_COMPARABLE_FLOPS_MODEL_FAMILY}")

fig = plot_linear_correlation(
    eval, 'FLOPs (1E21)', 'PC-1', 
    PLOT_COMPARABLE_FLOPS_MODEL_FAMILY, log_x_metric=True, num_cols=6
)
##########################plot the pca components vs model size PC-1##########################
# plt.show()
plt.savefig(f"./plots/{dataset_name}/{eaval_type}/pca/pca_components_vs_model_size_pc1.png")
print("="*100)
print(f"Saved PCA components vs model size PC-1 plot to ./plots/{dataset_name}/{eaval_type}/pca/pca_components_vs_model_size_pc1.png")
print("="*100)

fig = plot_linear_correlation(
    eval, 'FLOPs (1E21)', 'PC-2', 
    PLOT_COMPARABLE_FLOPS_MODEL_FAMILY, log_x_metric=True, 
    num_cols=6, ylim=[-0.45, 0.45]
)
plt.savefig(f"./plots/{dataset_name}/{eaval_type}/pca/pca_components_vs_model_size_pc2.png")
print(f"Saved PCA components vs model size PC-2 plot to ./plots/{dataset_name}/{eaval_type}/pca/pca_components_vs_model_size_pc2.png")

fig = plot_linear_correlation(
    eval, 'FLOPs (1E21)', 'PC-3', 
    PLOT_COMPARABLE_FLOPS_MODEL_FAMILY, log_x_metric=True, 
    num_cols=6, ylim=[-0.25, 0.35]
)
plt.savefig(f"./plots/{dataset_name}/{eaval_type}/pca/pca_components_vs_model_size_pc3.png")
print(f"Saved PCA components vs model size PC-3 plot to ./plots/{dataset_name}/{eaval_type}/pca/pca_components_vs_model_size_pc3.png")

fig = plot_linear_correlation(
    eval, 'FLOPs (1E21)', 'PC-1', 
    EVAL_BASE_MODEL_WITH_FLOPS_FAMILIES, log_x_metric=True, unified_plot=True,
)
plt.savefig(f"./plots/{dataset_name}/{eaval_type}/pca/pca_components_vs_model_size_pc1_unified.png")
print(f"Saved PCA components vs model size PC-1 unified plot to ./plots/{dataset_name}/{eaval_type}/pca/pca_components_vs_model_size_pc1_unified.png")

fig = plot_linear_correlation(
    eval, 'FLOPs (1E21)', 'PC-2', 
    EVAL_BASE_MODEL_WITH_FLOPS_FAMILIES, log_x_metric=True, unified_plot=True,
)
plt.savefig(f"./plots/{dataset_name}/{eaval_type}/pca/pca_components_vs_model_size_pc2_unified.png")
print(f"Saved PCA components vs model size PC-2 unified plot to ./plots/{dataset_name}/{eaval_type}/pca/pca_components_vs_model_size_pc2_unified.png")

fig = plot_linear_correlation(
    eval, 'FLOPs (1E21)', 'PC-3', 
    EVAL_BASE_MODEL_WITH_FLOPS_FAMILIES, log_x_metric=True, unified_plot=True,
)
plt.savefig(f"./plots/{dataset_name}/{eaval_type}/pca/pca_components_vs_model_size_pc3_unified.png")
print(f"Saved PCA components vs model size PC-3 unified plot to ./plots/{dataset_name}/{eaval_type}/pca/pca_components_vs_model_size_pc3_unified.png")


print("="*100)
print(f"Saved PCA explained variance plot to ./plots/{dataset_name}/{eaval_type}/pca/pca_explained_variance.png")
print(f"Saved PCA components plot to ./plots/{dataset_name}/{eaval_type}/pca/pca_components.png")
print(f"Saved PCA components vs model size PC-1 plot to ./plots/{dataset_name}/{eaval_type}/pca/pca_components_vs_model_size_pc1.png")
print(f"Saved PCA components vs model size PC-2 plot to ./plots/{dataset_name}/{eaval_type}/pca/pca_components_vs_model_size_pc2.png")
print(f"Saved PCA components vs model size PC-3 plot to ./plots/{dataset_name}/{eaval_type}/pca/pca_components_vs_model_size_pc3.png")
print(f"Saved PCA components vs model size PC-1 unified plot to ./plots/{dataset_name}/{eaval_type}/pca/pca_components_vs_model_size_pc1_unified.png")
print(f"Saved PCA components vs model size PC-2 unified plot to ./plots/{dataset_name}/{eaval_type}/pca/pca_components_vs_model_size_pc2_unified.png")
print(f"Saved PCA components vs model size PC-3 unified plot to ./plots/{dataset_name}/{eaval_type}/pca/pca_components_vs_model_size_pc3_unified.png")
print("="*100)