import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import json
import re
import copy

from utils import *

dataset_name = "virtualhome"  # virtualhome #behavior
eaval_type = "action_sequencing"

# Load VirtualHome action sequencing dataset
# virtualhome_eval_path = f"./eval_results/{dataset_name}_{eaval_type}_results.csv"
virtualhome_eval_path = f"/Users/qinjielin/Downloads/NWU/25corl/corl_ws/ObsScaling/eval_results/{dataset_name}_{eaval_type}_results_with_flops.csv"
virtualhome_eval = pd.read_csv(virtualhome_eval_path)

# Filter to only include models with valid Model Size data
virtualhome_eval = virtualhome_eval.dropna(subset=['FLOPs (1E21)'])
print(f"After filtering FLOPs: {len(virtualhome_eval)} models with meaningful performance data")

# remove models with Chat, Instruct, it, etc.
virtualhome_eval = virtualhome_eval[~virtualhome_eval['Model'].str.contains('Chat|Instruct|it|R1|Think', case=False, na=False)]
# virtualhome_eval = virtualhome_eval[virtualhome_eval['Model'].str.contains('Chat|Instruct|it', case=False, na=False)]
print(f"After filtering Chat, Instruct, it: {len(virtualhome_eval)} models with meaningful performance data")

# List the remaining models after filtering
print(f"\n📋 {len(virtualhome_eval)} Models remaining after filtering:")
print("=" * 80)
virtualhome_eval_sorted = virtualhome_eval.sort_values(['Model Family', 'Model'])
for _, row in virtualhome_eval_sorted.iterrows():
    print(f"  • {row['Model']} (Family: {row['Model Family']})")
print("=" * 80)

# Define Y metrics for VirtualHome action sequencing - organized into 3 groups of 4 metrics
VIRTUALHOME_Y_METRICS = [
    # Group 1: Success metrics (4 metrics)
    ['task_success_rate', 'execution_success_rate', 'total_goal', 'state_goal'],
    # Group 2: Goal achievement metrics (4 metrics)  
    ['relation_goal', 'action_goal', 'parsing_error', 'hallucination_error'],
    # Group 3: Error metrics (4 metrics)
    ['wrong_order_error', 'missing_step_error', 'additional_step_error', 'affordance_error']
]

VIRTUALHOME_Y_METRIC_MAP = {
    'task_success_rate': "Task Success Rate",
    'execution_success_rate': "Execution Success Rate",
    'total_goal': "Total Goal",
    'state_goal': "State Goal",
    'relation_goal': "Relation Goal",
    'action_goal': "Action Goal",
    'parsing_error': "Parsing Error",
    'hallucination_error': "Hallucination Error",
    'wrong_order_error': "Wrong Order Error",
    'missing_step_error': "Missing Step Error",
    'additional_step_error': "Additional Step Error",
    'affordance_error': "Affordance Error"
}

# Define X metrics - using Model Size as the primary predictor (3 identical predictors for 3 rows)
PLOT_X_METRICS_LIST = [
    ['FLOPs (1E21)'],  # Row 1: Success metrics
    ['FLOPs (1E21)'],  # Row 2: Goal achievement metrics
    ['FLOPs (1E21)']   # Row 3: Error metrics
]

# Setup configuration for VirtualHome analysis
VIRTUALHOME_DEFAULT_SETUP_KWARGS = {
    # Metric normalization - normalize success rates to [0,1] range
    "y_metric_process_funcs": "minmax_norm",

    # Group markers by model family
    "df_groupby": 'Model Family',  
    
    # Regression: robust regression for handling outliers
    "reg_method": "robust",  
    "reg_kwargs": {"delta": 1.0},   # huber loss with delta=1.0 for normalized target within [0, 1]
    
    # No PCA needed since we're only using Model Size as predictor
    "apply_pca": False,
    "apply_imputation": False,
    
    # Override the default stylize_model_family to use all families in your dataset
    "stylize_model_family": virtualhome_eval['Model Family'].unique().tolist(),
}

# Task-specific configuration for VirtualHome metrics
VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS = {}    

# Configure each individual metric (not the groups)
all_individual_metrics = []
for group in VIRTUALHOME_Y_METRICS:
    all_individual_metrics.extend(group)

for y_metric in all_individual_metrics:
    VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS[y_metric] = {}

    # Metric range for success rates and goals (0% to 100%)
    if 'error' in y_metric:
        # Error metrics: 0% to 100%
        VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS[y_metric]['y_metric_range'] = (0.0, 100.0)
    else:
        # Success and goal metrics: 0% to 100%
        VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS[y_metric]['y_metric_range'] = (0.0, 100.0)

    # Use cutoff by model size for train/test split
    VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS[y_metric].update({
        "split_method": "cutoff_by_Model Size (B)",
        "cutoff_threshold": 70,  # Split at 70B parameters
        "plot_adjust_kwargs": {"ylim": [-0.05, 1.05]}  # Allow some margin for visualization
    })

# Generate scaling plots for VirtualHome action sequencing
# This will create a 3x4 grid: 3 rows (groups) x 4 columns (metrics per group)
# Each row uses Model Size as the predictor for 4 different Y metrics

# Store all figures in a list
all_figures = []

for (x_metrics, y_metrics) in zip(PLOT_X_METRICS_LIST, VIRTUALHOME_Y_METRICS):       
    fig = plot_multi_scaling_predictions(
        virtualhome_eval, y_metrics, [x_metrics], 
        VIRTUALHOME_DEFAULT_SETUP_KWARGS, 
        y_metric_specific_kwargs=VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS, 
        filter_model_family=None,  # Include all model families
        ymetric2title_map=VIRTUALHOME_Y_METRIC_MAP,
        plot_legend=True, legend_nrow=2,
    )
    
    # Store the figure in our list
    all_figures.append(fig)
    
    # Show the plot
    # plt.show()
    fig.savefig(f'plots/{dataset_name}/{eaval_type}/scaling_{"_".join(y_metrics)}_flops.png', dpi=300, bbox_inches='tight')