import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import json
import re
import copy

from utils import *

dataset_name = "virtualhome"
eaval_type = "goal_interpretation_v4"

# Load VirtualHome goal interpretation dataset
virtualhome_eval_path = f"./eval_results/{dataset_name}_{eaval_type}_results.csv"
virtualhome_eval = pd.read_csv(virtualhome_eval_path)

# Filter to only include models with valid Model Size data
virtualhome_eval = virtualhome_eval.dropna(subset=['Model Size (B)'])

# Filter out rows where all goal interpretation metrics are 0.0
goal_metrics = ['node_precision', 'edge_precision', 'action_precision', 'all_precision',
                'node_recall', 'edge_recall', 'action_recall', 'all_recall',
                'node_f1', 'edge_f1', 'action_f1', 'all_f1']

# Remove rows where all goal metrics are 0.0 or NaN
virtualhome_eval = virtualhome_eval[~(virtualhome_eval[goal_metrics] == 0.0).all(axis=1)]
virtualhome_eval = virtualhome_eval.dropna(subset=goal_metrics, how='all')

print(f"After filtering: {len(virtualhome_eval)} models with meaningful performance data")

# Define Y metrics for VirtualHome goal interpretation - organized into 3 groups of 4 metrics
VIRTUALHOME_Y_METRICS = [
    # Group 1: Precision metrics (4 metrics)
    ['node_precision', 'edge_precision', 'action_precision', 'all_precision'],
    # Group 2: Recall metrics (4 metrics)  
    ['node_recall', 'edge_recall', 'action_recall', 'all_recall'],
    # Group 3: F1 metrics (4 metrics)
    ['node_f1', 'edge_f1', 'action_f1', 'all_f1']
]


VIRTUALHOME_Y_METRIC_MAP = {
    'node_precision': "Node Precision",
    'edge_precision': "Edge Precision",
    'action_precision': "Action Precision",
    'all_precision': "Overall Precision",
    'node_recall': "Node Recall",
    'edge_recall': "Edge Recall",
    'action_recall': "Action Recall",
    'all_recall': "Overall Recall",
    'node_f1': "Node F1",
    'edge_f1': "Edge F1",
    'action_f1': "Action F1",
    'all_f1': "Overall F1"
}

# Define X metrics - using Model Size as the primary predictor (3 identical predictors for 3 rows)
PLOT_X_METRICS_LIST = [
    ['Model Size (B)'],  # Row 1: Success metrics
    ['Model Size (B)'],  # Row 2: Goal achievement metrics
    ['Model Size (B)']   # Row 3: Error metrics
]

# Setup configuration for VirtualHome goal interpretation analysis
VIRTUALHOME_DEFAULT_SETUP_KWARGS = {
    # Metric normalization - normalize precision/recall/F1 scores to [0,1] range
    "y_metric_process_funcs": "minmax_norm",

    # Group markers by model family
    "df_groupby": 'Model Family',  
    
    # Regression: robust regression for handling outliers
    "reg_method": "robust",  
    "reg_kwargs": {"delta": 1.0},   # huber loss with delta=1.0 for normalized target within [0, 1]
    
    # No PCA needed since we're only using Model Size as predictor
    "apply_pca": False,
    "apply_imputation": False,
    
    # Override the default stylize_model_family to use all families in your dataset
    "stylize_model_family": virtualhome_eval['Model Family'].unique().tolist(),
}

# Task-specific configuration for VirtualHome goal interpretation metrics
VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS = {}    

# Configure each individual metric (not the groups)
all_individual_metrics = []
for group in VIRTUALHOME_Y_METRICS:
    all_individual_metrics.extend(group)

for y_metric in all_individual_metrics:
    VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS[y_metric] = {}

    # Metric range for precision, recall, and F1 scores (0% to 100%)
    # All goal interpretation metrics are percentages from 0-100
    VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS[y_metric]['y_metric_range'] = (0.0, 100.0)

    # Use cutoff by model size for train/test split
    VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS[y_metric].update({
        "split_method": "cutoff_by_Model Size (B)",
        "cutoff_threshold": 70,  # Split at 70B parameters
        "plot_adjust_kwargs": {"ylim": [-0.05, 1.05]}  # Allow some margin for visualization
        # "plot_adjust_kwargs": {"ylim": [-5, 105]}  # Allow some margin for visualization
    })

# Generate scaling plots for VirtualHome goal interpretation
# This will create a 3x4 grid: 3 rows (groups) x 4 columns (metrics per group)
# Each row uses Model Size as the predictor for 4 different Y metrics

# Store all figures in a list
all_figures = []

for (x_metrics, y_metrics) in zip(PLOT_X_METRICS_LIST, VIRTUALHOME_Y_METRICS):       
    fig = plot_multi_scaling_predictions(
        virtualhome_eval, y_metrics, [x_metrics], 
        VIRTUALHOME_DEFAULT_SETUP_KWARGS, 
        y_metric_specific_kwargs=VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS, 
        filter_model_family=None,  # Include all model families
        ymetric2title_map=VIRTUALHOME_Y_METRIC_MAP,
        plot_legend=True, legend_nrow=2,
    )
    
    # Store the figure in our list
    all_figures.append(fig)
    
    # Show the plot
    # plt.show()
    if not os.path.exists(f'plots/{dataset_name}/{eaval_type}'):
        os.makedirs(f'plots/{dataset_name}/{eaval_type}', exist_ok=True)
    fig.savefig(f'plots/{dataset_name}/{eaval_type}/scaling_{"_".join(y_metrics)}.png', dpi=300, bbox_inches='tight')