import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import json
import re
import copy

from utils import *

dataset_name = "virtualhome"  # virtualhome #behavior
eaval_type = "action_sequencing"

# Load VirtualHome action sequencing dataset
# virtualhome_eval_path = f"./eval_results/{dataset_name}_{eaval_type}_results.csv"
virtualhome_eval_path = f"/Users/qinjielin/Downloads/NWU/25corl/corl_ws/ObsScaling/eval_results/{dataset_name}_{eaval_type}_results_with_flops.csv"
virtualhome_eval = pd.read_csv(virtualhome_eval_path)

# Filter to only include models with valid Model Size data
virtualhome_eval = virtualhome_eval.dropna(subset=['FLOPs (1E21)'])
print(f"After filtering FLOPs: {len(virtualhome_eval)} models with meaningful performance data")

columns_to_convert_to_percentage = [
    # Group 1: Success metrics (4 metrics)
    ['task_success_rate', 'execution_success_rate', 'total_goal', 'state_goal'],
    # Group 2: Goal achievement metrics (4 metrics)  
    ['relation_goal', 'action_goal', 'parsing_error', 'hallucination_error'],
    # Group 3: Error metrics (4 metrics)
    ['wrong_order_error', 'missing_step_error', 'additional_step_error', 'affordance_error']
]

# Convert percentage metrics to 0-1 range
for group in columns_to_convert_to_percentage:
    for metric in group:
        if metric in virtualhome_eval.columns:
            virtualhome_eval[metric] = virtualhome_eval[metric] / 100.0

########################################################
# devide into base and it
# virtualhome_eval = virtualhome_eval[virtualhome_eval['Model'].str.contains('llama|Qwen', case=False, na=False)]
virtualhome_eval_base = virtualhome_eval[~virtualhome_eval['Model'].str.contains('Chat|Instruct|it|R1|Think', case=False, na=False)]
virtualhome_eval_it = virtualhome_eval[virtualhome_eval['Model'].str.contains('Chat|Instruct|-it', case=False, na=False)]
print(f"After filtering Chat, Instruct, it: {len(virtualhome_eval)} models with meaningful performance data")

# List the remaining models after filtering
print(f"\n📋 {len(virtualhome_eval_base)} Models remaining for base LLM")
print("=" * 80)
virtualhome_eval_sorted = virtualhome_eval_base.sort_values(['Model Family', 'Model'])
for _, row in virtualhome_eval_sorted.iterrows():
    print(f"  • {row['Model']} (Family: {row['Model Family']})")
print("=" * 80)

# List the remaining models after filtering
print(f"\n📋 {len(virtualhome_eval_it)} Models remaining for instruction-tuned LLM")
print("=" * 80)
virtualhome_eval_sorted = virtualhome_eval_it.sort_values(['Model Family', 'Model'])
for _, row in virtualhome_eval_sorted.iterrows():
    print(f"  • {row['Model']} (Family: {row['Model Family']})")
print("=" * 80)

# Rename columns in base dataframe to add ",base" suffix
base_columns_rename = {}
for col in virtualhome_eval_base.columns:
    if col in ['Model', 'Model Family', 'FLOPs (1E21)', 'Model Size (B)', 'Pretraining Data Size (T)']:  # Keep key columns unchanged
        base_columns_rename[col] = col
    else:
        base_columns_rename[col] = f"{col},base"

virtualhome_eval_base = virtualhome_eval_base.rename(columns=base_columns_rename)

# Rename columns in instruction-tuned dataframe to add ",it" suffix
it_columns_rename = {}
for col in virtualhome_eval_it.columns:
    if col in ['Model', 'Model Family', 'FLOPs (1E21)', 'Model Size (B)', 'Pretraining Data Size (T)']:  # Keep key columns unchanged
        it_columns_rename[col] = col
    else:
        it_columns_rename[col] = f"{col},it"

virtualhome_eval_it = virtualhome_eval_it.rename(columns=it_columns_rename)

# Merge the two dataframes on Model column
# This will create a wide format with separate columns for base and it versions
merged_virtualhome = pd.merge(
    virtualhome_eval_base, 
    virtualhome_eval_it, 
    on=['Model', 'Model Family', 'FLOPs (1E21)', 'Model Size (B)', 'Pretraining Data Size (T)'], 
    how='outer'
)

print(f"Merged dataframe shape: {merged_virtualhome.shape}")
print(f"Columns in merged dataframe: {merged_virtualhome.columns.tolist()}")

# Remove the debug breakpoint
# import pdb; pdb.set_trace()

# Define Y metrics for VirtualHome action sequencing - organized into 3 groups of 4 metrics
VIRTUALHOME_Y_METRICS = [
    # Group 1: Success metrics (4 metrics)
    ['task_success_rate', 'execution_success_rate', 'total_goal', 'state_goal'],
    # Group 2: Goal achievement metrics (4 metrics)  
    ['relation_goal', 'action_goal', 'parsing_error', 'hallucination_error'],
    # Group 3: Error metrics (4 metrics)
    ['wrong_order_error', 'missing_step_error', 'additional_step_error', 'affordance_error']
]

# Update the metric map to include both base and it versions
VIRTUALHOME_Y_METRIC_MAP = {}
for group in VIRTUALHOME_Y_METRICS:
    for metric in group:
        VIRTUALHOME_Y_METRIC_MAP[f"{metric},base"] = f"Base Model"
        VIRTUALHOME_Y_METRIC_MAP[f"{metric},it"] = f"Instruction-tuned Model"

# Define color maps for base vs instruction-tuned models
VIRTUALHOME_BASE_COLOR_MAP = {}
VIRTUALHOME_IT_COLOR_MAP = {}

for group in VIRTUALHOME_Y_METRICS:
    for metric in group:
        VIRTUALHOME_BASE_COLOR_MAP[f"{metric},base"] = '#1f77b4'  # Blue for base
        VIRTUALHOME_IT_COLOR_MAP[f"{metric},it"] = '#ff7f0e'      # Orange for IT

# Define X metrics - using Model Size as the primary predictor
PLOT_X_METRICS_LIST = [
    ['FLOPs (1E21)'],  # Row 1: Success metrics
    ['FLOPs (1E21)'],  # Row 2: Goal achievement metrics
    ['FLOPs (1E21)']   # Row 3: Error metrics
]

# Setup configuration for VirtualHome analysis
VIRTUALHOME_DEFAULT_SETUP_KWARGS = {
    # **NONGSM_PCA_PREPROCESS_KWARGS,  # exclude GSM to avoid making the task trivial, e.g., using GSM to predict GSM

    # Metric normalization - normalize success rates to [0,1] range
    # "y_metric_process_funcs": "minmax_norm",

    # Group markers by model family
    "df_groupby": 'Model Family',  
    
    # Regression: robust regression for handling outliers
    "reg_method": "robust",  
    "reg_kwargs": {"delta": 1.0},   # huber loss with delta=1.0 for normalized target within [0, 1]
    
    # No PCA needed since we're only using Model Size as predictor
    "apply_pca": False,
    "apply_imputation": False,
    
    # Override the default stylize_model_family to use all families in your dataset
    "stylize_model_family": merged_virtualhome['Model Family'].unique().tolist(),

    # "nonlinearity": "sigmoid-parametric", 

}

# Task-specific configuration for VirtualHome metrics
VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS = {}    

# Configure each individual metric (both base and it versions)
all_individual_metrics = []
for group in VIRTUALHOME_Y_METRICS:
    for metric in group:
        all_individual_metrics.extend([f"{metric},base", f"{metric},it"])

for y_metric in all_individual_metrics:
    VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS[y_metric] = {}

    # Metric range is now 0.0 to 1.0 (converted from percentages)
    VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS[y_metric]['y_metric_range'] = (0.0, 1.0)

    # Use cutoff by model size for train/test split
    VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS[y_metric].update({
        # Cutoff: 8.4E22 FLOPs corresponding to LLama-2 7B
        # "split_method": "cutoff_by_FLOPs (1E21)",
        # "cutoff_threshold": 10000,
        # "split_method": "cutoff_by_Model Size (B)",
        # "cutoff_threshold": 70,  # Split at 70B parameters
        "split_method": "cutoff_by_FLOPs (1E21)",
        "cutoff_threshold": 2000, #1500,
        "plot_adjust_kwargs": {"ylim": [-0.05, 1.05]}  # Allow some margin for visualization
    })

# Generate scaling plots for VirtualHome action sequencing with comparison
# This will create a 3x4 grid: 3 rows (groups) x 4 columns (metrics per group)
# Each row uses Model Size as the predictor for 4 different Y metrics

# Store all figures in a list
all_figures = []

for (x_metrics, y_metrics) in zip(PLOT_X_METRICS_LIST, VIRTUALHOME_Y_METRICS):       
    # Create list of metrics for this group (both base and it versions)
    group_metrics = []
    for metric in y_metrics:
        group_metrics.extend([[f"{metric},base", f"{metric},it"]])
    
    # Plot using the merged dataframe with plot_scaling_comparison_multi_metrics
    for pair in group_metrics:
      fig = plot_scaling_comparison_multi_metrics(
          merged_virtualhome,                    # Your merged dataframe
          pair,                         # List of 8 metrics (4 base + 4 it)
          [['FLOPs (1E21)']],                           # Predictor: [['FLOPs (1E21)']]
          VIRTUALHOME_DEFAULT_SETUP_KWARGS,      # General setup
          y_metric_specific_kwargs=VIRTUALHOME_EVAL_SETUP_SPECIFIC_KWARGS,  # Per-metric settings
          ymetric2title_map=VIRTUALHOME_Y_METRIC_MAP,                        # Metric → Title mapping
          ymetric2color_map={**VIRTUALHOME_BASE_COLOR_MAP, **VIRTUALHOME_IT_COLOR_MAP},  # Combined color map
          plot_title=f"{', '.join(set([m.replace(',base', '').replace(',it', '') for m in pair]))}"
      )
    
      # Store the figure in our list
      all_figures.append(fig)
      
      # Save the plot
      if not os.path.exists(f'plots/{dataset_name}/{eaval_type}/comparison'):
        os.makedirs(f'plots/{dataset_name}/{eaval_type}/comparison')
    #   plt.show()
      fig.savefig(f'plots/{dataset_name}/{eaval_type}/comparison/{"-".join(pair)}.png', dpi=300, bbox_inches='tight')
      print(f"Saved comparison plot to plots/{dataset_name}/{eaval_type}/comparison/{'-'.join(pair)}.png")

print(f"\nGenerated {len(all_figures)} comparison plots")
