import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


SCENARIOS = {
    "clusters1234_h0_delta2_audio_prompts": (1e-30, {
        "Base": "/home/wmar/wmar_audio/outputs/wm_generations_new/audio_prompts/clusters1234_h0_delta2/results_standard.csv",
        "WMAR (aug)": "/home/wmar/wmar_audio/outputs/wm_generations_new/audio_prompts/meta_aug_clusters1234_h0_delta2/results_standard.csv",
        "WMAR": "/home/wmar/wmar_audio/outputs/wm_generations_new/audio_prompts/meta_noaug_clusters1234_h0_delta2/results_standard.csv",
        "Ours": "/home/wmar/wmar_audio/outputs/wm_generations_new/audio_prompts/clusters1234_h0_delta2/results_selected.csv",
    }),
    "clusters1234_h1_delta2_audio_prompts": (1e-15, {
        "Base": "/home/wmar/wmar_audio/outputs/wm_generations_new/audio_prompts/clusters1234_h1_delta2/results_standard.csv",
        "WMAR (aug)": "/home/wmar/wmar_audio/outputs/wm_generations_new/audio_prompts/meta_aug_clusters1234_h1_delta2/results_standard.csv",
        "WMAR": "/home/wmar/wmar_audio/outputs/wm_generations_new/audio_prompts/meta_noaug_clusters1234_h1_delta2/results_standard.csv",
        "Ours": "/home/wmar/wmar_audio/outputs/wm_generations_new/audio_prompts/clusters1234_h1_delta2/results_selected.csv",
    }),
    "clusters1234_h2_delta2_audio_prompts": (1e-7, {
        "Base": "/home/wmar/wmar_audio/outputs/wm_generations_new/audio_prompts/clusters1234_h2_delta2/results_standard.csv",
        "WMAR (aug)": "/home/wmar/wmar_audio/outputs/wm_generations_new/audio_prompts/meta_aug_clusters1234_h2_delta2/results_standard.csv",
        "WMAR": "/home/wmar/wmar_audio/outputs/wm_generations_new/audio_prompts/meta_noaug_clusters1234_h2_delta2/results_standard.csv",
        "Ours": "/home/wmar/wmar_audio/outputs/wm_generations_new/audio_prompts/clusters1234_h2_delta2/results_selected.csv",
    }),

    "clusters1234_h0_delta2_librispeech": (1e-20, {
        "Base": "/home/wmar/wmar_audio/outputs/wm_generations_new/librispeech/clusters1234_h0_delta2/results_standard.csv",
        "WMAR (aug)": "/home/wmar/wmar_audio/outputs/wm_generations_new/librispeech/meta_aug_clusters1234_h0_delta2/results_standard.csv",
        "WMAR": "/home/wmar/wmar_audio/outputs/wm_generations_new/librispeech/meta_noaug_clusters1234_h0_delta2/results_standard.csv",
        "Ours": "/home/wmar/wmar_audio/outputs/wm_generations_new/librispeech/clusters1234_h0_delta2/results_selected.csv",
    }),
    "clusters1234_h1_delta2_librispeech": (1e-10, {
        "Base": "/home/wmar/wmar_audio/outputs/wm_generations_new/librispeech/clusters1234_h1_delta2/results_standard.csv",
        "WMAR (aug)": "/home/wmar/wmar_audio/outputs/wm_generations_new/librispeech/meta_aug_clusters1234_h1_delta2/results_standard.csv",
        "WMAR": "/home/wmar/wmar_audio/outputs/wm_generations_new/librispeech/meta_noaug_clusters1234_h1_delta2/results_standard.csv",
        "Ours": "/home/wmar/wmar_audio/outputs/wm_generations_new/librispeech/clusters1234_h1_delta2/results_selected.csv",
    }),
    "clusters1234_h2_delta2_librispeech": (1e-5, {
        "Base": "/home/wmar/wmar_audio/outputs/wm_generations_new/librispeech/clusters1234_h2_delta2/results_standard.csv",
        "WMAR (aug)": "/home/wmar/wmar_audio/outputs/wm_generations_new/librispeech/meta_aug_clusters1234_h2_delta2/results_standard.csv",
        "WMAR": "/home/wmar/wmar_audio/outputs/wm_generations_new/librispeech/meta_noaug_clusters1234_h2_delta2/results_standard.csv",
        "Ours": "/home/wmar/wmar_audio/outputs/wm_generations_new/librispeech/clusters1234_h2_delta2/results_selected.csv",
    }),

    "clusters1234_h0_delta2_music": (1e-40, {
        "Base": "/home/wmar/wmar_audio/outputs/wm_generations_new/music/final_clusters0123_h0_delta2/results_standard.csv",
        "WMAR": "/home/wmar/wmar_audio/outputs/wm_generations_new/music/meta_noaug_clusters1234_h0_delta2/results_standard.csv",
        "Ours": "/home/wmar/wmar_audio/outputs/wm_generations_new/music/final_clusters0123_h0_delta2/results_selected.csv",
    }),
}

plots_dir = "/home/wmar/wmar_audio/outputs/paper_plots"

for scenario, info in SCENARIOS.items():
    min_fpr, paths = info

    plt.figure(figsize=(4, 3))
    target_thresholds = np.logspace(np.log10(min_fpr), -1, 1000)

    for method, path in paths.items():
        df = pd.read_csv(path)
        pvals = df["pval"].values
        pvals = np.array(pvals)

        # Shape: (N_pvals, 1) <= (1, N_thresholds)
        detections = (pvals[:, None] <= target_thresholds[None, :])
        detection_rates = np.mean(detections, axis=0)
        plt.plot(target_thresholds, detection_rates, label=method)

    plt.axvline(x=1e-2, color='black', linestyle='--', label="1% threshold")

    plt.xlabel('Theoretical FPR')
    plt.ylabel('True Positive Rate')
    plt.xscale('log')
    plt.xlim(min_fpr, 1e-1)
    plt.ylim(0, 1.05)
    plt.grid()
    plt.legend()
    plt.savefig(f"{plots_dir}/detection_{scenario}.png", dpi=200, bbox_inches="tight")
