import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import json
import re
import copy

from utils import *

dataset_name = "behavior"
eaval_type = "action_sequencing"

# Load VirtualHome action sequencing dataset
virtualhome_eval_path = f"./eval_results/{dataset_name}_{eaval_type}_results.csv"
virtualhome_eval = pd.read_csv(virtualhome_eval_path)

# Filter to only include models with valid Model Size data
virtualhome_eval = virtualhome_eval.dropna(subset=['Model Size (B)'])

# Define Y metrics for VirtualHome action sequencing - organized into 3 groups of 4 metrics
VIRTUALHOME_Y_METRICS = [
    # Group 1: Success metrics (4 metrics)
    ['task_success_rate', 'execution_success_rate', 'total_goal', 'state_goal'],
    # Group 2: Goal achievement metrics (4 metrics)  
    ['relation_goal', 'action_goal', 'parsing_error', 'hallucination_error'],
    # Group 3: Error metrics (4 metrics)
    ['wrong_order_error', 'missing_step_error', 'additional_step_error', 'affordance_error']
]

VIRTUALHOME_Y_METRIC_MAP = {
    'task_success_rate': "Task Success Rate",
    'execution_success_rate': "Execution Success Rate",
    'total_goal': "Total Goal",
    'state_goal': "State Goal",
    'relation_goal': "Relation Goal",
    'action_goal': "Action Goal",
    'parsing_error': "Parsing Error",
    'hallucination_error': "Hallucination Error",
    'wrong_order_error': "Wrong Order Error",
    'missing_step_error': "Missing Step Error",
    'additional_step_error': "Additional Step Error",
    'affordance_error': "Affordance Error"
}

# Define X metrics - using Model Size as the primary predictor (3 identical predictors for 3 rows)
PLOT_X_METRICS_LIST = [
    ['Model Size (B)'],  # Row 1: Success metrics
    ['Model Size (B)'],  # Row 2: Goal achievement metrics
    ['Model Size (B)']   # Row 3: Error metrics
]

# Setup configuration for VirtualHome analysis
VIRTUALHOME_DEFAULT_SETUP_KWARGS = {
    # Metric normalization - normalize success rates to [0,1] range
    "y_metric_process_funcs": "minmax_norm",

    # Group markers by model family
    "df_groupby": 'Model Family',  
    
    # Regression: robust regression for handling outliers
    "reg_method": "robust",  
    "reg_kwargs": {"delta": 1.0},   # huber loss with delta=1.0 for normalized target within [0, 1]
    
    # No PCA needed since we're only using Model Size as predictor
    "apply_pca": False,
    "apply_imputation": False,
    
    # Override the default stylize_model_family to use all families in your dataset
    "stylize_model_family": virtualhome_eval['Model Family'].unique().tolist(),
}

# Task-specific configuration for VirtualHome metrics
VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS = {}    

# Configure each individual metric (not the groups)
all_individual_metrics = []
for group in VIRTUALHOME_Y_METRICS:
    all_individual_metrics.extend(group)

for y_metric in all_individual_metrics:
    VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS[y_metric] = {}

    # Metric range for success rates and goals (0% to 100%)
    if 'error' in y_metric:
        # Error metrics: 0% to 100%
        VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS[y_metric]['y_metric_range'] = (0.0, 100.0)
    else:
        # Success and goal metrics: 0% to 100%
        VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS[y_metric]['y_metric_range'] = (0.0, 100.0)

    # Use cutoff by model size for train/test split
    VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS[y_metric].update({
        "split_method": "cutoff_by_Model Size (B)",
        "cutoff_threshold": 70,  # Split at 70B parameters
        "plot_adjust_kwargs": {"ylim": [-0.05, 1.05]}  # Allow some margin for visualization
    })

# Generate scaling plots for VirtualHome action sequencing
# This will create a 3x4 grid: 3 rows (groups) x 4 columns (metrics per group)
# Each row uses Model Size as the predictor for 4 different Y metrics

# Store all figures in a list
all_figures = []

for (x_metrics, y_metrics) in zip(PLOT_X_METRICS_LIST, VIRTUALHOME_Y_METRICS):       
    fig = plot_multi_scaling_predictions(
        virtualhome_eval, y_metrics, [x_metrics], 
        VIRTUALHOME_DEFAULT_SETUP_KWARGS, 
        y_metric_specific_kwargs=VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS, 
        filter_model_family=None,  # Include all model families
        ymetric2title_map=VIRTUALHOME_Y_METRIC_MAP,
        plot_legend=True, legend_nrow=2,
    )
    
    # Store the figure in our list
    all_figures.append(fig)
    
    # Show the plot
    # plt.show()
    fig.savefig(f'plots/{dataset_name}/{eaval_type}/scaling_{"_".join(y_metrics)}.png', dpi=300, bbox_inches='tight')