import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import json
import re
import copy

from utils import *

dataset_name = "behavior"  # virtualhome #behavior
eaval_type = "action_sequencing" # "action_sequencing" # "action_sequencing_v4" # "goal_interpretation"

if dataset_name == "virtualhome":
    obs_xmetric = PC_METRIC_NUM_3
elif dataset_name == "behavior":
    obs_xmetric = PC_METRIC_NUM_4

# Load VirtualHome action sequencing dataset
virtualhome_eval = pd.read_csv(f'./eval_results/{dataset_name}_{eaval_type}_results_with_flops_and_openllm.csv')

# Filter to only include models with valid Model Size data
virtualhome_eval = virtualhome_eval.dropna(subset=['FLOPs (1E21)'])
print(f"After filtering FLOPs: {len(virtualhome_eval)} models with meaningful performance data")

# remove models without open LLM data
virtualhome_eval =virtualhome_eval.dropna(subset=['Average'])
print(f"After filtering Average: {len(virtualhome_eval)} models with meaningful performance data")

# remove instruct-like models
# virtualhome_eval = virtualhome_eval[~virtualhome_eval['Model'].str.contains('chuan', case=False, na=False)]
print(f"After filtering Chat, Instruct, it: {len(virtualhome_eval)} models with meaningful performance data")

# Convert all error columns to success rates in one line
error_columns = [col for col in virtualhome_eval.columns if col.endswith('_error')]
virtualhome_eval[error_columns] = (100 - virtualhome_eval[error_columns])/100
print(f"Converted {len(error_columns)} error columns to success rates")

# Convert all rate columns to success rates in one line
rate_columns = [col for col in virtualhome_eval.columns if col.endswith('_rate') or col.endswith('_goal')]
virtualhome_eval[rate_columns] = virtualhome_eval[rate_columns]/100
print(f"Converted {len(rate_columns)} rate columns to success rates")

# Convert all metrics to success rates
##################################### pca metrics is very sensitive to the scale of the metric, convert it to success rate #####################################
convert_metrics_to_rate = ['Average', 'BBH', 'MATH Lvl 5', 'GPQA', 'MUSR', 'MMLU-PRO', 'IFEval']
virtualhome_eval[convert_metrics_to_rate] = virtualhome_eval[convert_metrics_to_rate]/100
print(f"Converted {len(convert_metrics_to_rate)} metrics to success rates")


# List the remaining models after filtering
print(f"\n📋 {len(virtualhome_eval)} Models remaining after filtering:")
print("=" * 80)
virtualhome_eval_sorted = virtualhome_eval.sort_values(['Model Family', 'Model'])
for _, row in virtualhome_eval_sorted.iterrows():
    # Check if Average column exists and has a value
    if 'Average' in row:
        if pd.notna(row['Average']):
            average_value = row['Average']
            average_status = f"✅ {average_value:.2f}"
        else:
            average_status = "❌ No open LLM data"
    else:
        average_status = "⚠️ No open LLM column"
    
    print(f"  • {row['Model']} (Family: {row['Model Family']}) {average_status}")
print("=" * 80)

# Define Y metrics for VirtualHome action sequencing - organized into 3 groups of 4 metrics
VIRTUALHOME_Y_METRICS = [
    ['task_success_rate', 'execution_success_rate'],
    ['task_success_rate', 'execution_success_rate', 'parsing_error', 'hallucination_error'],
    ['total_goal', 'state_goal', 'relation_goal', 'action_goal'],
    ['wrong_order_error', 'missing_step_error', 'additional_step_error', 'affordance_error'],
    ['task_success_rate', 'execution_success_rate'],
    ['task_success_rate', 'execution_success_rate', 'parsing_error', 'hallucination_error'],
    ['total_goal', 'state_goal', 'relation_goal', 'action_goal'],
    ['wrong_order_error', 'missing_step_error', 'additional_step_error', 'affordance_error'],
    ['task_success_rate', 'execution_success_rate'],
    ['task_success_rate', 'execution_success_rate', 'parsing_error', 'hallucination_error'],
    ['total_goal', 'state_goal', 'relation_goal', 'action_goal'],
    ['wrong_order_error', 'missing_step_error', 'additional_step_error', 'affordance_error'],
]

VIRTUALHOME_Y_METRIC_MAP = {
    'task_success_rate': "Task Success Rate",
    'execution_success_rate': "Execution Success Rate",
    'total_goal': "Total Goal",
    'state_goal': "State Goal",
    'relation_goal': "Relation Goal",
    'action_goal': "Action Goal",
    'parsing_error': "Parsing Error",
    'hallucination_error': "Hallucination Error",
    'wrong_order_error': "Wrong Order Error",
    'missing_step_error': "Missing Step Error",
    'additional_step_error': "Additional Step Error",
    'affordance_error': "Affordance Error",

    'node_f1': "Node F1",
    'edge_f1': "Edge F1",
    'action_f1': "Action F1",
    'all_f1': "Overall F1"
}

PLOT_X_METRICS_LIST = [
    obs_xmetric,
    obs_xmetric, 
    obs_xmetric,
    obs_xmetric,
    ['Model Size (B)'],
    ['Model Size (B)'],
    ['Model Size (B)'],
    ['Model Size (B)'],
    ['FLOPs (1E21)'],
    ['FLOPs (1E21)'],
    ['FLOPs (1E21)'],
    ['FLOPs (1E21)'],
]

pca_imputation_metrics = ['Average', 'BBH', 'MATH Lvl 5', 'GPQA', 'MUSR', 'MMLU-PRO', 'IFEval']
# pca_imputation_metrics = ["action_goal", "relation_goal", "state_goal", "total_goal", "parsing_error", "hallucination_error"]
NONGSM_PCA_PREPROCESS_KWARGS['imputation_metrics'] = pca_imputation_metrics

# Setup configuration for VirtualHome analysis
VIRTUALHOME_DEFAULT_SETUP_KWARGS_PCA = {
    **NONGSM_PCA_PREPROCESS_KWARGS,

    "pca_metrics": pca_imputation_metrics,
    "ref_model_family": "Gemma", #"Gemma", #"Llama-2", #"Qwen", #Gemma looks better than llama, Qwen1.5 looks better than llama
    "stylize_data": True,
    "nonlinearity": "sigmoid-parametric", 

    # Metric normalization - normalize success rates to [0,1] range
    "y_metric_process_funcs": "minmax_norm",

    # Group markers by model family
    "df_groupby": 'Model Family',  
    
    # Regression: robust regression for handling outliers
    "reg_method": "robust",  
    "reg_kwargs": {"delta": 1.0},   # huber loss with delta=1.0 for normalized target within [0, 1]
    
    # No PCA needed since we're only using Model Size as predictor
    "apply_pca": True,
    "apply_imputation": True,

    # "split_method": "cutoff_by_FLOPs (1E21)",
    # "cutoff_threshold": 2000,
    "split_method": "cutoff_by_Model Size (B)",
    "cutoff_threshold": 40,  # Split at 70B parameters / 80

    # Override the default stylize_model_family to use all families in your dataset
    "stylize_model_family": virtualhome_eval['Model Family'].unique().tolist(),
}

# Setup configuration for VirtualHome analysis
VIRTUALHOME_DEFAULT_SETUP_KWARGS_NO_PCA = {
    # Metric normalization - normalize success rates to [0,1] range
    "y_metric_process_funcs": "minmax_norm",

    # Group markers by model family
    "df_groupby": 'Model Family',  
    
    # Regression: robust regression for handling outliers
    "reg_method": "robust",  
    "reg_kwargs": {"delta": 1.0},   # huber loss with delta=1.0 for normalized target within [0, 1]
    
    # No PCA needed since we're only using Model Size as predictor
    "apply_pca": False,
    "apply_imputation": False,
    
    # Override the default stylize_model_family to use all families in your dataset
    "stylize_model_family": virtualhome_eval['Model Family'].unique().tolist(),

    "split_method": "cutoff_by_Model Size (B)",
    "cutoff_threshold": 40,  # Split at 70B parameters / 80
}

# Task-specific configuration for VirtualHome metrics
VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS = {}    

# Configure each individual metric (not the groups)
all_individual_metrics = []
for group in VIRTUALHOME_Y_METRICS:
    all_individual_metrics.extend(group)

for y_metric in all_individual_metrics:
    VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS[y_metric] = {}
    VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS[y_metric]['y_metric_range'] = (0.0, 1.0)
    VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS[y_metric].update({
        "plot_adjust_kwargs": {"ylim": [-0.05, 1.05]}  # Allow some margin for visualization
    })

# Generate scaling plots for VirtualHome action sequencing
# This will create a 3x4 grid: 3 rows (groups) x 4 columns (metrics per group)
# Each row uses Model Size as the predictor for 4 different Y metrics

# Store all figures in a list
all_figures = []

for (x_metrics, y_metrics) in zip(PLOT_X_METRICS_LIST, VIRTUALHOME_Y_METRICS):

    # check is one of the y_metricsi s all Nan
    if any(virtualhome_eval[y_metrics].isna().all()):
        print(f"⚠️ WARNING: One of {y_metrics} is all NaN, skipping")
        continue

    fig = plot_multi_scaling_predictions(
        virtualhome_eval, y_metrics, [x_metrics], 
        VIRTUALHOME_DEFAULT_SETUP_KWARGS_PCA if x_metrics in [PC_METRIC_NUM_3, PC_METRIC_NUM_2, PC_METRIC_NUM_1, PC_METRIC_NUM_4, PC_METRIC_NUM_5] else VIRTUALHOME_DEFAULT_SETUP_KWARGS_NO_PCA, 
        y_metric_specific_kwargs=VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS, 
        filter_model_family=None,  # Include all model families
        ymetric2title_map=VIRTUALHOME_Y_METRIC_MAP,
        plot_legend=True, legend_nrow=2,
    )
    
    # Store the figure in our list
    all_figures.append(fig)
    
    # Show the plot
    plt.show()

    if not os.path.exists(f'plots/{dataset_name}/{eaval_type}/validate'):
        os.makedirs(f'plots/{dataset_name}/{eaval_type}/validate')
    filename = f'{"_".join(y_metrics)}_vs_{"_".join(x_metrics).replace("Model Size (B)", "Model_Size").replace("FLOPs (1E21)", "FLOPs")}'
    # fig.savefig(f'plots/{dataset_name}/{eaval_type}/validate/{filename}.png', dpi=300, bbox_inches='tight')
    # print(f'Saved plot to plots/{dataset_name}/{eaval_type}/validate/{filename}.png')