import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import json
import re
import copy

from utils import *

dataset_name = "virtualhome"  # virtualhome #behavior
eaval_type = "goal_interpretation"

# Load VirtualHome action sequencing dataset
df = pd.read_csv(f"/Users/qinjielin/Downloads/NWU/25corl/corl_ws/ObsScaling/eval_results/{dataset_name}_{eaval_type}_results_with_flops.csv")
masked_df = pd.read_csv(f"/Users/qinjielin/Downloads/NWU/25corl/corl_ws/ObsScaling/eval_results/{dataset_name}_{eaval_type}_v4_results_with_flops.csv")

# df = df[~df['Model'].str.contains('Chat|Instruct|it|R1|Think', case=False, na=False)]
# masked_df = masked_df[~masked_df['Model'].str.contains('Chat|Instruct|it|R1|Think', case=False, na=False)]


# Filter to only include models with valid Model Size data
df = df.dropna(subset=['FLOPs (1E21)'])
masked_df = masked_df.dropna(subset=['FLOPs (1E21)'])
print(f"After filtering FLOPs in original df: {len(df)} models with meaningful performance data")
print(f"After filtering FLOPs in masked df: {len(masked_df)} models with meaningful performance data")

# Convert percentage metrics to 0-1 range
columns_to_convert_to_percentage = ['node_precision', 'edge_precision', 'action_precision', 'all_precision',
                'node_recall', 'edge_recall', 'action_recall', 'all_recall',
                'node_f1', 'edge_f1', 'action_f1', 'all_f1']
for metric in columns_to_convert_to_percentage:
    if metric in df.columns:
        df[metric] = df[metric] / 100.0
for metric in columns_to_convert_to_percentage:
    if metric in masked_df.columns:
        masked_df[metric] = masked_df[metric] / 100.0

######################################################## merge the two dataframes########################################################
# Rename columns in base dataframe to add ",base" suffix
base_columns_rename = {}
for col in df.columns:
    if col in ['Model', 'Model Family', 'FLOPs (1E21)', 'Model Size (B)', 'Pretraining Data Size (T)']:  # Keep key columns unchanged
        base_columns_rename[col] = col
    else:
        base_columns_rename[col] = f"{col},base"
df = df.rename(columns=base_columns_rename)

# Rename columns in instruction-tuned dataframe to add ",it" suffix
it_columns_rename = {}
for col in masked_df.columns:
    if col in ['Model', 'Model Family', 'FLOPs (1E21)', 'Model Size (B)', 'Pretraining Data Size (T)']:  # Keep key columns unchanged
        it_columns_rename[col] = col
    else:
        it_columns_rename[col] = f"{col},masked"
masked_df = masked_df.rename(columns=it_columns_rename)

# Merge the two dataframes on Model column
# This will create a wide format with separate columns for base and it versions
merged_df = pd.merge(
    df, 
    masked_df, 
    on=['Model', 'Model Family', 'FLOPs (1E21)', 'Model Size (B)', 'Pretraining Data Size (T)'], 
    how='outer'
)

print(f"Merged dataframe shape: {merged_df.shape}")
print(f"Columns in merged dataframe: {merged_df.columns.tolist()}")

########################################################configuring the plotting########################################################
# Define Y metrics for VirtualHome action sequencing - organized into 3 groups of 4 metrics
df_y_metrics = [
    # Group 1: Success metrics (4 metrics)
    ['node_precision', 'edge_precision', 'action_precision', 'all_precision'],
    # Group 2: Goal achievement metrics (4 metrics)  
    ['node_recall', 'edge_recall', 'action_recall', 'all_recall'],
    # Group 3: Error metrics (4 metrics)
    ['node_f1', 'edge_f1', 'action_f1', 'all_f1']
]

# Update the metric map to include both base and it versions
df_y_metric_map = {}
for group in df_y_metrics:
    for metric in group:
        df_y_metric_map[f"{metric},base"] = f"Base Model"
        df_y_metric_map[f"{metric},masked"] = f"Model with Decoder Masking"

# Define color maps for base vs instruction-tuned models
df_base_color_map = {}
df_masked_color_map = {}

for group in df_y_metrics:
    for metric in group:
        df_base_color_map[f"{metric},base"] = '#1f77b4'  # Blue for base
        df_masked_color_map[f"{metric},masked"] = '#ff7f0e'      # Orange for IT

# Define X metrics - using Model Size as the primary predictor
df_x_metrics_list = [
    ['FLOPs (1E21)'],  # Row 1: Success metrics
    ['FLOPs (1E21)'],  # Row 2: Goal achievement metrics
    ['FLOPs (1E21)']   # Row 3: Error metrics
]

# Setup configuration for VirtualHome analysis
df_default_setup_kwargs = {
    # **NONGSM_PCA_PREPROCESS_KWARGS,  # exclude GSM to avoid making the task trivial, e.g., using GSM to predict GSM

    # Metric normalization - normalize success rates to [0,1] range
    # "y_metric_process_funcs": "none",

    # Group markers by model family
    "df_groupby": 'Model Family',  
    
    # Regression: robust regression for handling outliers
    "reg_method": "robust",  
    "reg_kwargs": {"delta": 1.0},   # huber loss with delta=1.0 for normalized target within [0, 1]
    
    # No PCA needed since we're only using Model Size as predictor
    "apply_pca": False,
    "apply_imputation": False,
    
    # Override the default stylize_model_family to use all families in your dataset
    "stylize_model_family": merged_df['Model Family'].unique().tolist(),

    # "nonlinearity": "sigmoid-parametric", 

}

# Task-specific configuration for VirtualHome metrics
df_eval_setup_specific_kwargs = {}    

# Configure each individual metric (both base and it versions)
all_individual_metrics = []
for group in df_y_metrics:
    for metric in group:
        all_individual_metrics.extend([f"{metric},base", f"{metric},masked"])

for y_metric in all_individual_metrics:
    df_eval_setup_specific_kwargs[y_metric] = {}

    # Metric range is now 0.0 to 1.0 (converted from percentages)
    df_eval_setup_specific_kwargs[y_metric]['y_metric_range'] = (0.0, 1.0)

    # Use cutoff by model size for train/test split
    df_eval_setup_specific_kwargs[y_metric].update({
        "split_method": "cutoff_by_Model Size (B)",
        "cutoff_threshold": 70,  # Split at 70B parameters
        # "split_method": "cutoff_by_FLOPs (1E21)",
        # "cutoff_threshold": 10000,
        "plot_adjust_kwargs": {"ylim": [-0.05, 1.05]}  # Allow some margin for visualization
    })

# Generate scaling plots for VirtualHome action sequencing with comparison
# This will create a 3x4 grid: 3 rows (groups) x 4 columns (metrics per group)
# Each row uses Model Size as the predictor for 4 different Y metrics

# Store all figures in a list
all_figures = []

for (x_metrics, y_metrics) in zip(df_x_metrics_list, df_y_metrics):       
    # Create list of metrics for this group (both base and it versions)
    group_metrics = []
    for metric in y_metrics:
        group_metrics.extend([[f"{metric},base", f"{metric},masked"]])
    
    # Plot using the merged dataframe with plot_scaling_comparison_multi_metrics
    for pair in group_metrics:
      fig = plot_scaling_comparison_multi_metrics(
          merged_df,                    # Your merged dataframe
          pair,                         # List of 8 metrics (4 base + 4 it)
          [['FLOPs (1E21)']],                           # Predictor: [['FLOPs (1E21)']]
          df_default_setup_kwargs,      # General setup
          y_metric_specific_kwargs=df_eval_setup_specific_kwargs,  # Per-metric settings
          ymetric2title_map=df_y_metric_map,                        # Metric → Title mapping
          ymetric2color_map={**df_base_color_map, **df_masked_color_map},  # Combined color map
          plot_title=f"{', '.join(set([m.replace(',base', '').replace(',masked', '') for m in pair]))}"
      )
    
      # Store the figure in our list
      all_figures.append(fig)
      
      # Save the plot
      if not os.path.exists(f'plots/{dataset_name}/{eaval_type}/comparison'):
        os.makedirs(f'plots/{dataset_name}/{eaval_type}/comparison')
      plt.show()
      fig.savefig(f'plots/{dataset_name}/{eaval_type}/comparison/{"_".join(pair)}.png', dpi=300, bbox_inches='tight')
      print(f"Saved comparison plot to plots/{dataset_name}/{eaval_type}/comparison/{'_'.join(pair)}.png")

print(f"\nGenerated {len(all_figures)} comparison plots")
