import os
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import seaborn as sns
import sys

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from config.constants import BASE_PROJECT_DIR



INPUT_CSV_PATH = str(BASE_PROJECT_DIR / "benchmarks/noise_experiment_results/noise_robustness_summary.csv")
OUTPUT_DIR = str(BASE_PROJECT_DIR / "figure")
OUTPUT_FILENAME = "fortress_noise_robustness_distribution.pdf"

                                                                           
BENCHMARKS = ['aegis_v2', 'fortress_dataset', 'jailbreakbench', 'xstest']

def setup_matplotlib_for_tmlr():
    """Sets Matplotlib parameters for TMLR publication-quality figures."""
    plt.style.use('seaborn-v0_8-paper')
    plt.rcParams.update({
        'font.size': 12,                                                           
        'axes.labelsize': 12,
        'axes.titlesize': 14,
        'xtick.labelsize': 10,
        'ytick.labelsize': 10,
        'legend.fontsize': 10,
        'font.family': 'serif',
        'font.serif': ['Times New Roman', 'serif'],
        'text.usetex': False,
        'figure.figsize': (10, 4.5),                                            
        'figure.dpi': 300,
        'savefig.dpi': 300,
        'axes.grid': True,
        'grid.alpha': 0.3,
        'grid.linestyle': '--',
    })
    print("Matplotlib configured for TMLR styling.")

def process_noise_data_for_plots(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Processes noise data for two types of plots: an aggregated line plot and
    a distributional box plot.

    Args:
        df: The raw DataFrame loaded from the CSV.

    Returns:
        A tuple containing:
        - aggregated_stats: DataFrame for the line plot (mean, std).
        - long_form_data: DataFrame for the box plot (raw F1 scores).
    """
    print(f"Processing data for benchmarks: {BENCHMARKS}")

                                                                
    long_form_data = df[df['dataset_name'].isin(BENCHMARKS)].copy()
    
                                                  
                                                                                     
    df_agg_per_run = long_form_data.groupby(['run_number', 'noise_level']).agg(
        avg_f1_per_run=('f1_unsafe', 'mean')
    ).reset_index()
                                                                                       
    aggregated_stats = df_agg_per_run.groupby('noise_level').agg(
        mean_f1=('avg_f1_per_run', 'mean'),
        std_f1=('avg_f1_per_run', 'std')
    ).reset_index()
    
                                                                           
    baseline_f1 = 0.857                             
    
                         
    baseline_agg_row = pd.DataFrame([{'noise_level': 0.0, 'mean_f1': baseline_f1, 'std_f1': 0.0}])
    aggregated_stats = pd.concat([baseline_agg_row, aggregated_stats], ignore_index=True)
    aggregated_stats = aggregated_stats.sort_values('noise_level').reset_index(drop=True)

                                                                  
    baseline_long_rows = pd.DataFrame({
        'noise_level': [0.0] * len(BENCHMARKS),
        'f1_unsafe': [baseline_f1] * len(BENCHMARKS)
    })
    long_form_data = pd.concat([baseline_long_rows, long_form_data[['noise_level', 'f1_unsafe']]], ignore_index=True)

    print("Data processing complete.")
    return aggregated_stats, long_form_data

def create_noise_robustness_plot(agg_df: pd.DataFrame, long_df: pd.DataFrame, output_path: str):
    """
    Creates and saves a two-subplot figure showing performance trend (line)
    and performance distribution (box plot) vs. data noise level.

    Args:
        agg_df: The aggregated DataFrame for the line plot.
        long_df: The long-form DataFrame for the box plot.
        output_path: The full path to save the output PDF file.
    """
                                                            
    setup_matplotlib_for_tmlr()

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(11, 4.5))

                                                                
    color_performance = "#009E73"                      
    color_distribution = "#0072B2"                        

                                                           
    ax1.plot(agg_df['noise_level'], agg_df['mean_f1'], color=color_performance, marker='.', markersize=6, linestyle='-', label='Average F1 Score')
    ax1.fill_between(
        agg_df['noise_level'],
        agg_df['mean_f1'] - agg_df['std_f1'],
        agg_df['mean_f1'] + agg_df['std_f1'],
        color=color_performance,
        alpha=0.2,
        label='Std. Dev. (across runs)'
    )
    ax1.set_title('Average Performance vs. Data Noise')
    ax1.set_xlabel('Noise Level (%)')
    ax1.set_ylabel('Average F1 Score')
    ax1.legend()
    
                                                                                    
    ax1.set_ylim(0.58, 0.92)

    ax1.xaxis.set_major_formatter(mticker.PercentFormatter(xmax=1.0))
    ax1.tick_params(axis='x', rotation=30)

                                                                
    
                                                                                        
                                                                 
    sns.boxplot(
        x='noise_level', y='f1_unsafe', data=long_df, ax=ax2,
        color=color_distribution,
        width=0.6,                               
        linewidth=1.0,                                                               
        medianprops={'color': 'black', 'linewidth': 1.5},                                   
        flierprops={'marker': 'o', 'markersize': 3, 'markeredgecolor': 'black'}                 
    )
    ax2.set_title('F1 Score Distribution vs. Data Noise')
    ax2.set_xlabel('Noise Level (%)')
    ax2.set_ylabel('F1 Score')

                                                                            
    current_labels = [item.get_text() for item in ax2.get_xticklabels()]
    new_labels = [f'{float(label):.0%}' for label in current_labels]
    ax2.set_xticklabels(new_labels)
    ax2.tick_params(axis='x', rotation=30)
    
                                              
    ax2.set_ylim(ax1.get_ylim())
    
                                      
    fig.tight_layout(pad=2.0)
    
    png_output_path = output_path.replace('.pdf', '.png')
    
    plt.savefig(output_path, bbox_inches='tight')
    plt.savefig(png_output_path, bbox_inches='tight')
    
    print(f"\nFigure saved successfully to:")
    print(f"  PDF: {output_path}")
    print(f"  PNG: {png_output_path}")
    plt.close()


if __name__ == "__main__":
    if not os.path.exists(INPUT_CSV_PATH):
        print(f"Error: Input file not found at '{INPUT_CSV_PATH}'")
    else:
        os.makedirs(OUTPUT_DIR, exist_ok=True)
        
        print(f"Loading data from {INPUT_CSV_PATH}...")
        raw_df = pd.read_csv(INPUT_CSV_PATH)
        
        aggregated_data, long_form_data = process_noise_data_for_plots(raw_df)
        
        full_output_path = os.path.join(OUTPUT_DIR, OUTPUT_FILENAME)
        
        create_noise_robustness_plot(aggregated_data, long_form_data, full_output_path)
