import json
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

                       
import sys                                        
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from config.constants import BASE_PROJECT_DIR


K_SENSITIVITY_FILE = str(BASE_PROJECT_DIR / "experiments/top_k/experiment_summary.json")
WEIGHTS_SENSITIVITY_FILE = str(BASE_PROJECT_DIR / "experiments/primary_mixed_weight/experiment_summary.json")
OUTPUT_FILENAME = 'figure_sensitivity_analysis.pdf'

                                                             
TARGET_BENCHMARKS = [
    'aegis_v2_english',
    'fortress_dataset_english',
    'xstest_english',
    'jailbreakbench_english'
]

def setup_matplotlib_for_tmlr():
    """Sets Matplotlib parameters for TMLR publication-quality figures."""
    plt.style.use('seaborn-v0_8-paper')
    plt.rcParams.update({
        'font.size': 10,
        'axes.labelsize': 10,
        'axes.titlesize': 12,
        'xtick.labelsize': 9,
        'ytick.labelsize': 9,
        'legend.fontsize': 9,
        'font.family': 'serif',
        'font.serif': ['Times New Roman', 'serif'],
        'text.usetex': False,
        'figure.figsize': (10, 4),                                    
        'figure.dpi': 300,
        'savefig.dpi': 300,
        'axes.grid': True,
        'grid.alpha': 0.3,
        'grid.linestyle': '--',
    })
    print("Matplotlib configured for TMLR styling.")

def process_k_sensitivity_data(filepath: str, benchmarks: list) -> pd.DataFrame:
    """
    Parses the k-sensitivity JSON file to calculate the average F1 score
    for each k value across the target benchmarks. This version reads parameters
    directly and uses a robust groupby-average method, ensuring correct averaging
    across all benchmarks.
    """
    if not os.path.exists(filepath) or os.path.getsize(filepath) == 0:
        print(f"Error: File '{filepath}' is missing or empty. Skipping k-sensitivity analysis.")
        return pd.DataFrame()
    try:
        with open(filepath, 'r') as f:
            data = json.load(f)
    except json.JSONDecodeError as e:
        print(f"Error: Failed to parse JSON in '{filepath}': {e}")
        return pd.DataFrame()

    all_runs_data = []
                                                                                    
    for bench_name in benchmarks:
        if bench_name in data['results']:
                                                                        
            for run_key, run_data in data['results'][bench_name].items():
                try:
                                                                      
                    params = run_data['parameters']
                    metrics = run_data['metrics']
                    
                    k_value = params.get('detection_pipeline.top_k_semantic_search')
                    f1_score = metrics.get('f1_unsafe')

                                                                    
                    if k_value is not None and f1_score is not None:
                        all_runs_data.append({
                            'k': int(k_value),                                      
                            'f1_unsafe': f1_score
                        })
                except KeyError as e:
                    print(f"Warning: Missing key {e} in run '{run_key}' for benchmark '{bench_name}'. Skipping.")

    if not all_runs_data:
        print("Error: No valid k-sensitivity data found. Cannot create plot.")
        return pd.DataFrame()

                                                   
    df_all = pd.DataFrame(all_runs_data)
    
                                                                                       
    df_avg = df_all.groupby('k')['f1_unsafe'].mean().reset_index()
    df_avg.rename(columns={'f1_unsafe': 'avg_f1'}, inplace=True)
    
                                        
    df_avg = df_avg.sort_values('k').reset_index(drop=True)

    print(f"Processed k-sensitivity data. Found {len(df_avg)} k-values by averaging over {len(benchmarks)} benchmarks.")
    return df_avg

def process_weights_sensitivity_data(filepath: str, benchmarks: list) -> pd.DataFrame:
    """
    Parses the weights-sensitivity JSON, correctly averages F1 scores for each weight 
    combination across all target benchmarks, and creates a pivot table for the heatmap.
    """
    if not os.path.exists(filepath) or os.path.getsize(filepath) == 0:
        print(f"Error: File '{filepath}' is missing or empty. Skipping weights-sensitivity analysis.")
        return pd.DataFrame()
    try:
        with open(filepath, 'r') as f:
            data = json.load(f)
    except json.JSONDecodeError as e:
        print(f"Error: Failed to parse JSON in '{filepath}': {e}")
        return pd.DataFrame()

    all_runs_data = []
                                                                                    
    for bench_name in benchmarks:
        if bench_name in data['results']:
                                                                        
            for run_key, run_data in data['results'][bench_name].items():
                try:
                                                                                         
                    params = run_data['parameters']
                    metrics = run_data['metrics']
                    
                    default_primary_weight = params.get('detection_pipeline.weighted_majority_vote_weights.default_primary')
                    mixed_primary_weight = params.get('detection_pipeline.weighted_majority_vote_weights.mixed_primary')
                    f1_score = metrics.get('f1_unsafe')

                                                                                 
                    if all(v is not None for v in [default_primary_weight, mixed_primary_weight, f1_score]):
                        all_runs_data.append({
                            'default_primary': default_primary_weight,
                            'mixed_primary': mixed_primary_weight,
                            'f1_unsafe': f1_score
                        })
                except KeyError as e:
                    print(f"Warning: Missing key {e} in run '{run_key}' for benchmark '{bench_name}'. Skipping.")

    if not all_runs_data:
        print("Error: No valid weight sensitivity data found. Cannot create plot.")
        return pd.DataFrame()

                                                   
    df_all = pd.DataFrame(all_runs_data)
    
                                                                                                
    df_avg = df_all.groupby(['default_primary', 'mixed_primary'])['f1_unsafe'].mean().reset_index()
    
                                                          
    heatmap_data = df_avg.pivot(index='default_primary', columns='mixed_primary', values='f1_unsafe')

                                                        
    heatmap_data = heatmap_data.sort_index(ascending=False)                                                
    heatmap_data = heatmap_data.sort_index(axis=1, ascending=True)                                              

    print(f"Processed weights-sensitivity data. Created a {heatmap_data.shape} pivot table by averaging over {len(benchmarks)} benchmarks.")
    return heatmap_data

def create_sensitivity_plot(k_data: pd.DataFrame, weights_data: pd.DataFrame, output_path: str):
    """
    Creates and saves a two-subplot figure showing sensitivity to k and weights.
    """
    setup_matplotlib_for_tmlr()
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))

                                                        
    ax1.plot(k_data['k'], k_data['avg_f1'], marker='o', linestyle='-', color='#0072B2', markersize=5)
    ax1.set_title('Parameter Sensitivity: k-Nearest Neighbors')
    ax1.set_xlabel('Number of Neighbors (k)')
    ax1.set_ylabel('Average F1 Score')
    ax1.grid(True, which='both', linestyle='--', linewidth=0.5)
    ax1.set_xticks(k_data['k'])                                         

                                                     
    sns.heatmap(
        weights_data,
        ax=ax2,
        annot=True,
        fmt=".3f",
        cmap="cividis",
        linewidths=.5,
        annot_kws={"size": 8}
    )
    ax2.set_title('Parameter Sensitivity: Ensemble Weights')
    ax2.set_xlabel(r'Mixed-Signal Primary Weight $W_{\mathrm{mix}}$')
    ax2.set_ylabel(r'Default Primary Weight $W_{\mathrm{def}}$')

    fig.tight_layout(pad=1.5)
    
                                                      
    png_output_path = output_path.replace('.pdf', '.png')
    plt.savefig(output_path, bbox_inches='tight')
    plt.savefig(png_output_path, bbox_inches='tight')
    
    print(f"\nFigure saved successfully to:")
    print(f"  PDF: {output_path}")
    print(f"  PNG: {png_output_path}")
    plt.close()

if __name__ == "__main__":
                                                               
    if not os.path.exists(K_SENSITIVITY_FILE):
        print(f"Error: The k-sensitivity file was not found: {K_SENSITIVITY_FILE}")
        k_df = pd.DataFrame()
    else:
        k_df = process_k_sensitivity_data(K_SENSITIVITY_FILE, TARGET_BENCHMARKS)

                                                             
    if not os.path.exists(WEIGHTS_SENSITIVITY_FILE):
        print(f"Error: The weights-sensitivity file was not found: {WEIGHTS_SENSITIVITY_FILE}")
        weights_pivot_df = pd.DataFrame()
    else:
        weights_pivot_df = process_weights_sensitivity_data(WEIGHTS_SENSITIVITY_FILE, TARGET_BENCHMARKS)

                                                                               
    if not k_df.empty and not weights_pivot_df.empty:
        create_sensitivity_plot(k_df, weights_pivot_df, OUTPUT_FILENAME)
    else:
        print("\nPlot generation skipped due to one or more data files not being processed.")