
import math
import json
import os
import sys
from collections import defaultdict
import pandas as pd

import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import roc_auc_score

from genpaths import *

def remove_invalid(lst):
    return [elem for elem in lst if not math.isnan(elem)]

def get_path(path_list, idx):
    path = path_list[idx]
    path = os.path.basename(path) + "_eval-" + path_list[-1] + ".json"
    path = os.path.join("mtd_scores", path)
    return path

def load(path_list, idx):
    path = get_path(path_list, idx)
    data = json.loads(open(path).read())
    return data

def main(
    max_fpr: float = 0.1,
):
    dataset_names = ["reddit", "amazon", "blogs"]
    name2idx = dict(zip(dataset_names, [2, 0, 1]))

    DETECTOR_ORDER = [
        "ReMoDetect",
        "RADAR",
        "Rank",
        "LogRank",
        "FastDetectGPT",
        "Binoculars",
        "StyleDetect",
        "StyleDetect-CISR",
        "StyleDetect-SD",
    ]
    
    # GROUPS = [
    #     ["RADAR", "ReMoDetect"],
    #     ["Rank", "LogRank", "FastDetectGPT", "Binoculars"],
    #     ["StyleDetect", "StyleDetect-CISR", "StyleDetect-SD"],
    # ]

    LABEL_MAP = {
        "LLMOPT": "Mistral-7B-DPO-FastDetectGPT",
        "Ours": "Ours (Mistral-7B)",
        "Ours_No_DPO": "Ours (no DPO)",
    }

    detectors = DETECTOR_ORDER
    methods   = ["Baseline", "LLMOPT", "OUTFOX", "Paraphrasing", "DIPPER", "Prompting", "TinyStyler", "Ours", "Ours_No_DPO"]
    Ns = [1, 5, 10, 25, 50]

    # Colors for methods
    method_colors = dict(zip(
        methods,
        plt.rcParams['axes.prop_cycle'].by_key()['color'][:len(methods)]
    ))

    fig, axes = plt.subplots(1, len(dataset_names), figsize=(12, 4), sharey=True, dpi=300)

    aucs_by_method = {}
    for ax, dataset in zip(axes, dataset_names):
        idx = name2idx[dataset]
        ax.set_title(dataset.capitalize(), fontsize=14)
        ax.set_xlabel("N", fontsize=12)
        ax.set_ylabel("Max AUROC(10)", fontsize=12)
        ax.grid(True, linestyle="--", alpha=0.4)

        table_data = {}
        for method in methods:
            aucs = []
            if method in LABEL_MAP:
                label = LABEL_MAP[method]
            else:
                label = method
                
            aucs_by_detector = defaultdict(list)
            for N in Ns:
                aucs_at_N = []
                for detector in detectors:
                    # load & clean human/machine scores
                    raw_h = remove_invalid(load(HUMAN, idx)[detector])
                    raw_m = remove_invalid(
                        load(MACHINE if method=="Baseline" else globals()[method.upper()], idx)[detector]
                    )
                    # invert Binoculars so higher=more human
                    if detector == "Binoculars":
                        raw_h = [-s for s in raw_h]
                        raw_m = [-s for s in raw_m]

                    # aggregate into non‐overlapping blocks of size N
                    h_agg = [sum(raw_h[i:i+N])/N for i in range(0, len(raw_h)-N+1, N)]
                    m_agg = [sum(raw_m[i:i+N])/N for i in range(0, len(raw_m)-N+1, N)]

                    # compute AUC at this N
                    labels = [0]*len(h_agg) + [1]*len(m_agg)
                    aucs_at_N.append(roc_auc_score(labels, h_agg + m_agg, max_fpr=max_fpr))
                    aucs_by_detector[detector].append(aucs_at_N[-1])
                aucs.append(max(aucs_at_N))

            if dataset == "reddit":
                aucs_by_method[method] = aucs_by_detector

            if method == "Ours" and dataset == "reddit":
                ours = aucs
            elif method == "Ours_No_DPO" and dataset == "reddit":
                ours_no_dpo = aucs
            
            if method != "Ours_No_DPO":
                ax.plot(
                    Ns, aucs,
                    marker='o',
                    linestyle='-',
                    linewidth=1.5,
                    color=method_colors[method],
                    label=label
                )
                table_data[label] = aucs
            
            if method == "Ours":
                if dataset != "blogs":
                    x_pt, y_pt = Ns[-3], aucs[-3]
                else:
                    x_pt, y_pt = Ns[0], aucs[0]
                ax.annotate(
                    "Ours (lower is better)",
                    xy=(x_pt, y_pt),
                    xycoords="data",
                    xytext=(x_pt+8, y_pt-0.1),   # adjust offset as needed
                    textcoords="data",
                    arrowprops=dict(
                        arrowstyle="->",
                        color="red",
                        lw=2
                    ),
                    fontsize=12,
                    color="red",
                    fontweight="bold"
                )


    # Shared legend below
    handles, labels = axes[0].get_legend_handles_labels()
    fig.tight_layout(rect=[0, 0.15, 1, 0.95])

    fig.legend(
        handles, labels,
        loc="lower center",
        ncol=4,
        fontsize=16,
        bbox_to_anchor=(0.5, -0.03)
    )

    plt.savefig("MGTD.pdf")
    plt.close()
    
    for method in aucs_by_method.keys():
        method_data = aucs_by_method[method]
        plt.figure(dpi=300)
        for detector_name, aucs in method_data.items():
            plt.plot(Ns, aucs, linewidth=1.5, label=detector_name, marker="o")
        plt.legend()
        plt.xlabel("N")
        plt.ylabel("AUROC(10)")
        plt.savefig("./appendix/{}.pdf".format(method))
        plt.close()
    
    plt.figure()
    plt.plot(Ns, ours, linewidth=1.5, marker="o", label="Ours")
    plt.plot(Ns, ours_no_dpo, linewidth=1.5, marker="o", label="Ours (no DPO)")
    plt.xlabel("N", fontsize=16)
    plt.ylabel("AUROC(10)", fontsize=16)
    plt.legend(fontsize=16)
    plt.savefig("./ablations_DPO.pdf")
    plt.close()
    
    fig, axes = plt.subplots(1, 2, figsize=(8, 4), sharey=True, dpi=300)
    axes[0].set_title(LABEL_MAP["LLMOPT"])
    axes[0].set_xlabel("N", fontsize=12)
    axes[0].set_ylabel("AUROC(10)", fontsize=12)
    axes[0].grid(True, linestyle="--", alpha=0.4)
    data = aucs_by_method["LLMOPT"]
    for detector, values in data.items():
        axes[0].plot(Ns, values, marker="o", linestyle="-", linewidth=1.5, label=detector)
    axes[1].set_title(LABEL_MAP["Ours"])
    axes[1].set_xlabel("N", fontsize=12)
    # axes[1].set_ylabel("AUROC(10)", fontsize=12)
    axes[1].grid(True, linestyle="--", alpha=0.4)
    data = aucs_by_method["Ours"]
    for detector, values in data.items():
        axes[1].plot(Ns, values, marker="o", linestyle="-", linewidth=1.5, label=detector)
    handles, labels = axes[0].get_legend_handles_labels()
    fig.tight_layout(rect=[0, 0.15, 1, 0.95])
    fig.legend(
        handles, labels,
        loc="lower center",
        ncol=5,
        fontsize=10,
        # bbox_to_anchor=(0.5, -0.03)
    )
    plt.savefig("MGTD_compare.pdf")
    plt.close()

    return 0

if __name__ == "__main__":
    sys.exit(main())