
# OUTFOX inspired baseline:
# For each preference dataset, we will pick at random 100 generations to 
# instantiate StyleDetect. Then, we will pick 1000 points for validation at random to 
# pick a good threshold. 
# Then, for all the remaining points we will evaluate StyleDetect and save the labels 
# of each text in the `content_text` (human) and `respond_reddit` (machine) fields.

# To evaluate on our usual test sets, we will extract all SBERT embeddings for all the text:
# test samples, and the aforementioned human and machine text.
# Then, we will create an OUTFOX prompt using the top-k "detected as human" and the "top-k" 
# detected as machine.
# We will save these prompts, and then use Nick's OpenAI key to prompt for the outputs.

import os
from typing import Union

import numpy as np
import fire
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
    AutoModel,
    AutoTokenizer
)
from sklearn.metrics import roc_curve, roc_auc_score
from tqdm import tqdm

@torch.inference_mode()
def get_luar_embeddings(
    text: Union[list[str], list[list[str]]],
    model: AutoModel,
    tokenizer: AutoTokenizer,
    batch_size: int = 256,
    single: bool = False,
):
    if isinstance(text[0], list):
        outputs = torch.cat([get_luar_embeddings(t, model, tokenizer, single=True) for t in text], dim=0)
        return outputs

    device = model.device

    inputs = tokenizer(
        text,
        max_length=512,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )
    
    if single:
        inputs["input_ids"] = inputs["input_ids"].unsqueeze(0)
        inputs["attention_mask"] = inputs["attention_mask"].unsqueeze(0)
        inputs.to(device)
        outputs = model(**inputs)
    else:
        outputs = []
        for batch_idx in tqdm(range(0, len(text), batch_size)):
            batch_inputs = {
                k: v[batch_idx:batch_idx + batch_size].unsqueeze(1).to(device)
                for k, v in inputs.items()
            }
            outputs.append(model(**batch_inputs))
        outputs = torch.cat(outputs, dim=0)
            
    outputs = F.normalize(outputs, dim=-1, p=2)
    return outputs

def main(
    dataset_name: str = "/usr/WS1/rrivera/data/style_transfer/preference/MTD_reddit_12000_correct.jsonl.ready",
):
    # dict_keys(['author_id', 'content_text', 'content_action_type', 'reference_text', 'reference_action_type', 'respond_reddit', 'model_name', 'decoding_param', 'watermarking_alg', 'watermarking_param', 'transfer_author_id', 'transfer_reference_text'])
    df = pd.read_json(dataset_name, lines=True)
    df = df.sample(frac=1., random_state=43)
    K = 100
    fewshot_df = df.iloc[:K]
    df = df.iloc[K:]

    # Get few-shot example
    model = AutoModel.from_pretrained("rrivera1849/LUAR-MUD", trust_remote_code=True)
    model.cuda(); model.eval()
    tokenizer = AutoTokenizer.from_pretrained("rrivera1849/LUAR-MUD", trust_remote_code=True)
    cossim = nn.CosineSimilarity(dim=-1)
    fewshot_samples = fewshot_df["respond_reddit"].tolist()
    fs_embedding = get_luar_embeddings(fewshot_samples, model, tokenizer, single=True)

    # Get threshold using calibration data
    C = 1_000
    calib_df = df.iloc[:C]
    df = df.iloc[C:]
    human = calib_df["content_text"].tolist()
    machine = calib_df["respond_reddit"].tolist()
    embeddings_human = get_luar_embeddings(human, model, tokenizer, single=False)
    embeddings_machine = get_luar_embeddings(machine, model, tokenizer, single=False)
    embeddings = torch.cat([embeddings_human, embeddings_machine], dim=0)
    scores = cossim(embeddings, fs_embedding.repeat(len(embeddings), 1)).tolist()
    labels = [0] * embeddings_human.size(0) + [1] * embeddings_machine.size(0)
    fpr, tpr, thresholds = roc_curve(labels, scores)
    # Youden J's Statistic: https://en.wikipedia.org/wiki/Youden%27s_J_statistic
    # https://www.kaggle.com/code/nicholasgah/obtain-optimal-probability-threshold-using-roc#
    roc_t = sorted(list(zip(np.abs(tpr - fpr), thresholds)), key=lambda i: i[0], reverse=True)[0][1]
    print("AUC Calibration:", roc_auc_score(labels, scores))

    # Evaluate on the held-out set, saving the labels and printing the ROC-AUC
    human = df["content_text"].tolist()
    machine = df["respond_reddit"].tolist()
    embeddings_human = get_luar_embeddings(human, model, tokenizer, single=False)
    embeddings_machine = get_luar_embeddings(machine, model, tokenizer, single=False)
    labels = [0] * embeddings_human.size(0) + [1] * embeddings_machine.size(0)

    scores_human = cossim(embeddings_human, fs_embedding.repeat(len(embeddings_human), 1))
    scores_machine = cossim(embeddings_machine, fs_embedding.repeat(len(embeddings_machine), 1))
    scores_human = (scores_human > roc_t).tolist()
    scores_machine = (scores_machine > roc_t).tolist()
    scores = scores_human + scores_machine

    print("AUC on held-out (in-domain):", roc_auc_score(labels, scores))

    os.makedirs("/usr/workspace/rrivera/data/style_transfer/outfox", exist_ok=True)
    ii = 0
    new_rows = []
    for _, row in df.iterrows():
        row["content_text_label"] = scores_human[ii]
        row["respond_reddit_label"] = scores_machine[ii]
        new_rows.append(row)
        ii += 1
    new_df = pd.DataFrame(new_rows)
    new_df.to_json("/usr/workspace/rrivera/data/style_transfer/outfox/{}".format(os.path.basename(dataset_name)), lines=True, orient="records")

if __name__ == "__main__":
    fire.Fire(main)