import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import json
import re
import copy

from utils import *

# Define evaluation type and dataset name
eaval_type = "action_sequencing"  # "action_sequencing" or "goal_interpretation"
virtualhome_eval_path = f"/Users/qinjielin/Downloads/NWU/25corl/corl_ws/ObsScaling/eval_results/virtualhome_{eaval_type}_results_with_flops_and_openllm.csv"
virtualhome_eval = pd.read_csv(virtualhome_eval_path)
behavior_eval_path = f"/Users/qinjielin/Downloads/NWU/25corl/corl_ws/ObsScaling/eval_results/behavior_{eaval_type}_results_with_flops_and_openllm.csv"
behavior_eval = pd.read_csv(behavior_eval_path)

# Merge behavior metrics into virtualhome dataframe based on Model column
# First, ensure behavior models are all in virtualhome models
print(f"Before merge - VirtualHome models: {len(virtualhome_eval)}")
print(f"Before merge - Behavior models: {len(behavior_eval)}")

# Check which behavior models are not in virtualhome
behavior_models_not_in_virtualhome = set(behavior_eval['Model']) - set(virtualhome_eval['Model'])
if behavior_models_not_in_virtualhome:
    print(f"Warning: {len(behavior_models_not_in_virtualhome)} behavior models not found in virtualhome:")
    for model in sorted(behavior_models_not_in_virtualhome):
        print(f"  • {model}")

# Rename behavior metrics to avoid conflicts
behavior_metrics_to_add = [
    # Group 1: Success metrics (4 metrics)
    'task_success_rate', 'execution_success_rate', 'total_goal', 'state_goal',
    # Group 2: Goal achievement metrics (4 metrics)  
    'relation_goal', 'action_goal', 'parsing_error', 'hallucination_error',
    # Group 3: Error metrics (4 metrics)
    'wrong_order_error', 'missing_step_error', 'additional_step_error', 'affordance_error'
]

# Create a mapping for renaming behavior metrics
behavior_metric_rename_map = {}
for metric in behavior_metrics_to_add:
    if metric in behavior_eval.columns:
        behavior_metric_rename_map[metric] = f"behavior_{metric}"

# Rename behavior metrics
behavior_eval_renamed = behavior_eval.rename(columns=behavior_metric_rename_map)

# Select only the Model column and renamed behavior metrics for merging
behavior_columns_to_merge = ['Model'] + list(behavior_metric_rename_map.values())
behavior_eval_for_merge = behavior_eval_renamed[behavior_columns_to_merge]

# Merge behavior metrics into virtualhome dataframe
virtualhome_eval = virtualhome_eval.merge(
    behavior_eval_for_merge, 
    on='Model', 
    how='left'
)

print(f"After merge - Combined models: {len(virtualhome_eval)}")

# Display available columns after merge
print(f"\nAvailable columns after merge ({len(virtualhome_eval.columns)} total):")
print("=" * 80)
for i, col in enumerate(virtualhome_eval.columns):
    print(f"  {i+1:2d}. {col}")
print("=" * 80)

# Check for any missing behavior metrics after merge
missing_behavior_metrics = virtualhome_eval[list(behavior_metric_rename_map.values())].isnull().sum()
print(f"\nMissing behavior metrics after merge:")
for metric, missing_count in missing_behavior_metrics.items():
    if missing_count > 0:
        print(f"  • {metric}: {missing_count} missing values")

# Rename virtualhome metrics to avoid conflicts
virtualhome_metrics_to_rename = [
    # Group 1: Success metrics (4 metrics)
    'task_success_rate', 'execution_success_rate', 'total_goal', 'state_goal',
    # Group 2: Goal achievement metrics (4 metrics)  
    'relation_goal', 'action_goal', 'parsing_error', 'hallucination_error',
    # Group 3: Error metrics (4 metrics)
    'wrong_order_error', 'missing_step_error', 'additional_step_error', 'affordance_error'
]

# Create a mapping for renaming virtualhome metrics
virtualhome_metric_rename_map = {}
for metric in virtualhome_metrics_to_rename:
    if metric in virtualhome_eval.columns:
        virtualhome_metric_rename_map[metric] = f"virtualhome_{metric}"

# Rename virtualhome metrics
virtualhome_eval = virtualhome_eval.rename(columns=virtualhome_metric_rename_map)

print(f"\nRenamed metrics:")
print("VirtualHome metrics:")
for old_name, new_name in virtualhome_metric_rename_map.items():
    print(f"  • {old_name} → {new_name}")
print("Behavior metrics:")
for old_name, new_name in behavior_metric_rename_map.items():
    print(f"  • {old_name} → {new_name}")


# Update the metrics lists to use new names
VIRTUALHOME_Y_METRICS = [
    [f"{dataset}_{metric}" for dataset in ['virtualhome', 'behavior'] for metric in ['task_success_rate', 'execution_success_rate']],
    # Group 1: Success metrics (4 metrics)
    [f"{dataset}_{metric}" for dataset in ['virtualhome', 'behavior'] for metric in ['task_success_rate', 'execution_success_rate']],
    # Group 2: Goal achievement metrics (4 metrics)  
    [f"{dataset}_{metric}" for dataset in ['virtualhome', 'behavior'] for metric in ['task_success_rate', 'execution_success_rate']],
    # Group 3: Error metrics (4 metrics)
    [f"{dataset}_{metric}" for dataset in ['virtualhome'] for metric in ['total_goal', 'state_goal', 'relation_goal', 'action_goal']],
    [f"{dataset}_{metric}" for dataset in ['virtualhome'] for metric in ['total_goal', 'state_goal', 'relation_goal', 'action_goal']],
    [f"{dataset}_{metric}" for dataset in ['virtualhome'] for metric in ['total_goal', 'state_goal', 'relation_goal', 'action_goal']],
    [f"{dataset}_{metric}" for dataset in ['behavior'] for metric in ['total_goal', 'state_goal', 'relation_goal', 'action_goal']],
    [f"{dataset}_{metric}" for dataset in ['behavior'] for metric in ['total_goal', 'state_goal', 'relation_goal', 'action_goal']],
    [f"{dataset}_{metric}" for dataset in ['behavior'] for metric in ['total_goal', 'state_goal', 'relation_goal', 'action_goal']],
]

# Define X metrics - using Model Size as the primary predictor (3 identical predictors for 3 rows)
PLOT_X_METRICS_LIST = [
    ['Model Size (B)'], 
    ['FLOPs (1E21)'], 
    PC_METRIC_NUM_4, 
    ['Model Size (B)'],
    ['FLOPs (1E21)'],
    PC_METRIC_NUM_4,
    ['Model Size (B)'],
    ['FLOPs (1E21)'],
    PC_METRIC_NUM_4,
    ['Model Size (B)'],
    ['FLOPs (1E21)'],
    PC_METRIC_NUM_4,
]
#only works if x_metrics is not model or flops
special_x_metrics_mapping = {
    'behavior_task_success_rate': PC_METRIC_NUM_4,
    'behavior_execution_success_rate': PC_METRIC_NUM_4,
    'virtualhome_task_success_rate': PC_METRIC_NUM_3,
    'virtualhome_execution_success_rate': PC_METRIC_NUM_3,
    'behavior_total_goal': PC_METRIC_NUM_4,
    'behavior_state_goal': PC_METRIC_NUM_5,
    'behavior_relation_goal': PC_METRIC_NUM_5,
    'behavior_action_goal': PC_METRIC_NUM_5,
    'virtualhome_total_goal': PC_METRIC_NUM_3,
    'virtualhome_state_goal': PC_METRIC_NUM_3,
    'virtualhome_relation_goal': PC_METRIC_NUM_3,
    'virtualhome_action_goal': PC_METRIC_NUM_3,
}

# Update the metric map to use new names
VIRTUALHOME_Y_METRIC_MAP = {}
for old_metric, display_name in {
    'task_success_rate': "Task Success Rate",
    'execution_success_rate': "Execution Success Rate",
    'total_goal': "Total Goal",
    'state_goal': "State Goal",
    'relation_goal': "Relation Goal",
    'action_goal': "Action Goal",
    'parsing_error': "Parsing Error",
    'hallucination_error': "Hallucination Error",
    'wrong_order_error': "Wrong Order Error",
    'missing_step_error': "Missing Step Error",
    'additional_step_error': "Additional Step Error",
    'affordance_error': "Affordance Error"
}.items():
    VIRTUALHOME_Y_METRIC_MAP[f"virtualhome_{old_metric}"] = f"{display_name} (VirtualHome)"
    VIRTUALHOME_Y_METRIC_MAP[f"behavior_{old_metric}"] = f"{display_name} (Behavior)"

# Filter to only include models with valid Model Size data
virtualhome_eval = virtualhome_eval.dropna(subset=['FLOPs (1E21)'])
print(f"After filtering FLOPs: {len(virtualhome_eval)} models with meaningful performance data")

# remove models without open LLM data
virtualhome_eval =virtualhome_eval.dropna(subset=['Average'])
print(f"After filtering Average: {len(virtualhome_eval)} models with meaningful performance data")

# remove annoying models
virtualhome_eval = virtualhome_eval[~virtualhome_eval['Model'].str.contains('Yi-Coder-9B', case=False, na=False)]
print(f"After filtering Chat, Instruct, it: {len(virtualhome_eval)} models with meaningful performance data")

# Convert all error columns to success rates in one line
error_columns = [col for col in virtualhome_eval.columns if col.endswith('_error')]
virtualhome_eval[error_columns] = (100 - virtualhome_eval[error_columns])/100
print(f"Converted {len(error_columns)} error columns to success rates")

# Convert all rate columns to success rates in one line
rate_columns = [col for col in virtualhome_eval.columns if col.endswith('_rate') or col.endswith('_goal')]
virtualhome_eval[rate_columns] = virtualhome_eval[rate_columns]/100
print(f"Converted {len(rate_columns)} rate columns to success rates")

# Convert all metrics to success rates
##################################### pca metrics is very sensitive to the scale of the metric, convert it to success rate #####################################
convert_metrics_to_rate = ['Average', 'BBH', 'MATH Lvl 5', 'GPQA', 'MUSR', 'MMLU-PRO', 'IFEval']
virtualhome_eval[convert_metrics_to_rate] = virtualhome_eval[convert_metrics_to_rate]/100
print(f"Converted {len(convert_metrics_to_rate)} metrics to success rates")


# List the remaining models after filtering
print(f"\n📋 {len(virtualhome_eval)} Models remaining after filtering:")
print("=" * 80)
virtualhome_eval_sorted = virtualhome_eval.sort_values(['Model Family', 'Model'])
for _, row in virtualhome_eval_sorted.iterrows():
    # Check if Average column exists and has a value
    if 'Average' in row:
        if pd.notna(row['Average']):
            average_value = row['Average']
            average_status = f"✅ {average_value:.2f}"
        else:
            average_status = "❌ No open LLM data"
    else:
        average_status = "⚠️ No open LLM column"
    
    print(f"  • {row['Model']} (Family: {row['Model Family']}) {average_status}")
print("=" * 80)


pca_imputation_metrics = ['Average', 'BBH', 'MATH Lvl 5', 'GPQA', 'MUSR', 'MMLU-PRO', 'IFEval']
NONGSM_PCA_PREPROCESS_KWARGS['imputation_metrics'] = pca_imputation_metrics
# Setup configuration for VirtualHome analysis
VIRTUALHOME_DEFAULT_SETUP_KWARGS_PCA = {
    **NONGSM_PCA_PREPROCESS_KWARGS,

    "pca_metrics": pca_imputation_metrics,
    "ref_model_family": "Gemma-2", #"Gemma", #"Llama-2", #"Qwen", #Gemma looks better than llama, Qwen1.5 looks better than llama
    "stylize_data": True,
    "nonlinearity": "sigmoid-parametric", 

    # Metric normalization - normalize success rates to [0,1] range
    "y_metric_process_funcs": "minmax_norm",

    # Group markers by model family
    "df_groupby": 'Model Family',  
    
    # Regression: robust regression for handling outliers
    "reg_method": "robust",  
    "reg_kwargs": {"delta": 1.0},   # huber loss with delta=1.0 for normalized target within [0, 1]
    
    # No PCA needed since we're only using Model Size as predictor
    "apply_pca": True,
    "apply_imputation": True,

    # "split_method": "cutoff_by_FLOPs (1E21)",
    # "cutoff_threshold": 2000,
    "split_method": "cutoff_by_Model Size (B)",
    "cutoff_threshold": 40,  # Split at 70B parameters / 80

    # Override the default stylize_model_family to use all families in your dataset
    "stylize_model_family": virtualhome_eval['Model Family'].unique().tolist(),
}

# Setup configuration for VirtualHome analysis
VIRTUALHOME_DEFAULT_SETUP_KWARGS_NO_PCA = {
    # Metric normalization - normalize success rates to [0,1] range
    "y_metric_process_funcs": "minmax_norm",

    # Group markers by model family
    "df_groupby": 'Model Family',  
    
    # Regression: robust regression for handling outliers
    "reg_method": "robust",  
    "reg_kwargs": {"delta": 1.0},   # huber loss with delta=1.0 for normalized target within [0, 1]
    
    # No PCA needed since we're only using Model Size as predictor
    "apply_pca": False,
    "apply_imputation": False,
    
    # Override the default stylize_model_family to use all families in your dataset
    "stylize_model_family": virtualhome_eval['Model Family'].unique().tolist(),

    "split_method": "cutoff_by_Model Size (B)",
    "cutoff_threshold": 40,  # Split at 70B parameters / 80
}


# Task-specific configuration for VirtualHome metrics
VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS = {}    

# Configure each individual metric (not the groups)
all_individual_metrics = []
for group in VIRTUALHOME_Y_METRICS:
    all_individual_metrics.extend(group)

for y_metric in all_individual_metrics:
    VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS[y_metric] = {}
    VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS[y_metric]['y_metric_range'] = (0.0, 1.0)
    VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS[y_metric].update({
        "plot_adjust_kwargs": {"ylim": [-0.05, 1.05]}  # Allow some margin for visualization
    })
# Store all figures in a list
all_figures = []

for (x_metrics, y_metrics) in zip(PLOT_X_METRICS_LIST, VIRTUALHOME_Y_METRICS):       
    fig = plot_multi_scaling_predictions(
        virtualhome_eval, y_metrics, [x_metrics], 
        VIRTUALHOME_DEFAULT_SETUP_KWARGS_PCA if x_metrics in [PC_METRIC_NUM_3, PC_METRIC_NUM_2, PC_METRIC_NUM_1, PC_METRIC_NUM_4, PC_METRIC_NUM_5] else VIRTUALHOME_DEFAULT_SETUP_KWARGS_NO_PCA, 
        y_metric_specific_kwargs=VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS, 
        filter_model_family=None,  # Include all model families
        ymetric2title_map=VIRTUALHOME_Y_METRIC_MAP,
        plot_legend=True, legend_nrow=2,
        special_x_metrics_mapping=special_x_metrics_mapping,
    )
    
    # Store the figure in our list
    all_figures.append(fig)
    
    # Show the plot
    # plt.show()
    if not os.path.exists(f'plots/validate/{eaval_type}'):
        os.makedirs(f'plots/validate/{eaval_type}')
    filename = f'{"_".join(y_metrics)}_vs_{"_".join(x_metrics).replace("Model Size (B)", "Model_Size").replace("FLOPs (1E21)", "FLOPs")}'
    fig.savefig(f'plots/validate/{eaval_type}/{filename}.png', dpi=300, bbox_inches='tight')
    print(f'Saved plot to plots/validate/{eaval_type}/{filename}.png')