import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import json
import re
import copy

from utils import *

base_llm_benchmark_eval = load_base_llm_benchmark_eval()

dataset_name = "behavior" #"behavior" "virtualhome"
eaval_type = "goal_interpretation" #"action_sequencing" goal_interpretation_v4 goal_interpretation
eai_eval = pd.read_csv(f"/Users/qinjielin/Downloads/NWU/25corl/corl_ws/ObsScaling/eval_results/{dataset_name}_{eaval_type}_results.csv")

# Get FLOPs and data size info from base_llm_benchmark_eval and add to eai_eval
# Assuming base_llm_benchmark_eval has columns like 'FLOPs' and 'Data Size'
eai_eval = pd.merge(
    eai_eval, 
    base_llm_benchmark_eval[['Model', 'Pretraining Data Size (T)', 'FLOPs (1E21)']], 
    on='Model', 
    how='left'
)

# Create a mapping dictionary for the new columns
data_info = {
    "meta-llama/Llama-4-Scout-17B-16E-Instruct": 40,
    "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8": 22,
    "baichuan-inc/Baichuan-7B": 1.2,
    "baichuan-inc/Baichuan-13B-Base": 1.4,
    "google/gemma-3-1b-pt": 2,
    "google/gemma-3-1b-it": 2,
    "google/gemma-3-4b-it" : 4,
    "google/gemma-3-4b-pt": 4,
    "google/gemma-3-12b-it" : 12,
    "google/gemma-3-12b-pt": 12,
    "google/gemma-3-27b-it": 14,
    "google/gemma-3-27b-pt": 14,
    "google/gemma-2-2b": 2,
    "google/gemma-2-2b-it": 2,
    "google/gemma-2-9b": 8,
    "google/gemma-2-9b-it": 8,
    "google/gemma-2-27b": 13,
    "google/gemma-2-27b-it": 13,
    'google/gemma-1.1-2b-it': 3,
    'google/gemma-1.1-7b-it': 6,
    "google/gemma-7b-it": 2,
    "google/gemma-2b-it": 6,
    "meta-llama/Llama-3.3-70B-Instruct": 15,
    "meta-llama/Llama-3.2-1B": 9,
    "meta-llama/Llama-3.2-1B-Instruct": 9,
    "meta-llama/Llama-3.2-3B-Instruct": 9,
    "meta-llama/Meta-Llama-3-70B-Instruct": 15,
    "meta-llama/Meta-Llama-3-8B-Instruct": 15,
    "meta-llama/Llama-3.1-70B": 15,
    "meta-llama/Llama-3.1-70B-Instruct": 15,
    "meta-llama/Llama-3.1-8B": 15,
    "baichuan-inc/Baichuan2-7B-Base": 2.6,
    "baichuan-inc/Baichuan2-7B-Chat": 2.6,
    "01-ai/Yi-Coder-1.5B-Chat": 2.4,
    "01-ai/Yi-Coder-1.5B": 2.4,
    "01-ai/Yi-Coder-9B-Chat": 2.4,
    "01-ai/Yi-Coder-9B": 2.4,
    "01-ai/Yi-1.5-6B-Chat": 3.6,
    "01-ai/Yi-1.5-6B": 3.6,
    "01-ai/Yi-1.5-9B": 3.6,
    "01-ai/Yi-1.5-34B-Chat": 3.6,
    "01-ai/Yi-1.5-34B": 3.6,
    "01-ai/Yi-6B": 3.1,
    "01-ai/Yi-9B": 3.1,
    "01-ai/Yi-34B": 3.1,
    "01-ai/Yi-6B-Chat": 3.1,
    "01-ai/Yi-9B-Chat": 3.1,
    "01-ai/Yi-34B-Chat": 3.1,
    'ibm-granite/granite-3.2-2b-instruct': 12,
    'ibm-granite/granite-3.1-8b-instruct': 12,
    'ibm-granite/granite-3.3-2b-base': 12,
    'ibm-granite/granite-3.1-2b-base': 12,
    'ibm-granite/granite-3.1-2b-instruct': 12,
    'ibm-granite/granite-3.2-8b-instruct': 12,
    'ibm-granite/granite-3.1-8b-base': 12,
    'ibm-granite/granite-3.3-2b-instruct': 12,
    'ibm-granite/granite-3.3-8b-instruct': 12,
    'ibm-granite/granite-3.3-8b-base': 12,
    'Qwen/Qwen3-14B': 36,
    'Qwen/Qwen3-32B': 36,
    'Qwen/Qwen3-8B': 36,
    'Qwen/Qwen3-0.6B': 36,
    'Qwen/Qwen3-1.7B': 36,
    'Qwen/Qwen3-4B': 36,
    'Qwen/Qwen3-235B-A22B-Thinking-2507': 36,
    'Qwen/Qwen2.5-0.5B': 18,
    'Qwen/Qwen2.5-1.5B': 18,
    'Qwen/Qwen2.5-3B': 18,
    'Qwen/Qwen2.5-7B': 18,
    'Qwen/Qwen2.5-14B': 18,
    'Qwen/Qwen2.5-32B': 18,
    'Qwen/Qwen2.5-72B': 18,
    'LGAI-EXAONE/EXAONE-3.5-32B-Instruct': 6.5,
    'moonshotai/Kimi-K2-Instruct': 15.5,
    "deepseek-ai/DeepSeek-V3": 14.8,
    "deepseek-ai/DeepSeek-R1": 14.8,
    'deepseek-ai/DeepSeek-R1-Distill-Llama-70B': 15,
    'deepseek-ai/DeepSeek-R1-Distill-Llama-8B': 15,
    'deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B': 18,
    'deepseek-ai/DeepSeek-R1-Distill-Qwen-14B': 18,
    'deepseek-ai/DeepSeek-R1-Distill-Qwen-32B': 18,
    'deepseek-ai/DeepSeek-R1-Distill-Qwen-7B': 18,
    "tiiuae/falcon-11B": 5,
    "tiiuae/Falcon3-7B-Base": 14 ,
    "tiiuae/Falcon3-10B-Base": 14,
    "microsoft/Phi-3-mini-4k-instruct": 4.9,
    "microsoft/Phi-3-mini-128k-instruct": 4.9,
    "microsoft/Phi-3-medium-4k-instruct": 4.8,
    "microsoft/Phi-3-medium-128k-instruct": 4.8,
    "microsoft/phi-4": 9.8 ,
    "deepseek-ai/deepseek-coder-33b-instruct": 2.0,
    "deepseek-ai/deepseek-coder-7b-base-v1.5": 2.0,
    "deepseek-ai/deepseek-coder-7b-instruct-v1.5": 2.0,
    "deepseek-ai/deepseek-coder-6.7b-instruct": 2.0,
    "deepseek-ai/deepseek-coder-1.3b-instruct": 2.0,
    "meta-llama/Llama-3.2-3B": 9.0,
    "Qwen/Qwen1.5-110B": 7.0, # from google
    "LGAI-EXAONE/EXAONE-Deep-32B": 6.5,

}

# Add the new columns to eai_eval
for model, data_size in data_info.items():
    eai_eval.loc[eai_eval['Model'] == model, 'Pretraining Data Size (T)'] = data_size
    # Calculate FLOPs for the specific model rows
    model_mask = eai_eval['Model'] == model
    eai_eval.loc[model_mask, 'FLOPs (1E21)'] = (
        6 * data_size * 
        eai_eval.loc[model_mask, 'Model Size (B)']
    )

models_with_nan_data_size = eai_eval[eai_eval['Pretraining Data Size (T)'].isna()][['Model', 'Model Family']]

if len(models_with_nan_data_size) > 0:
    print(f"\n⚠️  WARNING: {len(models_with_nan_data_size)} models without Pretraining Data Size:")
    print("=" * 80)
    
    # Sort by Model Family first, then by Model name
    models_with_nan_data_size_sorted = models_with_nan_data_size.sort_values(['Model Family', 'Model'])
    
    for _, row in models_with_nan_data_size_sorted.iterrows():
        print(f"  • {row['Model']} (Family: {row['Model Family']})")
    
    print("=" * 80)
    print("These models may not display properly in scaling plots.")
    print("Consider checking HuggingFace model names or adding manual size mappings.")
else:
    print(f"\n✅ All {len(eai_eval)} models have model size data.")

print(f"saved to ./eval_results/{dataset_name}_{eaval_type}_results_with_flops.csv")
eai_eval.to_csv(f"./eval_results/{dataset_name}_{eaval_type}_results_with_flops.csv", index=False)