import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import json
import re
import copy

from utils import *

dataset_name = "virtualhome"  # virtualhome #behavior
eaval_type = "goal_interpretation" # "action_sequencing" # "action_sequencing_v4" # "goal_interpretation" "goal_interpretation_v4"

# Load VirtualHome action sequencing dataset
# virtualhome_eval_path = f"./eval_results/{dataset_name}_{eaval_type}_results.csv"
# virtualhome_eval_path = f"/Users/qinjielin/Downloads/NWU/25corl/corl_ws/ObsScaling/eval_results/{dataset_name}_{eaval_type}_results_with_flops.csv"
# virtualhome_eval = pd.read_csv(virtualhome_eval_path)
virtualhome_eval = pd.read_csv(f'./eval_results/{dataset_name}_{eaval_type}_results_with_flops_and_openllm.csv')

openllm_metrics = ['MMLU', 'ARC-C', 'HellaSwag', 'Winograd', 'TruthfulQA', 'XWinograd', 'HumanEval']
bench_df = load_base_llm_benchmark_eval()
virtualhome_eval = pd.merge(
    virtualhome_eval, 
    bench_df[['Model'] + openllm_metrics], 
    on='Model', 
    how='left'
)
# virtualhome_eval = virtualhome_eval.dropna(subset=openllm_metrics)
print(f"\nMerged DataFrame shape: {virtualhome_eval.shape}")
print(f"Models with ObsScaling benchmark data: {virtualhome_eval[openllm_metrics[0]].notna().sum()}")
print(f"Models without ObsScaling benchmark data: {virtualhome_eval[openllm_metrics[0]].isna().sum()}")

# Filter to only include models with valid Model Size data
virtualhome_eval = virtualhome_eval.dropna(subset=['FLOPs (1E21)'])
# print(f"After filtering FLOPs: {len(virtualhome_eval)} models with meaningful performance data")

# remove models without open LLM data
# virtualhome_eval = virtualhome_eval[virtualhome_eval['Model'].str.contains('llama|qwen|gemma|deepseek', case=False, na=False)]
# virtualhome_eval = virtualhome_eval[~virtualhome_eval['Model'].str.contains('deepseek-coder|gemma-3|Qwen3|Qwen-|Llama-4|Qwen/Qwen1.5-72B', case=False, na=False)] # these don't have LLM data
virtualhome_eval =virtualhome_eval.dropna(subset=['Average'])
virtualhome_eval = virtualhome_eval.dropna(subset=['BBH'])
virtualhome_eval = virtualhome_eval.dropna(subset=['MATH Lvl 5'])
virtualhome_eval = virtualhome_eval.dropna(subset=['GPQA'])
virtualhome_eval = virtualhome_eval.dropna(subset=['MUSR'])
virtualhome_eval = virtualhome_eval.dropna(subset=['IFEval'])
virtualhome_eval = virtualhome_eval.dropna(subset=['MMLU-PRO'])
print(f"After filtering Average: {len(virtualhome_eval)} models with meaningful performance data")

if dataset_name == 'virtualhome':
    virtualhome_eval = virtualhome_eval[virtualhome_eval['all_f1']!=0]
elif dataset_name == 'behavior':
    virtualhome_eval = virtualhome_eval[virtualhome_eval['overall_f1']!=0]
    virtualhome_eval = virtualhome_eval.dropna(subset=['overall_f1'])
else:
    raise ValueError(f"Invalid dataset name: {dataset_name}")
print(f"After filtering all_f1: {len(virtualhome_eval)} models with meaningful performance data")

# remove instruct-like models
# virtualhome_eval = virtualhome_eval[~virtualhome_eval['Model'].str.contains('Chat|Instruct|-it|R1|Think', case=False, na=False)]
# virtualhome_eval = virtualhome_eval[virtualhome_eval['Model'].str.contains('Chat|Instruct|-it', case=False, na=False)]
virtualhome_eval = virtualhome_eval[~virtualhome_eval['Model'].str.contains('chuan', case=False, na=False)]
print(f"After filtering Chat, Instruct, it: {len(virtualhome_eval)} models with meaningful performance data")

# Convert all rate columns to success rates in one line
rate_columns = [col for col in virtualhome_eval.columns if col.startswith('node_') or col.startswith('edge_') or col.startswith('action_') or col.startswith('all_') or 'f1' in col or col.endswith('rate')]
virtualhome_eval[rate_columns] = virtualhome_eval[rate_columns]/100
print(f"Converted {len(rate_columns)} rate columns to success rates")

# Convert all metrics to success rates
##################################### pca metrics is very sensitive to the scale of the metric, convert it to success rate #####################################
convert_metrics_to_rate = ['Average', 'BBH', 'MATH Lvl 5', 'GPQA', 'MUSR', 'MMLU-PRO', 'IFEval']
virtualhome_eval[convert_metrics_to_rate] = virtualhome_eval[convert_metrics_to_rate]/100
print(f"Converted {len(convert_metrics_to_rate)} metrics to success rates")


# List the remaining models after filtering
print(f"\n📋 {len(virtualhome_eval)} Models remaining after filtering:")
print("=" * 80)
virtualhome_eval_sorted = virtualhome_eval.sort_values(['Model Family', 'Model'])
for _, row in virtualhome_eval_sorted.iterrows():
    # Check if Average column exists and has a value
    if 'Average' in row:
        if pd.notna(row['Average']):
            average_value = row['Average']
            average_status = f"✅ {average_value:.2f}"
        else:
            average_status = "❌ No open LLM data"
    else:
        average_status = "⚠️ No open LLM column"
    
    print(f"  • {row['Model']} (Family: {row['Model Family']}) {average_status}")
print("=" * 80)

# Define Y metrics for VirtualHome action sequencing - organized into 3 groups of 4 metrics
if dataset_name == 'virtualhome':
    VIRTUALHOME_Y_METRICS = [
        ['node_f1', 'edge_f1', 'action_f1', 'all_f1'],
        ['node_f1', 'edge_f1', 'action_f1', 'all_f1'],
        ['node_f1', 'edge_f1', 'action_f1', 'all_f1'],
        ['node_f1', 'edge_f1', 'action_f1', 'all_f1'],
        ['node_f1', 'edge_f1', 'action_f1', 'all_f1'],
        ['node_precision', 'edge_precision', 'action_precision', 'all_precision'],
        ['node_recall', 'edge_recall', 'action_recall', 'all_recall'],
    ]
elif dataset_name == 'behavior':
    VIRTUALHOME_Y_METRICS = [
        ['overall_f1'],
        ['overall_f1'],
        ['overall_f1'],
        ['overall_f1'],
        ['overall_f1'],
        ['state_goal_f1', 'relation_goal_f1'],
        ['state_hallucination_rate', 'object_hallucination_rate', 'format_error_rate', 'grammatically_valid_rate'],
    ]
else:
    raise ValueError(f"Invalid dataset name: {dataset_name}")

VIRTUALHOME_Y_METRIC_MAP = {
    'task_success_rate': "Task Success Rate",
    'execution_success_rate': "Execution Success Rate",
    'total_goal': "Total Goal",
    'state_goal': "State Goal",
    'relation_goal': "Relation Goal",
    'action_goal': "Action Goal",
    'parsing_error': "Parsing Error",
    'hallucination_error': "Hallucination Error",
    'wrong_order_error': "Wrong Order Error",
    'missing_step_error': "Missing Step Error",
    'additional_step_error': "Additional Step Error",
    'affordance_error': "Affordance Error",

    'node_f1': "Node F1",
    'edge_f1': "Edge F1",
    'action_f1': "Action F1",
    'all_f1': "Overall F1",
    'node_precision': "Node Precision",
    'edge_precision': "Edge Precision",
    'action_precision': "Action Precision",
    'all_precision': "Overall Precision",
    'node_recall': "Node Recall",
    'edge_recall': "Edge Recall",
    'action_recall': "Action Recall",
    'all_recall': "Overall Recall",

    'overall_f1': "Overall F1",
    'state_goal_f1': "State Goal F1",
    'relation_goal_f1': "Relation Goal F1",
    'state_hallucination_rate': "State Hallucination Rate",
    'object_hallucination_rate': "Object Hallucination Rate",
    'format_error_rate': "Format Error Rate",
    'grammatically_valid_rate': "Grammatically Valid Rate",
}

# Define X metrics - using Model Size as the primary predictor (3 identical predictors for 3 rows)
# PLOT_X_METRICS_LIST = [
#     ['FLOPs (1E21)'],  # Row 1: Success metrics
#     ['FLOPs (1E21)'],  # Row 2: Goal achievement metrics
#     ['FLOPs (1E21)']   # Row 3: Error metrics
# ]
PLOT_X_METRICS_LIST = [
    [PC_METRIC_NUM_5],  # Row 1: Success metrics
    [PC_METRIC_NUM_4],  # Row 2: Goal achievement metrics
    [PC_METRIC_NUM_3],   # Row 3: Error metrics #good
    [PC_METRIC_NUM_2],   # Row 3: Error metrics #good
    [PC_METRIC_NUM_1],   # Row 3: Error metrics
    [PC_METRIC_NUM_3],   # Row 4: Error metrics
    [PC_METRIC_NUM_3]   # Row 5: Error metrics
]

# pca_imputation_metrics = ['total_goal', 'state_goal', 'relation_goal', 'action_goal', 'parsing_error', 'hallucination_error', 'wrong_order_error', 'missing_step_error', 'additional_step_error', 'affordance_error']
pca_imputation_metrics = ['Average', 'BBH', 'MATH Lvl 5', 'GPQA', 'MUSR', 'MMLU-PRO', 'IFEval']
# pca_imputation_metrics = ['MMLU', 'ARC-C', 'HellaSwag', 'Winograd', 'TruthfulQA', 'XWinograd', 'HumanEval']
NONGSM_PCA_PREPROCESS_KWARGS['imputation_metrics'] = pca_imputation_metrics

# Setup configuration for VirtualHome analysis
VIRTUALHOME_DEFAULT_SETUP_KWARGS = {
    **NONGSM_PCA_PREPROCESS_KWARGS,

    "pca_metrics": pca_imputation_metrics,
    "ref_model_family": "Llama-3",  #"DeepSeek", #"Qwen1.5", #"Llama-2", #"Qwen",
    "stylize_data": False,
    "nonlinearity": "sigmoid-parametric", 

    # Metric normalization - normalize success rates to [0,1] range
    "y_metric_process_funcs": "minmax_norm",

    # Group markers by model family
    "df_groupby": 'Model Family',  
    
    # Regression: robust regression for handling outliers
    "reg_method": "robust",  
    "reg_kwargs": {"delta": 1.0},   # huber loss with delta=1.0 for normalized target within [0, 1]
    
    # No PCA needed since we're only using Model Size as predictor
    "apply_pca": True,
    "apply_imputation": True,

    # "split_method": "cutoff_by_FLOPs (1E21)",
    # "cutoff_threshold": 2000,
    "split_method": "cutoff_by_Model Size (B)",
    "cutoff_threshold": 40,  # Split at 70B parameters / 80

    # Override the default stylize_model_family to use all families in your dataset
    # "stylize_model_family": virtualhome_eval['Model Family'].unique().tolist(),
}

# Task-specific configuration for VirtualHome metrics
VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS = {}    

# Configure each individual metric (not the groups)
all_individual_metrics = []
for group in VIRTUALHOME_Y_METRICS:
    all_individual_metrics.extend(group)

for y_metric in all_individual_metrics:
    VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS[y_metric] = {}

    # Metric range for success rates and goals (0% to 100%)
    if 'error' in y_metric:
        # Error metrics: 0% to 100%
        VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS[y_metric]['y_metric_range'] = (0.0, 1.0)
    else:
        # Success and goal metrics: 0% to 100%
        VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS[y_metric]['y_metric_range'] = (0.0, 1.0)

    # Use cutoff by model size for train/test split
    VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS[y_metric].update({
        # "split_method": "cutoff_by_Model Size (B)",
        # "cutoff_threshold": 70,  # Split at 70B parameters
        # "split_method": "cutoff_by_FLOPs (1E21)",
        # "cutoff_threshold": 2000, #1500,
        "plot_adjust_kwargs": {"ylim": [-0.05, 1.05]}  # Allow some margin for visualization
    })

# Generate scaling plots for VirtualHome action sequencing
# This will create a 3x4 grid: 3 rows (groups) x 4 columns (metrics per group)
# Each row uses Model Size as the predictor for 4 different Y metrics

# Store all figures in a list
all_figures = []

for (x_metrics, y_metrics) in zip(PLOT_X_METRICS_LIST, VIRTUALHOME_Y_METRICS):
    fig = plot_multi_scaling_predictions(
        virtualhome_eval, y_metrics, x_metrics, 
        VIRTUALHOME_DEFAULT_SETUP_KWARGS, 
        y_metric_specific_kwargs=VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS, 
        filter_model_family=None,  # Include all model families
        ymetric2title_map=VIRTUALHOME_Y_METRIC_MAP,
        plot_legend=True, legend_nrow=2,
    )
    
    # Store the figure in our list
    all_figures.append(fig)
    
    # Show the plot
    plt.show()
    # fig.savefig(f'plots/{dataset_name}/{eaval_type}/scaling_{"_".join(y_metrics)}_flops.png', dpi=300, bbox_inches='tight')
    # print(f"Saved plot to plots/{dataset_name}/{eaval_type}/scaling_{'_'.join(y_metrics)}_flops.png")