
import json
import math
import os

import torch
import fire
import numpy as np
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score
from transformers import AutoModel, AutoTokenizer
from binoculars import Binoculars
from tqdm import tqdm
from sentence_transformers import SentenceTransformer

from nicks_dpo.create_preference_data import (
    get_fast_detect_gpt_scores, 
    get_luar_embeddings,
)

def read_lines(
    fname: str,
    N: int = None,
):
    data = []
    with  open(fname, "r") as fin:
        for line in fin:
            data.append(json.loads(line))
            if N is not None and len(data) >= N:
                break
    return data

@torch.no_grad()
def bino_forward(
    bino: Binoculars,
    text: list[str],
    batch_size: int = 32,
):
    scores = []
    for i in tqdm(range(0, len(text), batch_size)):
        batch = text[i:i+batch_size]
        scores += bino.compute_score(batch)
    return [-score for score in scores]

def remove_invalid(scores: list[float]):
    return [s for s in scores if not math.isnan(s)]

def main(
    base: str = "", # TODO
    generations: str = "./nicks_dpo/generations_neurips/MTD-reddit-12000-correct_checkpoint-7500_merged-FastDetectGPT-reddit_temperature=0.7_top-p=0.9_ng=2-preference.jsonl",
    debug: bool = False,
    K: int = 100,
    fewshot_from: str = "machine",
):
    assert fewshot_from in ["human", "machine", "optimized"]
    if debug:
        N = 100 + K
    else:
        N = None
    
    data_base = read_lines(base, N=N)
    human = [d["content_text"] for d in data_base]
    machine_base = [d["respond_reddit"][0] for d in data_base]

    data_preference = read_lines(generations, N=N)
    machine_preference = [d["respond_reddit"][0] for d in data_preference]

    scores_human_fd = remove_invalid(get_fast_detect_gpt_scores(human))
    scores_machine_base_fd = remove_invalid(get_fast_detect_gpt_scores(machine_base))
    scores_machine_preference_fd = remove_invalid(get_fast_detect_gpt_scores(machine_preference))

    bino = Binoculars(
        observer_name_or_path="tiiuae/falcon-7b",
        performer_name_or_path="tiiuae/falcon-7b-instruct",
    )
    scores_human_bino = remove_invalid(bino_forward(bino, human))
    scores_machine_base_bino = remove_invalid(bino_forward(bino, machine_base))
    scores_machine_preference_bino = remove_invalid(bino_forward(bino, machine_preference))

    # load background:
    model = AutoModel.from_pretrained("rrivera1849/LUAR-MUD", trust_remote_code=True).eval().cuda()
    tokenizer = AutoTokenizer.from_pretrained("rrivera1849/LUAR-MUD", trust_remote_code=True)

    # Few-Shot examples come from the base Mistral-7B model:
    data_background = machine_base[:K] # TODO: Maybe need to change to machine_preference as before?
    machine_base = machine_base[K:]

    ## LUAR
    background_emb = get_luar_embeddings(data_background, model, tokenizer, single=True)
    human_emb = get_luar_embeddings(human, model, tokenizer)
    machine_base_emb = get_luar_embeddings(machine_base, model, tokenizer)
    machine_preference_emb = get_luar_embeddings(machine_preference, model, tokenizer)

    scores_human = F.cosine_similarity(background_emb.repeat(human_emb.size(0), 1), human_emb)
    scores_machine_base = F.cosine_similarity(background_emb.repeat(machine_base_emb.size(0), 1), machine_base_emb)
    scores_machine_preference = F.cosine_similarity(background_emb.repeat(machine_preference_emb.size(0), 1), machine_preference_emb)
    
    scores_human_luar = remove_invalid(scores_human.cpu().tolist())
    scores_machine_base_luar = remove_invalid(scores_machine_base.cpu().tolist())
    scores_machine_preference_luar = remove_invalid(scores_machine_preference.cpu().tolist())

    ## CISR
    cisr = SentenceTransformer("AnnaWegmann/Style-Embedding").cuda().eval()
    background_emb_CISR = cisr.encode(data_background, show_progress_bar=False, normalize_embeddings=True, convert_to_tensor=True)
    human_emb_CISR = cisr.encode(human, show_progress_bar=False, normalize_embeddings=True, convert_to_tensor=True)
    machine_base_emb_CISR = cisr.encode(machine_base, show_progress_bar=False, normalize_embeddings=True, convert_to_tensor=True)
    machine_preference_emb_CISR = cisr.encode(machine_preference, show_progress_bar=False, normalize_embeddings=True, convert_to_tensor=True)
    
    scores_human_CISR = F.cosine_similarity(background_emb_CISR.repeat(human_emb.size(0), 1), human_emb_CISR)
    scores_machine_base_CISR = F.cosine_similarity(background_emb_CISR.repeat(machine_base_emb_CISR.size(0), 1), machine_base_emb)
    scores_machine_preference_CISR = F.cosine_similarity(background_emb_CISR.repeat(machine_preference_emb.size(0), 1), machine_preference_emb_CISR)
    
    scores_human_CISR = remove_invalid(scores_human_CISR.cpu().tolist())
    scores_machine_base_CISR = remove_invalid(scores_machine_base_CISR.cpu().tolist())
    scores_machine_preference_CISR = remove_invalid(scores_machine_preference_CISR.cpu().tolist())

    ## SD
    sd = SentenceTransformer("StyleDistance/styledistance").cuda().eval()
    background_emb_SD = sd.encode(data_background, show_progress_bar=False, normalize_embeddings=True, convert_to_tensor=True)
    human_emb_SD = sd.encode(human, show_progress_bar=False, normalize_embeddings=True, convert_to_tensor=True)
    machine_base_emb_SD = sd.encode(machine_base, show_progress_bar=False, normalize_embeddings=True, convert_to_tensor=True)
    machine_preference_emb_SD = sd.encode(machine_preference, show_progress_bar=False, normalize_embeddings=True, convert_to_tensor=True)
    
    scores_human_SD = F.cosine_similarity(background_emb_SD.repeat(human_emb.size(0), 1), human_emb_SD)
    scores_machine_base_SD = F.cosine_similarity(background_emb_SD.repeat(machine_base_emb_SD.size(0), 1), machine_base_emb)
    scores_machine_preference_SD = F.cosine_similarity(background_emb_SD.repeat(machine_preference_emb.size(0), 1), machine_preference_emb_SD)
    
    scores_human_SD = remove_invalid(scores_human_SD.cpu().tolist())
    scores_machine_base_SD = remove_invalid(scores_machine_base_SD.cpu().tolist())
    scores_machine_preference_SD = remove_invalid(scores_machine_preference_SD.cpu().tolist())

    print("FastDetectGPT")                
    scores = scores_human_fd + scores_machine_base_fd
    labels = [0] * len(scores_human_fd) + [1] * len(scores_machine_base_fd)
    AUC_base_fd = roc_auc_score(labels, scores)
    print("\tAUC(1.0) Base - {:.2f}".format(AUC_base_fd))
    scores = scores_human_fd + scores_machine_preference_fd
    labels = [0] * len(scores_human_fd) + [1] * len(scores_machine_preference_fd)
    AUC_preference_fd = roc_auc_score(labels, scores)
    print("\tAUC(1.0) Preference - {:.2f}".format(AUC_preference_fd))
    
    print("Binoculars")                
    scores = scores_human_bino + scores_machine_base_bino
    labels = [0] * len(scores_human_bino) + [1] * len(scores_machine_base_bino)
    AUC_base_bino = roc_auc_score(labels, scores)
    print("\tAUC(1.0) Base - {:.2f}".format(AUC_base_bino))
    scores = scores_human_bino + scores_machine_preference_bino
    labels = [0] * len(scores_human_bino) + [1] * len(scores_machine_preference_bino)
    AUC_preference_bino = roc_auc_score(labels, scores)
    print("\tAUC(1.0) Preference - {:.2f}".format(AUC_preference_bino))

    print("LUAR")                
    scores = scores_human_luar + scores_machine_base_luar
    labels = [0] * len(scores_human_luar) + [1] * len(scores_machine_base_luar)
    AUC_base_luar = roc_auc_score(labels, scores)
    print("\tAUC(1.0) Base - {:.2f}".format(AUC_base_luar))
    scores = scores_human_luar + scores_machine_preference_luar
    labels = [0] * len(scores_human_luar) + [1] * len(scores_machine_preference_luar)
    AUC_preference_luar = roc_auc_score(labels, scores)
    print("\tAUC(1.0) Preference - {:.2f}".format(AUC_preference_luar))

    print("CISR")
    scores = scores_human_CISR + scores_machine_base_CISR
    labels = [0] * len(scores_human_CISR) + [1] * len(scores_machine_base_CISR)
    AUC_base_CISR = roc_auc_score(labels, scores)
    print("\tAUC(1.0) Base - {:.2f}".format(AUC_base_CISR))
    scores = scores_human_CISR + scores_machine_preference_CISR
    labels = [0] * len(scores_human_CISR) + [1] * len(scores_machine_preference_CISR)
    AUC_preference_CISR = roc_auc_score(labels, scores)
    print("\tAUC(1.0) Preference - {:.2f}".format(AUC_preference_CISR))

    print("SD")
    scores = scores_human_SD + scores_machine_base_SD
    labels = [0] * len(scores_human_SD) + [1] * len(scores_machine_base_SD)
    AUC_base_SD = roc_auc_score(labels, scores)
    print("\tAUC(1.0) Base - {:.2f}".format(AUC_base_SD))
    scores = scores_human_SD + scores_machine_preference_SD
    labels = [0] * len(scores_human_SD) + [1] * len(scores_machine_preference_SD)
    AUC_preference_SD = roc_auc_score(labels, scores)
    print("\tAUC(1.0) Preference - {:.2f}".format(AUC_preference_SD))

    d = {
        "FD AUC Base": AUC_base_fd,
        "FD AUC Preference": AUC_preference_fd,
        "BINO AUC Base": AUC_base_bino,
        "BINO AUC Preference": AUC_preference_bino,
        "CISR AUC Base": AUC_base_CISR,
        "CISR AUC Preference": AUC_preference_CISR,
        "SD AUC Base": AUC_base_SD,
        "SD AUC Preference": AUC_preference_SD,
    }
    with open("./nicks_dpo/detect_results/{}".format(os.path.basename(generations).replace(".jsonl", "")), "w+") as fout:
        fout.write(json.dumps(d, indent=4))
        fout.write("\n")
    
    return 0

if __name__ == "__main__":
    
    fire.Fire(main)