"""Plot generations of every method using LUAR, CISR, and SD.
"""

import os
import random
import sys
from functools import partial

from matplotlib.patches import Circle
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import umap
from sentence_transformers import SentenceTransformer
from transformers import AutoModel, AutoTokenizer, set_seed

from genpaths import *
from nicks_dpo.create_preference_data import get_luar_embeddings

palette = [
    "#0072B2",  # blue
    "#E69F00",  # orange
    "#009E73",  # green
    "#D55E00",  # vermillion
    "#CC79A7",  # purple/pink
    "#000000",  # black
    "#56B4E9",  # sky blue
    "#F0E442",  # yellow
    "#999999",  # gray
]
markers = ["o", "s", "^", "D", "P", "X", "v", ">", "<"]

def load_text(arr):
    # [2] -> Reddit [3] -> Generation
    text = pd.read_json(arr[2], lines=True)[arr[3]].tolist()
    if isinstance(text[0], list):
        text = [j[0] for j in text]
    return text

def main():
    set_seed(43)
    N = 100

    methods = ["Human", "Machine", "LLMOPT", "OUTFOX", "DIPPER", "Ours"]
    models = ["LUAR", "CISR", "SD"]

    label_map = {
        "Human": "Human",
        "Machine": "Machine",
        "LLMOPT": "Mistral-7B-DPO-FastDetectGPT",
        "OUTFOX": "OUTFOX (Prompting)",
        "DIPPER": "DIPPER (Paraphrasing)",
        "Ours": r"$\bf{Ours}$ (Style-aware Paraphrasing)",
    }
    method_data = {}
    for method_name in methods:
        mdata = load_text(globals()[method_name.upper()])
        random.shuffle(mdata)
        mdata = mdata[:N]
        method_data[method_name] = mdata
    # --- One figure for all models, one legend at the bottom ---
    fig, axes = plt.subplots(1, 3, figsize=(12.5, 5.0), dpi=200)  # wider canvas for 3 panels
    fig.subplots_adjust(bottom=0.20, wspace=0.08)  # leave room for legend; tighten spacing

    # We'll build legend handles once (proxies so they don't depend on any single axes)
    from matplotlib.lines import Line2D
    proxies = []
    labels = []
    for i, method_name in enumerate(methods):
        proxies.append(Line2D([0], [0],
                              marker=markers[i % len(markers)],
                              linestyle="",
                              markerfacecolor=palette[i % len(palette)],
                              markeredgecolor="white",
                              markeredgewidth=0.8,
                              markersize=7))
        labels.append(label_map[method_name])

    for ax, model_name in zip(axes, models):
        # ---- embed per model (same as your original logic) ----
        method_embeddings = {}
        if model_name != "LUAR":
            HF_id = "AnnaWegmann/Style-Embedding" if model_name == "CISR" else "StyleDistance/styledistance"
            model = SentenceTransformer(HF_id)
            model.eval(); model.cuda()
            for method_name, text in method_data.items():
                print("Embedding: {}".format(method_name))
                emb = model.encode(text, progress_bar=True, convert_to_tensor=True, normalize_embeddings=True).cpu()
                method_embeddings[method_name] = emb
        else:
            model = AutoModel.from_pretrained("rrivera1849/LUAR-MUD", trust_remote_code=True)
            model.eval(); model.cuda()
            tokenizer = AutoTokenizer.from_pretrained("rrivera1849/LUAR-MUD", trust_remote_code=True)
            for method_name, text in method_data.items():
                print("Embedding: {}".format(method_name))
                emb = get_luar_embeddings(text, model, tokenizer, batch_size=1024, single=False).cpu()
                method_embeddings[method_name] = emb

        all_embeddings = []
        all_labels = []
        for method_name, embeddings in method_embeddings.items():
            all_embeddings.append(embeddings)
            all_labels.extend([label_map[method_name]] * embeddings.size(0))
        all_embeddings = torch.cat(all_embeddings, dim=0)
        all_labels = np.array(all_labels)

        mapper = umap.UMAP(metric="cosine")
        lowdim_embs = mapper.fit_transform(all_embeddings)

        # ---- scatter per method on this subplot ----
        for i, method_name in enumerate(methods):
            alpha = 0.95

            idx = np.where(all_labels == label_map[method_name])[0]
            ax.scatter(
                lowdim_embs[idx, 0],
                lowdim_embs[idx, 1],
                s=35,
                marker=markers[i % len(markers)],
                c=palette[i % len(palette)],
                edgecolors="white",
                linewidths=0.6,
                alpha=alpha,
            )

        # clean look
        ax.set_xticks([]); ax.set_yticks([])
        # for spine in ax.spines.values():
            # spine.set_visible(False)
        model_name = "StyleDistance" if model_name == "SD" else model_name
        ax.set_title(f"{model_name}", fontsize=16, pad=6)

    # ---- single shared legend at the bottom ----
    fig.legend(
        handles=proxies,
        labels=labels,
        loc="outside lower center",
        ncol=min(len(methods), 3),
        frameon=False,
        handletextpad=0.5,
        columnspacing=1.0,
        borderaxespad=0.2,
        fontsize=16,
        bbox_to_anchor=(0.5, -0.02),
    )
    # optional overall title
    # fig.suptitle("UMAP Projections by Style Model", fontsize=12, y=0.98)
    # save once
    plt.tight_layout(rect=[0, 0.12, 1, 1])  # keep space for legend
    plt.savefig("umap_all.png", bbox_inches="tight")
    plt.savefig("umap_all.pdf", bbox_inches="tight")
    plt.close()

    return 0

if __name__ == "__main__":
    sys.exit(main())