#%%
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import re
from matplotlib.ticker import LogLocator
from scipy.interpolate import interp1d
from statsmodels.nonparametric.smoothers_lowess import lowess

df = pd.read_csv('gl_results_5_seeds.csv')

# Create a list of tuples and sort by model size (extracted from name)
def extract_size(model_name):
    """Extract numeric size from model name for sorting"""
    # Special case for phi4
    if 'phi-4' in model_name.lower():
        return 14
    
    # Look for patterns like "7B", "8x7B", "70B", etc.
    size_match = re.search(r'(\d+(?:x\d+)?[BbMm])', model_name)
    if size_match:
        size_str = size_match.group(1)
        # Convert to comparable number (B = billion, M = million)
        if 'x' in size_str:
            # Handle cases like "8x7B"
            parts = size_str.replace('B', '').replace('b', '').replace('M', '').replace('m', '').split('x')
            return int(parts[0]) * int(parts[1])
        elif 'B' in size_str or 'b' in size_str:
            return int(size_str.replace('B', '').replace('b', ''))
        elif 'M' in size_str or 'm' in size_str:
            return int(size_str.replace('M', '').replace('m', '')) / 1000  # Convert to billions
    return 0  # Default for models without clear size indicator


# Add model size to dataframe
df['model_size'] = df['model'].apply(extract_size)
# %%
df
# Adjust model sizes for Llama 2 70B and Llama 3 70B
# def adjust_model_size(row):
#     if 'Llama-2' in row['model'] and row['model_size'] == 70:
#         return 71  # Add 1B
#     elif 'Llama' in row['model'] and '3' in row['model'] and row['model_size'] == 70:
#         return 69  # Subtract 1B
#     else:
#         return row['model_size']

# df['model_size'] = df.apply(adjust_model_size, axis=1)
# %%
# Group by model to calculate mean and std across seeds
model_stats = df.groupby('model').agg({
    'gl_value': ['mean', 'std'],
    'model_size': 'first'  # Size should be the same for all seeds of a model
}).reset_index()

# Flatten column names
model_stats.columns = ['model', 'gl_mean', 'gl_std', 'model_size']

plt.figure(figsize=(10, 6))
plt.errorbar(model_stats['model_size'], model_stats['gl_mean'], 
             yerr=model_stats['gl_std'], fmt='o', alpha=0.7, capsize=3)
plt.xlabel('Model Size (B parameters)')
plt.ylabel('Grouping Loss')
plt.title('Grouping Loss vs Model Size (Mean ± Std across 5 seeds)')
plt.xscale('log')
plt.yscale('log')
plt.grid(True, alpha=0.3)
plt.show()

# Extract model type (instruct vs base)
def extract_model_type(model_name):
    """Extract whether model is instruct or base"""
    model_lower = model_name.lower()
    if 'instruct' in model_lower or 'chat' in model_lower or 'it' in model_lower:
        return 'Instruct'
    else:
        return 'Base'

# Add model type to dataframe
df['model_type'] = df['model'].apply(extract_model_type)
# plt.figure(figsize=(12, 8))
# families = df['family'].unique()
# colors = plt.cm.tab10(range(len(families)))

# for i, family in enumerate(families):
#     family_data = df[df['family'] == family]
    
#     # Plot base models with circles
#     base_data = family_data[family_data['model_type'] == 'Base']
#     if not base_data.empty:
#         plt.scatter(base_data['model_size'], base_data['grouping_loss'], 
#                     alpha=0.7, label=f'{family} (Base)', color=colors[i], marker='o')
    
#     # Plot instruct models with triangles
#     instruct_data = family_data[family_data['model_type'] == 'Instruct']
#     if not instruct_data.empty:
#         plt.scatter(instruct_data['model_size'], instruct_data['grouping_loss'], 
#                     alpha=0.7, label=f'{family} (Instruct)', color=colors[i], marker='^')

# plt.xlabel('Model Size (B parameters)')
# plt.ylabel('Grouping Loss')
# plt.title('Grouping Loss vs Model Size by LLM Family and Type')
# plt.xscale('log')
# plt.yscale('log')
# plt.legend()
# plt.grid(True, alpha=0.3)
# plt.show()
# %%
# Create scatter plot with family colors and model type symbols (circles and stars)
# plt.figure(figsize=(12, 8))
# families = df['family'].unique()
# colors = plt.cm.tab10(range(len(families)))

# for i, family in enumerate(families):
#     family_data = df[df['family'] == family]
    
#     # Plot base models with circles
#     base_data = family_data[family_data['model_type'] == 'Base']
#     if not base_data.empty:
#         plt.scatter(base_data['model_size'], base_data['grouping_loss'], 
#                     alpha=0.7, label=f'{family} (Base)', color=colors[i], marker='o')
    
#     # Plot instruct models with stars
#     instruct_data = family_data[family_data['model_type'] == 'Instruct']
#     if not instruct_data.empty:
#         plt.scatter(instruct_data['model_size'], instruct_data['grouping_loss'], 
#                     alpha=0.7, label=f'{family} (Instruct)', color=colors[i], marker='*')

# plt.xlabel('Model Size (B parameters)')
# plt.ylabel('Grouping Loss')
# plt.title('Grouping Loss vs Model Size by LLM Family and Type')
# plt.xscale('log')
# plt.yscale('log')
# plt.legend()
# plt.grid(True, alpha=0.3)
# plt.show()
# %%
# Update family extraction to separate Llama 2 and Llama 3
def extract_family_detailed(model_name):
    """Extract LLM family from model name with Llama 2/3 separation"""
    model_lower = model_name.lower()
    if "DeepSeek" in model_name:
        return "DeepSeek R1"
    elif 'llama' in model_lower:
        if 'Llama-2' in model_name:
            return 'Llama 2'
        elif '3' or '3.1' or "3.2" or "3.3" in model_name:
            return 'Llama 3'
        else:
            return 'Llama'
    elif 'mistral' in model_lower:
        return 'Mistral'
    elif 'mixtral' in model_lower:
        return 'Mixtral'
    elif 'phi' in model_lower:
        return 'Phi'
    elif 'qwen' in model_lower:
        return 'Qwen'
    elif 'gemma' in model_lower:
        return 'Gemma'
    elif 'claude' in model_lower:
        return 'Claude'
    elif 'gpt' in model_lower:
        return 'GPT'
    else:
        return 'Orca'

# Update family with detailed separation
df['family'] = df['model'].apply(extract_family_detailed)

# Group by model to calculate mean and std across seeds
model_stats_detailed = df.groupby(['model', 'family', 'model_type']).agg({
    'gl_value': ['mean', 'std'],
    'model_size': 'first'
}).reset_index()

# Flatten column names
model_stats_detailed.columns = ['model', 'family', 'model_type', 'gl_mean', 'gl_std', 'model_size']
    
# Create scatter plot with updated family separation and error bars
plt.figure(figsize=(4, 2))
families = model_stats_detailed['family'].unique()
colors = plt.cm.tab10(range(len(families)))

# Track which families have been added to legend
legend_added = set()

for i, family in enumerate(families):
    family_data = model_stats_detailed[model_stats_detailed['family'] == family]
    
    # Plot base models with circles
    base_data = family_data[family_data['model_type'] == 'Base']
    if not base_data.empty:
        label = family if family not in legend_added else None
        plt.errorbar(base_data['model_size'], base_data['gl_mean'], 
                     yerr=base_data['gl_std'], fmt='o', alpha=0.7, 
                     label=label, color=colors[i], markersize=8, capsize=3)
        legend_added.add(family)
    
    # Plot instruct models with stars
    instruct_data = family_data[family_data['model_type'] == 'Instruct']
    if not instruct_data.empty:
        label = family if family not in legend_added else None
        plt.errorbar(instruct_data['model_size'], instruct_data['gl_mean'], 
                     yerr=instruct_data['gl_std'], fmt='*', alpha=0.7, 
                     label=label, color=colors[i], markersize=12, 
                     markeredgecolor='black', capsize=3)
        legend_added.add(family)

# Add marker type indicators to legend
plt.errorbar([], [], fmt='o', color='gray', markersize=8, label='Base')
plt.errorbar([], [], fmt='*', color='gray', markersize=12, markeredgecolor='black', label='Instruct')
# Fit LOWESS regression separately for base and instruct models (only for models > 2B)
# Filter models with size > 2B
base_models_filtered = model_stats_detailed[(model_stats_detailed['model_type'] == 'Base') & (model_stats_detailed['model_size'] > 2)]
instruct_models_filtered = model_stats_detailed[(model_stats_detailed['model_type'] == 'Instruct') & (model_stats_detailed['model_size'] > 2)]

# Fit LOWESS for base models > 2B
if not base_models_filtered.empty:
    base_sorted = base_models_filtered.sort_values('model_size')
    base_lowess = lowess(base_sorted['gl_mean'], base_sorted['model_size'], frac=1)
    plt.plot(base_lowess[:, 0], base_lowess[:, 1], '--', color='blue', linewidth=2, 
             alpha=0.8, label='Base Trend')

# Fit LOWESS for instruct models > 2B
if not instruct_models_filtered.empty:
    instruct_sorted = instruct_models_filtered.sort_values('model_size')
    instruct_lowess = lowess(instruct_sorted['gl_mean'], instruct_sorted['model_size'], frac=1)
    plt.plot(instruct_lowess[:, 0], instruct_lowess[:, 1], '--', color='red', linewidth=2, 
             alpha=0.8, label='Instruct Trend')

# # Prepare data for LOWESS fitting
# base_models = model_stats_detailed[model_stats_detailed['model_type'] == 'Base']
# instruct_models = model_stats_detailed[model_stats_detailed['model_type'] == 'Instruct']

# # Fit LOWESS for base models
# if not base_models.empty:
#     base_sorted = base_models.sort_values('model_size')
#     base_lowess = lowess(base_sorted['gl_mean'], base_sorted['model_size'], frac=0.6)
#     plt.plot(base_lowess[:, 0], base_lowess[:, 1], '--', color='blue', linewidth=2, 
#              alpha=0.8, label='Base Trend')

# # Fit LOWESS for instruct models
# if not instruct_models.empty:
#     instruct_sorted = instruct_models.sort_values('model_size')
#     instruct_lowess = lowess(instruct_sorted['gl_mean'], instruct_sorted['model_size'], frac=0.6)
#     plt.plot(instruct_lowess[:, 0], instruct_lowess[:, 1], '--', color='red', linewidth=2, 
#              alpha=0.8, label='Instruct Trend')
plt.xlabel('Model Size (B parameters)', fontsize=12)
plt.ylabel('Grouping Loss', fontsize=12)
# plt.title('Grouping Loss vs Model Size by LLM Family and Type', fontsize=12)
plt.xscale('log')
plt.yscale('log')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8, ncol=2)

plt.grid(True, alpha=0.5)
# plt.tight_layout()
plt.savefig('grouping_loss_vs_model_size.pdf', dpi=800, bbox_inches='tight')
plt.show()
# %%
baselines = pd.read_csv('combined_results_baseline.csv')
# %%
# Add baseline performance as horizontal dashed lines
plt.figure(figsize=(6, 6))

families = df['family'].unique()
colors = plt.cm.tab10(range(len(families)))

# Track which families have been added to legend
legend_added = set()

for i, family in enumerate(families):
    family_data = df[df['family'] == family]
    
    # Plot base models with circles
    base_data = family_data[family_data['model_type'] == 'Base']
    if not base_data.empty:
        label = family if family not in legend_added else None
        plt.scatter(base_data['model_size'], base_data['grouping_loss'], 
                    alpha=0.7, label=label, color=colors[i], marker='o', s=100)
        legend_added.add(family)
    
    # Plot instruct models with stars
    instruct_data = family_data[family_data['model_type'] == 'Instruct']
    if not instruct_data.empty:
        label = family if family not in legend_added else None
        plt.scatter(instruct_data['model_size'], instruct_data['grouping_loss'], 
                    alpha=0.7, label=label, color=colors[i], marker='*', s=150, edgecolors='black')
        legend_added.add(family)

# # Add baseline models as horizontal dashed lines
# if 'RandomForest' in baselines['model_name'].values:
#     rf_loss = baselines[baselines['model_name'] == 'RandomForest']['grouping_loss'].iloc[0]
#     plt.axhline(y=rf_loss, color='red', linestyle='--', alpha=0.7, label='Random Forest')

# if 'GradientBoosting' in baselines['model_name'].values:
#     gb_loss = baselines[baselines['model_name'] == 'GradientBoosting']['grouping_loss'].iloc[0]
#     plt.axhline(y=gb_loss, color='green', linestyle='--', alpha=0.7, label='Gradient Boosting')

# if 'LogisticRegression' in baselines['model_name'].values:
#     lr_loss = baselines[baselines['model_name'] == 'LogisticRegression']['grouping_loss'].iloc[0]
#     plt.axhline(y=lr_loss, color='blue', linestyle='--', alpha=0.7, label='Logistic Regression')

# Add marker type indicators to legend
plt.scatter([], [], marker='o', color='gray', s=100, label='Base')
plt.scatter([], [], marker='*', color='gray', s=150, edgecolors='black', label='Instruct')

plt.xlabel('Model Size (B parameters)')
plt.ylabel('Grouping Loss')
plt.title('Grouping Loss vs Model Size by LLM Family and Type')
plt.xscale('log')
plt.yscale('log')

# Add more ticks to the y-axis
plt.gca().yaxis.set_minor_locator(LogLocator(base=10, subs='auto'))
plt.gca().yaxis.set_major_locator(LogLocator(base=10))
# Add more ticks to the y-axis
plt.gca().yaxis.set_minor_locator(LogLocator(base=10, subs='auto'))
plt.gca().yaxis.set_major_locator(LogLocator(base=10))

plt.legend(bbox_to_anchor=(0.5, -0.15), loc='upper center', ncol=3)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# %%
