import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os

def load_results(steps, model_name):
    file_path = f'/llm_unlearning/wmdp/data/{steps}steps_gcg_attack_results_{model_name}.csv'
    if os.path.exists(file_path):
        return pd.read_csv(file_path)
    return None

def calculate_asr_by_category(df):
    if df is None:
        return None
    return df.groupby('category')['success'].mean().reset_index()

def main():
    # Model names and steps
    model_names = ['zephyr-7b-beta', 'unlearned_model', 'zephyr_rmu']
    model_display_names = {'zephyr-7b-beta': 'Base', 'unlearned_model': 'ERASER', 'zephyr_rmu': 'RMU'}
    steps_list = [25, 50, 100, 250, 2500]
    
    all_results = []
    
    for steps in steps_list:
        for model_name in model_names:
            df = load_results(steps, model_name)
            if df is not None:
                asr_by_category = calculate_asr_by_category(df)
                if asr_by_category is not None:
                    asr_by_category['model'] = model_name
                    asr_by_category['steps'] = steps
                    all_results.append(asr_by_category)
    
    if not all_results:
        print("No results found!")
        return
        
    combined_results = pd.concat(all_results)
    combined_results.to_csv('/llm_unlearning/wmdp/data/gcg_combined_results.csv', index=False)
    plt.figure(figsize=(6.5,4.5))
    
    plot_data = combined_results.copy()
    plot_data['model'] = plot_data['model'].map(model_display_names)

    color_map = {
        'Base': '#4B77BE',
        'RMU': '#E74C3C',
        'ERASER': '#2ECC71'
    }

    plt.figure(figsize=(6.5, 4.5))

    for model in plot_data['model'].unique():
        model_data = plot_data[plot_data['model'] == model]
        grouped = model_data.groupby('steps')['success'].mean().reset_index()
        grouped = grouped.sort_values(by='steps')

        plt.plot(grouped['steps'], grouped['success'],
                label=model, color=color_map[model],
                marker='o', linewidth=2, alpha=0.8)
        plt.scatter(grouped['steps'], grouped['success'],
                    color=color_map[model], alpha=0.6, s=60)

    plt.grid(axis='y', linestyle='--', linewidth=0.5, alpha=0.5)
    plt.xlabel('Attack Iterations (m)', fontweight='bold', fontsize=13)
    plt.ylabel('ASR (%)', fontweight='bold', fontsize=13)
    plt.xscale('log')
    plt.xticks(steps_list, steps_list, fontsize=11)
    plt.yticks(fontsize=11)
    plt.legend(loc='lower right', fontsize=11, title_fontsize=12)

    plt.tight_layout()
    plt.savefig('/llm_unlearning/wmdp/data/asr_vs_steps_styled.png', dpi=1000, bbox_inches='tight')
    plt.close()

if __name__ == "__main__":
    main()