"""
Plot the linear scaling curves between the law of action sequencing and the masked observation of action sequencing
"""

import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import json
import re
import copy

from utils import *

dataset_name = "virtualhome"  # virtualhome #behavior
eaval_type = "action_sequencing"
pca_imputation_metrics = ['Average', 'BBH', 'MATH Lvl 5', 'GPQA', 'MUSR', 'MMLU-PRO', 'IFEval']

# Load VirtualHome action sequencing dataset
df = pd.read_csv(f"/Users/qinjielin/Downloads/NWU/25corl/corl_ws/ObsScaling/eval_results/{dataset_name}_{eaval_type}_results_with_flops_and_openllm.csv")
masked_df = pd.read_csv(f"/Users/qinjielin/Downloads/NWU/25corl/corl_ws/ObsScaling/eval_results/{dataset_name}_{eaval_type}_v4_results_with_flops_and_openllm.csv")

df =df.dropna(subset=['Average'])
masked_df =masked_df.dropna(subset=['Average'])
print(f"After filtering Average: {len(df)} models with meaningful performance data")
print(f"After filtering Average: {len(masked_df)} models with meaningful performance data")

f1_columns = ['execution_success_rate']
df = df[~(df[f1_columns].fillna(0) == 0).all(axis=1)]
masked_df = masked_df[~(masked_df[f1_columns].fillna(0) == 0).all(axis=1)]
print(f"After filtering F1 columns: {len(df)} models with meaningful performance data")
print(f"After filtering F1 columns: {len(masked_df)} models with meaningful performance data")

# df = df[~df['Model'].str.contains('Chat|Instruct|it|R1|Think', case=False, na=False)]
# masked_df = masked_df[~masked_df['Model'].str.contains('Chat|Instruct|it|R1|Think', case=False, na=False)]
# df = df[df['Model'].isin(masked_df['Model'])]

# # Filter to only include models with valid Model Size data
# df = df.dropna(subset=['FLOPs (1E21)'])
# masked_df = masked_df.dropna(subset=['FLOPs (1E21)'])
# print(f"After filtering FLOPs in original df: {len(df)} models with meaningful performance data")
# print(f"After filtering FLOPs in masked df: {len(masked_df)} models with meaningful performance data")

# Convert percentage metrics to 0-1 range
columns_to_convert_to_percentage = ["task_success_rate", "state_goal", "relation_goal", "action_goal", "total_goal", 
                                    "execution_success_rate", "parsing_error", "hallucination_error", 
                                    "predicate_argument_number_error", "wrong_order_error", "missing_step_error", 
                                    "affordance_error", "additional_step_error"]
for metric in columns_to_convert_to_percentage:
    if metric in df.columns:
        if metric.endswith('_error'):
            df[metric] = (100 - df[metric]) / 100.0
        else:
            df[metric] = df[metric] / 100.0
for metric in columns_to_convert_to_percentage:
    if metric in masked_df.columns:
        if metric.endswith('_error'):
            masked_df[metric] = (100 - masked_df[metric]) / 100.0
        else:
            masked_df[metric] = masked_df[metric] / 100.0

convert_metrics_to_rate = ['Average', 'BBH', 'MATH Lvl 5', 'GPQA', 'MUSR', 'MMLU-PRO', 'IFEval']
df[convert_metrics_to_rate] = df[convert_metrics_to_rate]/100
masked_df[convert_metrics_to_rate] = masked_df[convert_metrics_to_rate]/100
print(f"Converted {len(convert_metrics_to_rate)} metrics to success rates")


######################################################## merge the two dataframes########################################################
# Rename columns in base dataframe to add ",base" suffix
base_columns_rename = {}
for col in df.columns:
    if col in ['Model', 'Model Family', 'FLOPs (1E21)', 'Model Size (B)', 'Pretraining Data Size (T)'] or col in pca_imputation_metrics:  # Keep key columns unchanged
        base_columns_rename[col] = col
    else:
        base_columns_rename[col] = f"{col},base"
df = df.rename(columns=base_columns_rename)

# Rename columns in instruction-tuned dataframe to add ",it" suffix
it_columns_rename = {}
for col in masked_df.columns:
    if col in ['Model', 'Model Family', 'FLOPs (1E21)', 'Model Size (B)', 'Pretraining Data Size (T)'] or col in pca_imputation_metrics:  # Keep key columns unchanged
        it_columns_rename[col] = col
    else:
        it_columns_rename[col] = f"{col},masked"
masked_df = masked_df.rename(columns=it_columns_rename)

# Merge the two dataframes on Model column
# This will create a wide format with separate columns for base and it versions
merged_df = pd.merge(
    df, 
    masked_df, 
    on=['Model', 'Model Family', 'FLOPs (1E21)', 'Model Size (B)', 'Pretraining Data Size (T)'] + pca_imputation_metrics, 
    how='outer'
)

print(f"Merged dataframe shape: {merged_df.shape}")
print(f"Columns in merged dataframe: {merged_df.columns.tolist()}")

# List the remaining models after filtering
print(f"\n📋 {len(merged_df)} Models remaining after filtering:")
print("=" * 80)
merged_df_sorted = merged_df.sort_values(['Model Family', 'Model'])
for _, row in merged_df_sorted.iterrows():
    # Check if Average column exists and has a value
    if 'Average' in row:
        if pd.notna(row['Average']):
            average_value = row['Average']
            average_status = f"✅ {average_value:.2f}"
        else:
            average_status = "❌ No open LLM data"
    else:
        average_status = "⚠️ No open LLM column"
    
    print(f"  • {row['Model']} (Family: {row['Model Family']}) {average_status}")
print("=" * 80)


########################################################configuring the plotting########################################################
# Define Y metrics for VirtualHome action sequencing - organized into 3 groups of 4 metrics
df_y_metrics = [
    # ['task_success_rate'],
    # ['task_success_rate'],
    # ['task_success_rate'],
    # ['task_success_rate'],
    # ['task_success_rate'],
    ['task_success_rate', 'execution_success_rate', 'total_goal', 'state_goal'],
    ['relation_goal', 'action_goal', 'parsing_error', 'hallucination_error'],
    ['wrong_order_error', 'missing_step_error', 'additional_step_error', 'affordance_error']
]

# Update the metric map to include both base and it versions
df_y_metric_map = {}
for group in df_y_metrics:
    for metric in group:
        df_y_metric_map[f"{metric},base"] = f"Base Model"
        df_y_metric_map[f"{metric},masked"] = f"Model with Decoder Masking"

# Define color maps for base vs instruction-tuned models
df_base_color_map = {}
df_masked_color_map = {}

for group in df_y_metrics:
    for metric in group:
        df_base_color_map[f"{metric},base"] = '#1f77b4'  # Blue for base
        df_masked_color_map[f"{metric},masked"] = '#ff7f0e'      # Orange for IT

# Define X metrics - using Model Size as the primary predictor
df_x_metrics_list = [
    # [PC_METRIC_NUM_1],
    # [PC_METRIC_NUM_2],
    # [PC_METRIC_NUM_3],
    # [PC_METRIC_NUM_4],
    # [PC_METRIC_NUM_5],
    [PC_METRIC_NUM_3],  # Row 2: Goal achievement metrics
    [PC_METRIC_NUM_3],   # Row 3: Error metrics
    [PC_METRIC_NUM_3]   # Row 4: Error metrics
]

pca_imputation_metrics = ['Average', 'BBH', 'MATH Lvl 5', 'GPQA', 'MUSR', 'MMLU-PRO', 'IFEval']
NONGSM_PCA_PREPROCESS_KWARGS['imputation_metrics'] = pca_imputation_metrics
df_default_setup_kwargs = {
    **NONGSM_PCA_PREPROCESS_KWARGS,

    "pca_metrics": pca_imputation_metrics,
    "ref_model_family": "Qwen1.5", #"Gemma", #"Llama-2", #"Qwen", #Gemma looks better than llama, Qwen1.5 looks better than llama
    "stylize_data": False,
    "nonlinearity": "sigmoid-parametric", 

    # Metric normalization - normalize success rates to [0,1] range
    # "y_metric_process_funcs": "minmax_norm",

    # Group markers by model family
    "df_groupby": 'Model Family',  
    
    # Regression: robust regression for handling outliers
    "reg_method": "robust",  
    "reg_kwargs": {"delta": 1.0},   # huber loss with delta=1.0 for normalized target within [0, 1]
    
    # No PCA needed since we're only using Model Size as predictor
    "apply_pca": True,
    "apply_imputation": True,

    # "split_method": "cutoff_by_FLOPs (1E21)",
    # "cutoff_threshold": 2000,
    "split_method": "cutoff_by_Model Size (B)",
    "cutoff_threshold": 40,  # Split at 70B parameters / 80

}

# Task-specific configuration for VirtualHome metrics
df_eval_setup_specific_kwargs = {}    

# Configure each individual metric (both base and it versions)
all_individual_metrics = []
for group in df_y_metrics:
    for metric in group:
        all_individual_metrics.extend([f"{metric},base", f"{metric},masked"])

for y_metric in all_individual_metrics:
    df_eval_setup_specific_kwargs[y_metric] = {}
    # Metric range is now 0.0 to 1.0 (converted from percentages)
    df_eval_setup_specific_kwargs[y_metric]['y_metric_range'] = (0.0, 1.0)

    # Use cutoff by model size for train/test split
    df_eval_setup_specific_kwargs[y_metric].update({
        "plot_adjust_kwargs": {"ylim": [-0.05, 1.05]}  # Allow some margin for visualization
    })


# Generate scaling plots for VirtualHome action sequencing with comparison
# This will create a 3x4 grid: 3 rows (groups) x 4 columns (metrics per group)
# Each row uses Model Size as the predictor for 4 different Y metrics

# Store all figures in a list
all_figures = []

for (x_metrics, y_metrics) in zip(df_x_metrics_list, df_y_metrics):       
    # Create list of metrics for this group (both base and it versions)
    group_metrics = []
    for metric in y_metrics:
        group_metrics.extend([[f"{metric},base", f"{metric},masked"]])
    
    # Plot using the merged dataframe with plot_scaling_comparison_multi_metrics
    for pair in group_metrics:
      fig = plot_scaling_comparison_multi_metrics(
          merged_df,                    # Your merged dataframe
          pair,                         # List of 8 metrics (4 base + 4 it)
          x_metrics,                           # Predictor: [['FLOPs (1E21)']]
          df_default_setup_kwargs,      # General setup
          y_metric_specific_kwargs=df_eval_setup_specific_kwargs,  # Per-metric settings
          ymetric2title_map=df_y_metric_map,                        # Metric → Title mapping
          ymetric2color_map={**df_base_color_map, **df_masked_color_map},  # Combined color map
          plot_title=f"{', '.join(set([m.replace(',base', '').replace(',masked', '') for m in pair]))}"
      )
    
      # Store the figure in our list
      all_figures.append(fig)
      
      # Save the plot
      if not os.path.exists(f'plots/{dataset_name}/{eaval_type}/comparison_masked'):
        os.makedirs(f'plots/{dataset_name}/{eaval_type}/comparison_masked')
      plt.show()
      fig.savefig(f'plots/{dataset_name}/{eaval_type}/comparison_masked/{"_".join(pair)}.png', dpi=300, bbox_inches='tight')
      print(f"Saved comparison plot to plots/{dataset_name}/{eaval_type}/comparison_masked/{'_'.join(pair)}.png")

print(f"\nGenerated {len(all_figures)} comparison plots")
