
import random
import sys

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import umap
import umap.plot
import torch
import torch.nn.functional as F
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
from transformers import AutoModel, AutoTokenizer

from genpaths import *
from nicks_dpo.detect import read_lines

random.seed(43)

def load(path_list, idx):
    path = path_list[idx]
    df = pd.read_json(path, lines=True)
    text = df[path_list[-1]].tolist()
    if isinstance(text[0], list):
        text = [t[0] for t in text]
    return text

@torch.no_grad()
def embed(
    text: list[str],
    model: AutoModel,
    tokenizer: AutoTokenizer,
    max_length: int = 256,
    batch_size: int = 128,
):
    out = []
    for batch_idx in tqdm(range(0, len(text), batch_size)):
        batch = text[batch_idx:batch_idx+batch_size]
        inputs = tokenizer(
            batch,
            max_length=max_length,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )
        inputs.to(model.device)
        inputs["input_ids"] = inputs["input_ids"].unsqueeze(1)
        inputs["attention_mask"] = inputs["attention_mask"].unsqueeze(1)
        out.append(
            F.normalize(model(**inputs), dim=-1, p=2).detach().cpu().numpy()
        )
    return np.concatenate(out, axis=0)

def main():
    N = 200

    data_base = load(HUMAN, 0)
    random.shuffle(data_base)

    machine_base = load(MACHINE, 0)
    random.shuffle(machine_base)

    data_preference = load(LLMOPT, 0)
    random.shuffle(data_preference)

    data_ours = load(OURS, 0)
    random.shuffle(data_ours)
    
    human = data_base[:N]
    machine_base = machine_base[:N]
    machine_preference = data_preference[:N]
    machine_ours = data_ours[:N]

    all_text = human + machine_base + machine_preference + machine_ours
    labels = np.array(
        ["Human"] * N \
        + ["Mistral-7B"] * N \
        + ["LLM-OPT(FastDetectGPT)"] * N \
        + ["Ours"] * N
    )
    
    #### Embed
    model = AutoModel.from_pretrained("rrivera1849/LUAR-MUD", trust_remote_code=True)
    model.eval()
    model.cuda()
    tokenizer = AutoTokenizer.from_pretrained("rrivera1849/LUAR-MUD", trust_remote_code=True)
    embeddings_luar = embed(all_text, model, tokenizer)

    # model = SentenceTransformer("AnnaWegmann/Style-Embedding")
    # model.eval(); model.cuda()
    # embeddings_cisr = model.encode(all_text, normalize_embeddings=True)

    # model = SentenceTransformer("StyleDistance/styledistance")
    # model.eval(); model.cuda()
    # embeddings_sd = model.encode(all_text, normalize_embeddings=True)

    labels_plot   = (["Human"] * N
               + ["Mistral-7B"] * N
               + ["Mistral-7B-DPO-FastDetectGPT"] * N
               + ["Ours"] * N)
    labels_plot = np.array(labels_plot)

    embeddings_list = [
        (embeddings_luar, "LUAR"),
        # (embeddings_cisr, "CISR"),
        # (embeddings_sd,   "StyleDistance"),
    ]

    # define markers & colors for each class
    classes  = ["Human", "Mistral-7B", "Mistral-7B-DPO-FastDetectGPT", "Ours"]
    markers  = ["o", "s", "^", "D"]
    colors   = plt.rcParams['axes.prop_cycle'].by_key()['color'][:len(classes)]
    marker_map = dict(zip(classes, markers))
    color_map  = dict(zip(classes, colors))

    fig, axes = plt.subplots(1, len(embeddings_list), figsize=(6 * len(embeddings_list), 6))
    N_plot = 200
    if not isinstance(axes, list):
        axes = [axes]

    for ax, (emb, title) in zip(axes, embeddings_list):
        reducer = umap.UMAP(metric="cosine", random_state=43)
        proj    = reducer.fit_transform(emb)
        perm = np.random.permutation(len(proj))
        
        proj = proj[perm][:N_plot]
        labels_current_plot = labels_plot[perm][:N_plot]

        # scatter each class separately so we can set marker & color
        for cls in classes:
            print(cls)
            idxs = [i for i, lbl in enumerate(labels_current_plot) if lbl == cls]
            ax.scatter(
                proj[idxs, 0], proj[idxs, 1],
                marker=marker_map[cls],
                color=color_map[cls],
                s=60,              # make points larger
                alpha=0.8,
                label=cls
            )
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
            ax.spines['bottom'].set_visible(False)
            ax.spines['left'].set_visible(False)

        # ax.set_title(title, fontsize=14)
        ax.set_xticks([]); ax.set_yticks([])
        ax.grid(False)

    # one legend for all
    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels,
            loc="upper center",
            ncol=len(classes) // 2,
            frameon=False,
            fontsize=12,
            # bbox_to_anchor=(0.5, 0.0)
        )

    # fig.suptitle("UMAP projections (50 samples per class)", fontsize=16, y=0.98)
    # fig.tight_layout(rect=[0, 0.1, 1, 0.96])

    plt.savefig("umap_comparison_{}.png".format(len(embeddings_list)), dpi=300)
    # plt.show()


    # #### UMAP
    # mapper = umap.UMAP(metric="cosine").fit(embeddings_luar)
    # umap.plot.points(mapper, labels=labels)
    # plt.savefig("./nicks_dpo/umap_luar.pdf")
    # plt.savefig("./nicks_dpo/umap_luar.png")

    # mapper = umap.UMAP(metric="cosine").fit(embeddings_cisr)
    # umap.plot.points(mapper, labels=labels)
    # plt.savefig("./nicks_dpo/umap_cisr.pdf")
    # plt.savefig("./nicks_dpo/umap_cisr.png")

    # mapper = umap.UMAP(metric="cosine").fit(embeddings_sd)
    # umap.plot.points(mapper, labels=labels)
    # plt.savefig("./nicks_dpo/umap_sd.pdf")
    # plt.savefig("./nicks_dpo/umap_sd.png")
    
    return 0

if __name__ == "__main__":
    sys.exit(main())