import os
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import seaborn as sns
import numpy as np

import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from config.constants import BASE_PROJECT_DIR


INPUT_CSV_PATH = str(BASE_PROJECT_DIR / "benchmarks/scalability_results_20250624_131222/scalability_results_live_r1.csv")
OUTPUT_DIR = str(BASE_PROJECT_DIR / "figure")
OUTPUT_FILENAME = "fortress_scalability_performance.pdf"

                                                                           
                                                             
BENCHMARKS = ['aegis_v2', 'fortress_dataset', 'jailbreakbench', 'xstest']

def setup_matplotlib_for_tmlr():
    """Sets Matplotlib parameters for TMLR publication-quality figures."""
    plt.style.use('seaborn-v0_8-paper')
    plt.rcParams.update({
        'font.size': 10,
        'axes.labelsize': 10,
        'axes.titlesize': 12,
        'xtick.labelsize': 9,
        'ytick.labelsize': 9,
        'legend.fontsize': 9,
        'font.family': 'serif',
        'font.serif': ['Times New Roman', 'serif'],
        'text.usetex': False,                                                            
        'figure.figsize': (10, 4),                                    
        'figure.dpi': 300,
        'savefig.dpi': 300,
        'axes.grid': True,
        'grid.alpha': 0.3,
        'grid.linestyle': '--',
    })
    print("Matplotlib configured for TMLR styling.")

def process_scalability_data(df: pd.DataFrame) -> pd.DataFrame:
    """
    Processes the raw scalability data to calculate mean and standard deviation
    for latency and F1 score at each database size, aggregated across runs.

    Args:
        df: The raw DataFrame loaded from the CSV.

    Returns:
        A processed DataFrame with statistics for plotting.
    """
    print(f"Processing data for benchmarks: {BENCHMARKS}")
    
                                            
    df_filtered = df[df['benchmark'].isin(BENCHMARKS)].copy()
    
                                                                             
                                                                   
    df_agg_per_run = df_filtered.groupby(['run', 'db_size']).agg(
        avg_f1=('f1_unsafe', 'mean'),
        avg_latency=('latency_ms', 'mean')
    ).reset_index()

                                                                                      
                                                                  
    final_stats = df_agg_per_run.groupby('db_size').agg(
        mean_latency=('avg_latency', 'mean'),
        std_latency=('avg_latency', 'std'),
        mean_f1=('avg_f1', 'mean'),
        std_f1=('avg_f1', 'std')
    ).reset_index()

                                                                      
    final_stats['std_latency'] = final_stats['std_latency'].fillna(0)
    final_stats['std_f1'] = final_stats['std_f1'].fillna(0)
    
                                                        
    final_stats = final_stats.sort_values('db_size').reset_index(drop=True)

    print("Data processing complete.")
    return final_stats


def create_scalability_plot(stats_df: pd.DataFrame, output_path: str):
    """
    Creates and saves a two-subplot figure showing latency and performance
    vs. knowledge base size.

    Args:
        stats_df: The processed DataFrame containing plotting statistics.
        output_path: The full path to save the output PDF file.
    """
    setup_matplotlib_for_tmlr()

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))

                                                              
    color_latency = "#0072B2"                      
    color_f1 = "#009E73"                            

                                            
    ax1.plot(stats_df['db_size'], stats_df['mean_latency'], color=color_latency, marker='.', markersize=4, linestyle='-', label='Average Latency')
    ax1.fill_between(
        stats_df['db_size'],
        stats_df['mean_latency'] - stats_df['std_latency'],
        stats_df['mean_latency'] + stats_df['std_latency'],
        color=color_latency,
        alpha=0.2,
        label='Std. Dev. (across runs)'
    )
    ax1.set_title('System Latency vs. Knowledge Base Size')
    ax1.set_xlabel('Knowledge Base Size')
    ax1.set_ylabel('Average Inference Latency (ms)')
    
                                                                                      
                                            
    min_lat = stats_df['mean_latency'].min()
    max_lat = stats_df['mean_latency'].max()
    padding = (max_lat - min_lat) * 0.1                  
    ax1.set_ylim(min_lat - padding, max_lat + padding)
    
    ax1.legend()

                                                     
    ax2.plot(stats_df['db_size'], stats_df['mean_f1'], color=color_f1, marker='.', markersize=4, linestyle='-', label='Average F1 Score')
    ax2.fill_between(
        stats_df['db_size'],
        stats_df['mean_f1'] - stats_df['std_f1'],
        stats_df['mean_f1'] + stats_df['std_f1'],
        color=color_f1,
        alpha=0.2,
        label='Std. Dev. (across runs)'
    )
    ax2.set_title('System Performance vs. Knowledge Base Size')
    ax2.set_xlabel('Knowledge Base Size')
    ax2.set_ylabel('Average F1 Score')

                                                                            
    ax2.set_ylim(0.70, 1.0)
    
    ax2.legend()
    
                                                             
    for ax in [ax1, ax2]:
                                                 
        ax.xaxis.set_major_formatter(mticker.FuncFormatter(lambda x, p: f'{x/1000:.0f}K'))
        ax.tick_params(axis='x', rotation=30)
        
                                      
    fig.tight_layout(pad=1.5)
    
                                      
    png_output_path = output_path.replace('.pdf', '.png')
    
    plt.savefig(output_path, bbox_inches='tight')
    plt.savefig(png_output_path, bbox_inches='tight')
    
    print(f"\nFigure saved successfully to:")
    print(f"  PDF: {output_path}")
    print(f"  PNG: {png_output_path}")
    plt.close()


if __name__ == "__main__":
    if not os.path.exists(INPUT_CSV_PATH):
        print(f"Error: Input file not found at '{INPUT_CSV_PATH}'")
    else:
                                                     
        os.makedirs(OUTPUT_DIR, exist_ok=True)
        
                            
        raw_df = pd.read_csv(INPUT_CSV_PATH)
        
                                                    
        plotting_stats = process_scalability_data(raw_df)
        
                                     
        full_output_path = os.path.join(OUTPUT_DIR, OUTPUT_FILENAME)
        
                                  
        create_scalability_plot(plotting_stats, full_output_path)