import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from pathlib import Path
import re
import json

plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")

plt.rcParams['font.family'] = 'Times New Roman'

plt.rcParams['font.size'] = 6
plt.rcParams['axes.titlesize'] = 10
plt.rcParams['axes.labelsize'] = 8
plt.rcParams['legend.fontsize'] = 7
plt.rcParams['xtick.labelsize'] = 7
plt.rcParams['ytick.labelsize'] = 7

def size_to_numeric(size_str):
    """Converts model size string to a sortable numeric value (in billions)."""
    size_str = str(size_str).upper()
    if 'MINI' in size_str:
        return 0.1  
    if 'M' in size_str:
        return float(size_str.replace('M', '')) / 1000
    if 'B' in size_str:
        return float(size_str.replace('B', ''))
    return 999  

def parse__info(csv_name):
    """Extract name, size, type, and contamination level from CSV filename."""
   
    clean_csv_name = csv_name.replace(' ', '')
    
    parts = clean_csv_name.split('_')
    
    contamination = None
    for i, part in enumerate(parts):
        if part.isdigit() and int(part) in [25, 50, 75, 100]:
            contamination = int(part)
            break
    
    _type = 'Instruct'
    
    family = 'Unknown'
    size = 'Unknown'
    
    # Extract family (with version) and size using the clean name
    if 'Gemma3' in clean_csv_name:
        family = 'Gemma-3'
        if '_270M_' in clean_csv_name: size = '270M'
        elif '_1B_' in clean_csv_name: size = '1B'
        elif '_4B_' in clean_csv_name: size = '4B'
    elif 'Llama3_2' in clean_csv_name:
        family = 'Llama-3.2'
        if '_3B_' in clean_csv_name: size = '3B'
        elif '_1B_' in clean_csv_name: size = '1B'
    elif 'Qwen2-5' in clean_csv_name:
        family = 'Qwen-2.5'
        if '_3B_' in clean_csv_name: size = '3B'
        elif '_1_5B_' in clean_csv_name: size = '1.5B'
        elif '_0_5B_' in clean_csv_name: size = '0.5B'
    elif 'Phi4' in clean_csv_name:
        family = 'Phi-4'
        size = 'Mini'
    elif 'Olmo2' in clean_csv_name:
        family = 'OLMo-2'
        size = '1B'
    elif 'SmolLM2' in clean_csv_name:
        family = 'SmolLM-2'
        if '_1_7B_' in clean_csv_name: size = '1.7B'
        elif '_360M_' in clean_csv_name: size = '360M'
            
    _name = f"{family}_{size}"
    
    return {
        '_name': _name, 'family': family, 'size': size,
        'numerical_size': size_to_numeric(size),
        'type': _type, 'contamination': contamination,
        'full_name': f"{family}-{size}"
    }

def load_and_process_data(irrelevant_file, counterfactual_file):
    """Load and process both CSV files for semantic tasks."""
    irrelevant_df = pd.read_csv(irrelevant_file)
    counterfactual_df = pd.read_csv(counterfactual_file)
    
    processed_data = []
    
    # Process Irrelevant and CounterFactual dataframes
    for df, task_name in [(counterfactual_df, 'CounterFactual'), (irrelevant_df, 'Irrelevant')]:
        for _, row in df.iterrows():
            if row['Question'] == 'AVERAGE_METRICS':
                _info = parse__info(row['csv_name'])
                _info.update({
                    'task': task_name,
                    'accuracy': float(row['accuracy'].rstrip('%')),
                    'adherence': float(row['variation_check'].rstrip('%')),
                })
                processed_data.append(_info)
    
    df = pd.DataFrame(processed_data)
    
    task_order = ['CounterFactual', 'Irrelevant']
    df['task'] = pd.Categorical(df['task'], categories=task_order, ordered=True)
    
    return df

def create_combined_heatmap(all_data, save_path='./'):
    """Creates a single figure with adherence (left) and accuracy (right) heatmaps."""
    print("Generating combined heatmap for semantic tasks...")

    # Sort data by family and then by numerical size
    all_data = all_data.sort_values(by=['family', 'numerical_size'])

    # Contamination Adherence
    adherence_pivot = all_data.pivot_table(
        values='adherence', index=['full_name', 'type'], columns=['task', 'contamination'], aggfunc='mean', sort=False
    )
    
    # Task Accuracy
    accuracy_pivot = all_data.pivot_table(
        values='accuracy', index=['full_name', 'type'], columns=['task', 'contamination'], aggfunc='mean', sort=False
    )

    adherence_for_json = adherence_pivot.copy()
    accuracy_for_json = accuracy_pivot.copy()

    adherence_for_json.columns = [f"{col[0]} - {col[1]}%" for col in adherence_for_json.columns]
    accuracy_for_json.columns = [f"{col[0]} - {col[1]}%" for col in accuracy_for_json.columns]

    # Combine into a single dictionary
    json_output_data = {
        'adherence': adherence_for_json.reset_index().to_dict(orient='records'),
        'accuracy': accuracy_for_json.reset_index().to_dict(orient='records')
    }
    
    json_path = Path(save_path) / 'Plot3_Combined_Semantic_Heatmap_Data.json'
    with open(json_path, 'w') as f:
        json.dump(json_output_data, f, indent=4)
    print(f"Heatmap data successfully saved to: {json_path}")

    y_labels = []
    for name, type_ in accuracy_pivot.index:
        processed_name = name.replace('-', '', 1).replace('-', '_')
        if type_ == 'Instruct':
            y_labels.append(f"{processed_name}_IT")
        else:
            y_labels.append(processed_name)

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(5.5, 3.5), sharey=True)

    sns.heatmap(adherence_pivot, annot=True, fmt='.1f', cmap='RdYlGn', 
                cbar_kws={'label': 'Contamination Adherence (%)'}, ax=ax1,
                yticklabels=y_labels)
    ax1.set_title('Sensitivity to Contamination', fontweight='bold', fontsize=8, pad=8)
    ax1.set_ylabel('')
    
    sns.heatmap(accuracy_pivot, annot=True, fmt='.1f', cmap='RdYlGn', 
                cbar_kws={'label': 'Task Accuracy (%)'}, ax=ax2,
                yticklabels=y_labels)
    ax2.set_title('Task Accuracy', fontweight='bold', fontsize=8, pad=8)
    ax2.set_ylabel('')

    x_tick_labels = adherence_pivot.columns.get_level_values('contamination')
    num_cont_levels = len(all_data['contamination'].unique())
    pos1 = (num_cont_levels / 2.0)
    pos2 = pos1 + num_cont_levels

    for ax in [ax1, ax2]:
        ax.set_xticklabels(x_tick_labels, rotation=0)
        ax.tick_params(axis='x', length=0)
        ax.set_xlabel('')
        
        ax.axvline(x=num_cont_levels, color='white', linewidth=4)

        ax.text(pos1, -0.07, 'CounterFactual', ha='center', va='top',
                fontweight='bold', transform=ax.get_xaxis_transform(), fontsize=6)
        ax.text(pos2, -0.07, 'Irrelevant', ha='center', va='top',
                fontweight='bold', transform=ax.get_xaxis_transform(), fontsize=6)

    plt.tight_layout(rect=[0, 0, 1, 1])
    plt.savefig(f'{save_path}/Plot3_Combined_Semantic_Heatmap.png', dpi=300)
    plt.savefig(f'{save_path}/Plot3_Combined_Semantic_Heatmap.pdf')
    plt.show()

def main():
    """Main execution function."""
    irrelevant_file = 'Irrelevant.csv'
    counterfactual_file = 'CounterFactual.csv'
    
    print("Loading and processing data for semantic tasks...")
    all_data = load_and_process_data(irrelevant_file, counterfactual_file)
    
    create_combined_heatmap(all_data)
    
    print("\nCombined semantic heatmap saved as PNG and PDF formats.")
    print("Analysis complete!")

if __name__ == "__main__":
    main()