import os
import json
import re
import sys
from pathlib import Path
from collections import defaultdict
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

                          

               
                                             
try:
    from config.constants import BASE_PROJECT_DIR as BASE_DIR
except ImportError:
                                                 
    print("Warning: config.constants not found. Using script's parent directory as BASE_DIR.")
    BASE_DIR = Path(__file__).resolve().parent.parent

RESULTS_DIR = BASE_DIR / "benchmarks" / "results_data"                                                             
OUTPUT_DIR = BASE_DIR / "paper_assets"                                                      

                                    
                                                                         
MODEL_MAPPING = {
    "FORTRESS Gemma 1B (Exp.)": "fortress_1b_all_augmented",
    "FORTRESS Gemma 4B (Exp.)": "fortress_4b_all_augmented",
    "FORTRESS Qwen 0.6B (Exp.)": "fortress_qwen_0_6b",
    "FORTRESS Qwen 4B (Exp.)": "fortress_qwen_4b",
               
    "AegisGuard Defensive": "aegisguard_aegis_ai_content_safety_llamaguard_defensive",
    "AegisGuard Permissive": "aegisguard_aegis_ai_content_safety_llamaguard_permissive",
    "Ayub XGBoost": "ayub_oai_xgb_guard",
    "GuardReasoner 1B": "guardreasoner_guardreasoner_1b",
    "GuardReasoner 3B": "guardreasoner_guardreasoner_3b",
    "GuardReasoner 8B": "guardreasoner_guardreasoner_8b",
    "LlamaGuard-3 1B": "llamaguard_llama_guard_3_1b",
    "LlamaGuard-3 8B": "llamaguard_llama_guard_3_8b",
    "ShieldGemma-1 2B": "shieldgemma_shieldgemma_2b",
    "ShieldGemma-1 9B": "shieldgemma_shieldgemma_9b",
    "ShieldGemma-2 4B": "shieldgemma2_shieldgemma_2_4b",
    "WildGuard (7B)": "wildguard_wildguard",
    "OpenAI Moderation": "openai_mod",
}

                                                                              
BENCHMARK_MAPPING = {
    "Aegis": "aegis_v2_english",
    "Ailum": "ailuminate_english",
    "FORT": "fortress_dataset_english",
    "Harm": "harmbench_english",
    "JBB": "jailbreakbench_english",
    "OAI": "openai_moderation_english",
    "Simple": "simple_safety_english",
    "XSafe": "xsafety_multilingual",
    "XSTest": "xstest_english",
}

                                                            
MODEL_ORDER = [
    "AegisGuard Defensive",
    "AegisGuard Permissive",
    "Ayub XGBoost",
    "GuardReasoner 1B",
    "GuardReasoner 3B",
    "GuardReasoner 8B",
                     
    "LlamaGuard-3 1B",
    "LlamaGuard-3 8B",
    "OpenAI Moderation",
    "ShieldGemma-1 2B",
    "ShieldGemma-1 9B",
    "ShieldGemma-2 4B",
    "WildGuard (7B)",

    "FORTRESS Qwen 4B (Exp.)",
    "FORTRESS Qwen 0.6B (Exp.)",
    "FORTRESS Gemma 4B (Exp.)",
    "FORTRESS Gemma 1B (Exp.)",
]

BENCHMARK_ORDER = ["Aegis", "Ailum", "FORT", "Harm", "JBB", "OAI", "Simple", "XSafe", "XSTest"]

                               
MAIN_METRICS = {
    'f1': {'key': 'f1_unsafe', 'title': 'F1 Score', 'fmt': '.1f', 'scale': 100},
    'precision': {'key': 'precision_unsafe', 'title': 'Precision', 'fmt': '.1f', 'scale': 100},
    'recall': {'key': 'recall_unsafe', 'title': 'Recall', 'fmt': '.1f', 'scale': 100},
    'latency': {'key': 'latency_ms', 'title': 'Average Latency (ms)', 'fmt': '.1f', 'scale': 1},
}

                                      

def get_latest_file(files: list[Path]) -> Path:
    """Selects the latest file based on timestamp in filename or modification time."""
    if not files: return None
    if len(files) == 1: return files[0]
    def get_sort_key(fp):
        match = re.search(r'_(\d{8}_\d{6})', fp.name)
        return match.group(1) if match else str(int(fp.stat().st_mtime))
    return sorted(files, key=get_sort_key, reverse=True)[0]

def load_and_process_data() -> pd.DataFrame:
    """
    Scans results directories for the main metrics, parses the latest JSON for each
    model-benchmark pair, and returns a clean pandas DataFrame.
    """
    print("Starting data loading for main metrics...")
    all_data = []
    for model_display_name, model_pattern in MODEL_MAPPING.items():
        for bench_display_name, bench_dir_name in BENCHMARK_MAPPING.items():
            benchmark_path = RESULTS_DIR / bench_dir_name
            if not benchmark_path.is_dir(): continue
            candidate_files = list(benchmark_path.glob(f"*{model_pattern}*.json"))
            latest_file = get_latest_file(candidate_files)
            if not latest_file: continue
            try:
                with open(latest_file, 'r', encoding='utf-8') as f: data = json.load(f)
                metrics = data.get("metrics", {})
                latencies = [res.get("processing_time_ms", 0) for res in data.get("individual_results", []) if res.get("processing_time_ms") is not None]
                all_data.append({
                    "model_name": model_display_name,
                    "benchmark_name": bench_display_name,
                    "f1_unsafe": metrics.get("f1_unsafe", 0.0),
                    "precision_unsafe": metrics.get("precision_unsafe", 0.0),
                    "recall_unsafe": metrics.get("recall_unsafe", 0.0),
                    "latency_ms": np.mean(latencies) if latencies else 0.0,
                })
            except (json.JSONDecodeError, KeyError, ValueError) as e:
                print(f"  [Warning] Could not process {latest_file.name}: {e}")
    if not all_data: raise ValueError("No valid data found for main metrics.")
    print(f"Successfully processed {len(all_data)} result entries for main metrics.")
    return pd.DataFrame(all_data)

                              

def setup_matplotlib_for_paper():
    """Sets Matplotlib parameters for publication-quality figures."""
    plt.style.use('seaborn-v0_8-paper')
    plt.rcParams.update({
        'font.size': 8, 'axes.labelsize': 9, 'axes.titlesize': 11,
        'xtick.labelsize': 8, 'ytick.labelsize': 8, 'legend.fontsize': 8,
        'font.family': 'serif', 'font.serif': ['Times New Roman', 'serif'],
        'text.usetex': False, 'figure.dpi': 300, 'savefig.dpi': 300,
    })

def create_heatmap(df: pd.DataFrame, metric_info: dict, output_dir: Path):
    """Generates and saves a heatmap for a given metric."""
    metric_key, title, fmt, scale = metric_info['key'], metric_info['title'], metric_info['fmt'], metric_info['scale']
    print(f"Generating heatmap for: {title}...")
    pivot_df = df.pivot(index="model_name", columns="benchmark_name", values=metric_key)
    pivot_df = pivot_df.reindex(index=MODEL_ORDER, columns=BENCHMARK_ORDER).dropna(how='all')
    if metric_key == 'latency_ms':
        pivot_df = pivot_df.drop(index=["OpenAI Moderation", "Ayub XGBoost"], errors='ignore')
    pivot_df *= scale
    setup_matplotlib_for_paper()
    fig_height = max(5, len(pivot_df.index) * 0.2)
    fig_width = 10
    plt.figure(figsize=(fig_width, fig_height))
    ax = sns.heatmap(pivot_df, annot=True, fmt=fmt, cmap="viridis", linewidths=.5, cbar_kws={'label': title})
    ax.set_xlabel("Benchmark", fontsize=10)
    ax.set_ylabel("Model", fontsize=10)
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    filename = f"heatmap_{metric_key}.pdf"
    output_path = output_dir / filename
    plt.savefig(output_path, bbox_inches='tight')
    plt.savefig(output_path.with_suffix(".png"), bbox_inches='tight')
    print(f"  -> Saved to {output_path}")
    plt.close()

def generate_xsafety_language_heatmap(output_dir: Path):
    """Calculates and displays per-language F1 scores for all models on the XSafety benchmark."""
    print("\nGenerating XSafety Per-Language Performance Heatmap...")
    benchmark_dir = RESULTS_DIR / BENCHMARK_MAPPING["XSafe"]
    if not benchmark_dir.is_dir():
        print(f"  [Error] XSafety directory not found: {benchmark_dir}")
        return

    lang_code_pattern = re.compile(r"all_languages_matter_(\w+)\.jsonl")
    stats = defaultdict(lambda: defaultdict(lambda: {"TP": 0, "FP": 0, "FN": 0}))

    for model_display_name, model_pattern in MODEL_MAPPING.items():
        candidate_files = list(benchmark_dir.glob(f"*{model_pattern}*.json"))
        latest_file = get_latest_file(candidate_files)
        if not latest_file: continue
        print(f"  -> Processing {model_display_name} from {latest_file.name}")
        try:
            with open(latest_file, 'r', encoding='utf-8') as f: data = json.load(f)
            for result in data.get("individual_results", []):
                match = lang_code_pattern.search(result.get("source_file_input", ""))
                if not match: continue
                lang_code = match.group(1).upper()
                true_label = result.get("true_label", "").upper()
                pred_label = result.get("predicted_label", "").upper()
                if true_label == "UNSAFE" and pred_label == "UNSAFE": stats[model_display_name][lang_code]["TP"] += 1
                elif true_label == "SAFE" and pred_label == "UNSAFE": stats[model_display_name][lang_code]["FP"] += 1
                elif true_label == "UNSAFE" and pred_label != "UNSAFE": stats[model_display_name][lang_code]["FN"] += 1
        except Exception as e:
            print(f"    [Warning] Error processing file {latest_file.name}: {e}")

    f1_data = []
    for model, lang_stats in stats.items():
        for lang, counts in lang_stats.items():
            tp, fp, fn = counts["TP"], counts["FP"], counts["FN"]
            precision = tp / (tp + fp) if (tp + fp) > 0 else 0
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0
            f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
            f1_data.append({"model_name": model, "language": lang, "f1_score": f1 * 100})
    
    if not f1_data:
        print("  [Error] No per-language data could be extracted. Skipping heatmap.")
        return

    df_f1 = pd.DataFrame(f1_data)
    pivot_df = df_f1.pivot(index="model_name", columns="language", values="f1_score")
    pivot_df = pivot_df.reindex(index=MODEL_ORDER).dropna(how='all')

    setup_matplotlib_for_paper()
    fig_height = max(5, len(pivot_df.index) * 0.2)
    fig_width = 10
    plt.figure(figsize=(fig_width, fig_height))
    ax = sns.heatmap(pivot_df, annot=True, fmt=".1f", cmap="viridis_r", linewidths=.5, cbar_kws={'label': 'F1 Score'})
    ax.set_xlabel("Language Code", fontsize=10)
    ax.set_ylabel("Model", fontsize=10)
    plt.xticks(rotation=0)
    plt.yticks(rotation=0)
    filename = "heatmap_xsafety_per_language.pdf"
    output_path = output_dir / filename
    plt.savefig(output_path, bbox_inches='tight')
    plt.savefig(output_path.with_suffix(".png"), bbox_inches='tight')
    print(f"  -> Saved to {output_path}")
    plt.close()

                           

if __name__ == "__main__":
    print("Starting paper asset generation script.")
    print(f"Results Source: {RESULTS_DIR.resolve()}")
    print(f"Assets Destination: {OUTPUT_DIR.resolve()}")
    OUTPUT_DIR.mkdir(exist_ok=True)
    try:
                                       
        master_df = load_and_process_data()
        
                                            
        for metric_details in MAIN_METRICS.values():
            create_heatmap(master_df, metric_details, OUTPUT_DIR)
            
                                                              
        generate_xsafety_language_heatmap(OUTPUT_DIR)
        
        print("\n[SUCCESS] All assets generated successfully!")

    except Exception as e:
        print(f"\n[ERROR] An error occurred during asset generation: {e}")
        import traceback
        traceback.print_exc()