import sys
sys.path.append('/Users/qinjielin/Downloads/NWU/25corl/corl_ws/ObsScaling')
from utils import *

import matplotlib.pyplot as plt
import numpy as np  # Add this import back

dataset_name = 'virtualhome' # behavior
eval_type = 'action_sequencing' # action_sequencing
df = pd.read_csv(f'./eval_results/{dataset_name}_{eval_type}_results_with_flops_and_openllm.csv')

# ALL_BENCHMARK_METRIC_LIST = ['all_f1', 'node_f1', 'edge_f1', 'action_f1']
ALL_BENCHMARK_METRIC_LIST = ['task_success_rate', 'execution_success_rate']
# ALL_BENCHMARK_METRIC_LIST = ['execution_success_rate']
x_metric_name = 'FLOPs (1E21)'


# Filter out model families with less than 3 models
family_counts = df['Model Family'].value_counts()
families_with_min_models = family_counts[family_counts >= 3].index.tolist()
df = df[df['Model Family'].isin(families_with_min_models)]

# Filter out models with NA FLOPs
df = df.dropna(subset=['FLOPs (1E21)'])
df = df.dropna(subset=ALL_BENCHMARK_METRIC_LIST)

#remove instruction tuning models
# df = df[~df['Model'].str.contains('-it|Chat|instruct|Instruct|Yi-Coder|Falcon3|Thinking|deepseek-ai/DeepSeek-R1-Distill-Qwen-32B')]
df = df[~df['Model'].str.contains('Chat|Instruct|Yi-Coder|Falcon3|Thinking|DeepSeek-R1-Distill-Qwen-32B|phi-1_5|phi-4')]

# further reassign model family
def reassign_model_family(row):
    """Reassign model families based on specific rules"""
    model = row['Model']
    current_family = row['Model Family']
    
    if 'DeepSeek-R1' in model:
        if 'Distill-Llama' in model:
            return 'DeepSeek-R1-Distill-Llama'
        elif 'Distill-Qwen' in model:
            return 'DeepSeek-R1-Distill'
        else:
            return 'DeepSeek-R1'  # Keep original for non-distilled models
    
    if 'google/gemma-2' in model:
        if '-it' in model:
            return 'Gemma-2-it'
        else:
            return 'Gemma-2'
    if 'google/gemma-3' in model:
        if '-it' in model:
            return 'Gemma-3-it'
        else:
            return 'Gemma-3'
    
    return current_family  # Keep original family for all other models

# Apply the reassignment
df['Model Family'] = df.apply(reassign_model_family, axis=1)

# Update the filter for minimum 3 models after reassignment
family_counts = df['Model Family'].value_counts()
families_with_min_models = family_counts[family_counts >= 3].index.tolist()
df = df[df['Model Family'].isin(families_with_min_models)]

# plot_families = [
#                 'Yi', 'Granite', 'DeepSeek-Coder', 
#                  'Llama-3', 'Gemma-2', 'Qwen3', 'Qwen2.5', 
#                  'DeepSeek-R1', 'DeepSeek-R1-Distill-Llama', 'DeepSeek-R1-Distill-Qwen',
#                  'Qwen1.5', 'phi', 'Gemma-3', 'falcon', 'Gemma', 
#                  'Qwen', 'Llama', 'Baichuan', 'Llama-2',
#                  'Gemma-2-it', 'Gemma-3-it', 'Gemma-it', 'Gemma-2', 'Gemma-3', 'Gemma',
#                 ]
plot_families =  ['Qwen3', 'Qwen', 'Llama-2', 'Llama-3', 'Qwen2.5', 'DeepSeek-R1-Distill-Qwen', 'phi', 'Gemma-2-it', 'Gemma-3-it', 'Gemma-2', 'DeepSeek-R1-Distill']
df = df[df['Model Family'].isin(plot_families)]
df = df.sort_values(by='Model', ascending=True)

for family in sorted(df['Model Family'].unique()):
    family_df = df[df['Model Family'] == family]
    models = family_df['Model'].unique()
    print(f"\n{family} ({len(models)} models):")
    for model in sorted(models):
        print(f"  - {model}")

for y_metric_name in ALL_BENCHMARK_METRIC_LIST:
    _base_llm_eval_pca = df.dropna(subset=[y_metric_name])
    _base_llm_eval_pca = _base_llm_eval_pca.dropna(subset=[x_metric_name])
    _base_llm_eval_pca[y_metric_name] = _base_llm_eval_pca[y_metric_name] / 100
    _eval_base_model_with_flops_families = _base_llm_eval_pca['Model Family'].unique()
    
    print(F"{eval_type} {y_metric_name} {_eval_base_model_with_flops_families} models")
    
    sns.set_theme(style="darkgrid")

    fig = plot_linear_correlation(
        _base_llm_eval_pca, x_metric_name, y_metric_name, 
        model_family_names=_eval_base_model_with_flops_families,
        log_x_metric=True, unified_plot=True,
    )

    # Adjust subplot parameters
    plt.subplots_adjust(
        # left=0.31,    # Left margin
        # bottom=0.11,  # Bottom margin  
        right=0.75,  # Right margin
        # top=0.88,     # Top margin
        # wspace=0.2,   # Width spacing between subplots
        # hspace=0.2    # Height spacing between subplots
    )
    plt.show()
    plt.savefig(f'./plots/slides/{dataset_name}_{eval_type}_{y_metric_name}_cor.png')
    print(f"Saved plot to ./plots/slides/{dataset_name}_{eval_type}_{y_metric_name}_cor.png")
    plt.close()