
import math
import json
import os
import sys
from collections import defaultdict
import pandas as pd

import fire
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import roc_auc_score

from genpaths import *

def remove_invalid(lst):
    return [elem for elem in lst if not math.isnan(elem)]

def get_path(path_list, idx):
    path = path_list[idx]
    path = os.path.basename(path) + "_eval-" + path_list[-1] + ".json"
    path = os.path.join("mtd_scores", path)
    return path

def load(path_list, idx):
    path = get_path(path_list, idx)
    data = json.loads(open(path).read())
    return data

def main(
    max_fpr: float = 0.01,
):
    dataset_names = ["reddit", "amazon", "blogs"]
    name2idx = dict(zip(dataset_names, [2, 0, 1]))

    DETECTOR_ORDER = [
        "ReMoDetect",
        "RADAR",
        "Rank",
        "LogRank",
        "FastDetectGPT",
        "Binoculars",
        "StyleDetect",
        "StyleDetect-CISR",
        "StyleDetect-SD",
        "SemDetect",
    ]
    
    LABEL_MAP = {
        "Prompting": "Prompting (Style-aware)",
        "LLMOPT": "Detector-Guided DPO (Target: FastDetectGPT)",
        "Ours": "Style-aware Paraphrasing",
        "Ours_No_DPO": "Ours (no DPO)",
        "DG": "D+G (Base)",
        "DG04": "D+G (D=0.4)",
    }

    detectors = DETECTOR_ORDER
    # methods   = ["Baseline", "LLMOPT", "OUTFOX", "Paraphrasing", "DIPPER", "Prompting", "TinyStyler", "Ours", "Ours_No_DPO"]
    methods   = ["Baseline", "LLMOPT", "OUTFOX", "Paraphrasing", "DIPPER", "Prompting", "TinyStyler", "Ours"]
    Ns = [1, 5, 10, 25, 50]

    method_colors = dict(zip(
        methods,
        plt.rcParams['axes.prop_cycle'].by_key()['color'][:len(methods)]
    ))

    fig, axes = plt.subplots(1, len(dataset_names), figsize=(12, 4), sharey=True, dpi=300)
    if not hasattr(axes, "__iter__"):
        axes = [axes]

    aucs_by_method = {}
    for ax, dataset in zip(axes, dataset_names):
        idx = name2idx[dataset]
        ax.set_title(dataset.capitalize(), fontsize=14)
        ax.set_xlabel("Number of Samples", fontsize=12)
        ax.set_ylabel("Max AUROC({})".format(int(max_fpr*100)), fontsize=12)
        ax.grid(True, linestyle="--", alpha=0.4)

        table_data = {}
        if dataset == "reddit":
            reddit_custom_table = {}
        for method in methods:
            aucs = []
            if method in LABEL_MAP:
                label = LABEL_MAP[method]
            else:
                label = method
                
            aucs_by_detector = defaultdict(list)
            for N in Ns:
                aucs_at_N = []
                for detector in detectors:
                    # load & clean human/machine scores
                    raw_h = remove_invalid(load(HUMAN, idx)[detector])
                    raw_m = remove_invalid(
                        load(MACHINE if method=="Baseline" else globals()[method.upper()], idx)[detector]
                    )
                    # invert Binoculars so higher=more human
                    if detector == "Binoculars":
                        raw_h = [-s for s in raw_h]
                        raw_m = [-s for s in raw_m]

                    # aggregate into non‐overlapping blocks of size N
                    h_agg = [sum(raw_h[i:i+N])/N for i in range(0, len(raw_h)-N+1, N)]
                    m_agg = [sum(raw_m[i:i+N])/N for i in range(0, len(raw_m)-N+1, N)]

                    # compute AUC at this N
                    labels = [0]*len(h_agg) + [1]*len(m_agg)
                    aucs_at_N.append(roc_auc_score(labels, h_agg + m_agg, max_fpr=max_fpr))
                    aucs_by_detector[detector].append(aucs_at_N[-1])

                    if dataset == "reddit" and N == 1 and detector in ["FastDetectGPT", "Binoculars", "StyleDetect", "SemDetect"]:
                        full_auc = roc_auc_score(labels, h_agg + m_agg)
                        if method not in reddit_custom_table:
                             reddit_custom_table[method] = {}
                        reddit_custom_table[method][detector] = full_auc

                aucs.append(max(aucs_at_N))

            if dataset == "reddit":
                aucs_by_method[method] = aucs_by_detector

            if method == "Ours" and dataset == "reddit":
                ours = aucs
            elif method == "Ours_No_DPO" and dataset == "reddit":
                ours_no_dpo = aucs
            
            if method != "Ours_No_DPO":
                ax.plot(
                    Ns, aucs,
                    marker='o',
                    linestyle='-',
                    linewidth=1.5,
                    color=method_colors[method],
                    label=label
                )
                table_data[label] = aucs
            
            if method == "Ours":
                if dataset != "blogs":
                    x_pt, y_pt = Ns[-3], aucs[-3]
                else:
                    x_pt, y_pt = Ns[0], aucs[0]
                ax.annotate(
                    "Ours (lower is better)",
                    xy=(x_pt, y_pt),
                    xycoords="data",
                    xytext=(x_pt+8, y_pt-0.1),   # adjust offset as needed
                    textcoords="data",
                    arrowprops=dict(
                        arrowstyle="->",
                        color="red",
                        lw=2
                    ),
                    fontsize=12,
                    color="red",
                    fontweight="bold"
                )

        if dataset == "reddit":
            desired_methods = ["Baseline", "LLMOPT", "OUTFOX", "Paraphrasing", "DIPPER", "Prompting", "TinyStyler"]
            desired_detectors = ["FastDetectGPT", "Binoculars", "StyleDetect", "SemDetect"]

            data_for_df = {}
            for meth in desired_methods:
                if meth in reddit_custom_table:
                    data_for_df[meth] = [reddit_custom_table[meth].get(det, float('nan')) for det in desired_detectors]
                else:
                     data_for_df[meth] = [float('nan')] * len(desired_detectors)

            full_auc_df = pd.DataFrame(data_for_df, index=desired_detectors).T
            print("\n### Reddit Full AUC (N=1)\n")
            reddit_markdown_table = full_auc_df.to_markdown(floatfmt=".4f")
            print(reddit_markdown_table)
            
            with open("REDDIT_FULL_AUC_TABLE.md", "w") as f:
                f.write(reddit_markdown_table)

        # Create and print Markdown table
        df = pd.DataFrame(table_data, index=Ns).T
        print(f"\n### {dataset.capitalize()}\n")
        markdown_table = df.to_markdown(floatfmt=".2f")
        print(markdown_table)
        
        txt_path = "MGTD_{:.2f}.txt".format(max_fpr)
        mode = 'w' if dataset == dataset_names[0] else 'a'
        with open(txt_path, mode) as f:
            f.write(f"\n### {dataset.capitalize()}\n\n")
            f.write(markdown_table)
            f.write("\n")

    # Shared legend below
    handles, labels = axes[0].get_legend_handles_labels()
    fig.tight_layout(rect=[0, 0.15, 1, 0.95])

    fig.legend(
        handles, labels,
        loc="lower center",
        ncol=5,
        fontsize=10,
        bbox_to_anchor=(0.5, -0.03)
    )

    plt.savefig("MGTD_{:.2f}.pdf".format(max_fpr))
    plt.savefig("MGTD_{:.2f}.png".format(max_fpr))
    plt.close()
    
    # if max_fpr == 0.1:
    for method in aucs_by_method.keys():
        method_data = aucs_by_method[method]
        plt.figure(dpi=300)
        for detector_name, aucs in method_data.items():
            plt.plot(Ns, aucs, linewidth=1.5, label=detector_name, marker="o")
        plt.legend()
        plt.xlabel("Number of Samples")
        plt.ylabel("AUROC({})".format(int(max_fpr*100)))
        plt.savefig("./appendix/{}_{:.2f}.pdf".format(method, max_fpr))
        plt.savefig("./appendix/{}_{:.2f}.png".format(method, max_fpr))
        plt.close()

    # plt.figure()
    # plt.plot(Ns, ours, linewidth=1.5, marker="o", label="Ours")
    # plt.plot(Ns, ours_no_dpo, linewidth=1.5, marker="o", label="Ours (no DPO)")
    # plt.xlabel("Number of Samples", fontsize=16)
    # plt.ylabel("AUROC({})".format(int(max_fpr*100)), fontsize=16)
    # plt.legend(fontsize=16)
    # plt.savefig("./ablations_DPO_{:.2f}.pdf".format(max_fpr))
    # plt.savefig("./ablations_DPO_{:.2f}.png".format(max_fpr))
    # plt.close() 
    
    fig, axes = plt.subplots(1, 2, figsize=(8, 4), sharey=True, dpi=300)
    axes[0].set_title(LABEL_MAP["LLMOPT"])
    axes[0].set_xlabel("Number of Samples", fontsize=12)
    axes[0].set_ylabel("AUROC({})".format(int(max_fpr*100)), fontsize=12)
    axes[0].grid(True, linestyle="--", alpha=0.4)
    data = aucs_by_method["LLMOPT"]
    for detector, values in data.items():
        axes[0].plot(Ns, values, marker="o", linestyle="-", linewidth=1.5, label=detector)
    axes[1].set_title(LABEL_MAP["Ours"])
    axes[1].set_xlabel("Number of Samples", fontsize=12)
    axes[1].grid(True, linestyle="--", alpha=0.4)
    data = aucs_by_method["Ours"]
    for detector, values in data.items():
        axes[1].plot(Ns, values, marker="o", linestyle="-", linewidth=1.5, label=detector)
    handles, labels = axes[0].get_legend_handles_labels()
    fig.tight_layout(rect=[0, 0.15, 1, 0.95])
    fig.legend(
        handles, labels,
        loc="lower center",
        ncol=5,
        fontsize=10,
    )
    plt.savefig("MGTD_compare_{:.2f}.pdf".format(max_fpr))
    plt.savefig("MGTD_compare_{:.2f}.png".format(max_fpr))
    plt.close()

    return 0

if __name__ == "__main__":
    fire.Fire(main)