
# OUTFOX inspired baseline:
# For each preference dataset, we will pick at random 100 generations to 
# instantiate StyleDetect. Then, we will pick 1000 points for validation at random to 
# pick a good threshold. 
# Then, for all the remaining points we will evaluate StyleDetect and save the labels 
# of each text in the `content_text` (human) and `respond_reddit` (machine) fields.

# To evaluate on our usual test sets, we will extract all SBERT embeddings for all the text:
# test samples, and the aforementioned human and machine text.
# Then, we will create an OUTFOX prompt using the top-k "detected as human" and the "top-k" 
# detected as machine.
# We will save these prompts, and then use Nick's OpenAI key to prompt for the outputs.

import os

import fire
import pandas as pd
import torch.nn as nn
from transformer import (
    AutoModel,
    AutoTokenizer
)
from sklearn.metrics import roc_curve, roc_auc_score
from tqdm import tqdm

@torch.inference_mode()
def get_luar_embeddings(
    text: Union[list[str], list[list[str]]],
    model: AutoModel,
    tokenizer: AutoTokenizer,
    batch_size: int = 32,
    single: bool = False,
):
    if isinstance(text[0], list):
        outputs = torch.cat([get_luar_embeddings(t, model, tokenizer, single=True) for t in text], dim=0)
        return outputs

    device = model.device

    inputs = tokenizer(
        text,
        max_length=512,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )
    
    if single:
        inputs["input_ids"] = inputs["input_ids"].unsqueeze(0)
        inputs["attention_mask"] = inputs["attention_mask"].unsqueeze(0)
        inputs.to(device)
        outputs = model(**inputs)
    else:
        outputs = []
        for batch_idx in tqdm(range(0, len(text), batch_size)):
            batch_inputs = {
                k: v[batch_idx:batch_idx + batch_size].unsqueeze(1).to(device)
                for k, v in inputs.items()
            }
            outputs.append(model(**batch_inputs))
        outputs = torch.cat(outputs, dim=0)
            
    outputs = F.normalize(outputs, dim=-1, p=2)
    return outputs

def main(
    dataset_name: str,
):
    # dict_keys(['author_id', 'content_text', 'content_action_type', 'reference_text', 'reference_action_type', 'respond_reddit', 'model_name', 'decoding_param', 'watermarking_alg', 'watermarking_param', 'transfer_author_id', 'transfer_reference_text'])

    df = pd.read_json(dataset_name, lines=True)
    df = df.sample(frac=1., seed=43)
    K = 100
    fewshot_df = df.iloc[:K]
    df = df.iloc[K:]

    breakpoint()
    # Get few-shot example
    model = AutoModel.from_pretrained("rrivera1849/LUAR-MUD", trust_remote_code=True)
    tokenizer = AutoTokenizer.from_pretrained("rrivera1849/LUAR-MUD", trust_remote_code=True)
    cossim = nn.CosineSimilarity(dim=-1)
    fewshot_samples = fewshot_df["respond_reddit"].tolist()
    embedding = get_luar_embeddings(fewshot_samples, model, tokenizer, single=True)
    breakpoint()

    # Get threshold using calibration data
    C = 1_000
    calib_df = df.iloc[:C]
    df = df.iloc[C:]
    human = calib_df["content_text"].tolist()
    machine = calib_df["respond_reddit"].tolist()
    embeddings_human = get_luar_embeddings(human, model, tokenizer, single=False)
    embeddings_machine = get_luar_embeddings(machine, model, tokenizer, single=False)
    embeddings = torch.cat([embeddings_human, embeddings_machine], dim=0)
    scores = cossim(embeddings, embedding.repeat(len(embeddings), 1))
    labels = [0] * embeddings_human.size(0) + [1] * embeddings_machine.size(0)
    fpr, tpr, thresholds = roc_curve(labels, scores)
    # Youden J's Statistic: https://en.wikipedia.org/wiki/Youden%27s_J_statistic
    # https://www.kaggle.com/code/nicholasgah/obtain-optimal-probability-threshold-using-roc#
    roc_t = sorted(list(zip(np.abs(tpr - fpr), thresholds)), key=lambda i: i[0], reverse=True)[0][1]
    breakpoint()

    # Evaluate on the held-out set, saving the labels and printing the ROC-AUC
    human = df["content_text"].tolist()
    machine = df["respond_reddit"].tolist()
    embeddings_human = get_luar_embeddings(human, model, tokenizer, single=False)
    embeddings_machine = get_luar_embeddings(machine, model, tokenizer, single=False)
    labels = [0] * embeddings_human.size(0) + [1] * embeddings_machine.size(0)

    scores_human = cossim(embeddings_human, embedding.repeat(len(embeddings_human), 1))
    scores_machine = cossim(embeddings_machine, embedding.repeat(len(embeddings_machine), 1))
    scores_human = (scores_human > roc_t).tolist()
    scores_machine = (scores_machine > roc_t).tolist()
    scores = scores_human + scores_machine

    print("AUC:", roc_auc_score(labels, scores))

    breakpoint()
    os.makedirs("./outfox", exists_ok=True)
    ii = 0
    new_rows = []
    for _, row in df.iterrows():
        row["content_text_label"] = scores_human[ii]
        row["respond_reddit_label"] = scores_machine[ii]
        new_rows.append(row)
    new_df = pd.DataFrame(new_rows)
    new_df.to("./outfox/{}".format(os.path.basename(filename)), lines=True, orient="records")

    breakpoint()

if __name__ == "__main__":
    fire.Fire(main)