import matplotlib.pyplot as plt
import numpy as np
import datasets
import transformers
import re
import torch
import torch.nn.functional as F
import tqdm
import random
from sklearn.metrics import roc_curve, precision_recall_curve, auc
import argparse
import datetime
import os
import json
import functools
import custom_datasets
from multiprocessing.pool import ThreadPool
import time
import openai
from scipy.special import rel_entr
detector_weight_dir =""
checkpoint_weight_dir =""

# 15 colorblind-friendly colors
COLORS = ["#0072B2", "#009E73", "#D55E00", "#CC79A7", "#F0E442",
            "#56B4E9", "#E69F00", "#000000", "#0072B2", "#009E73",
            "#D55E00", "#CC79A7", "#F0E442", "#56B4E9", "#E69F00"]

# define regex to match all <extra_id_*> tokens, where * is an integer
pattern = re.compile(r"<extra_id_\d+>")

def min_max_normalize(real_preds,sample_preds):
    # to compare kl_div
    min_score = min(
        min(real_preds),min(sample_preds)
    )
    max_score = max(
        max(real_preds),max(sample_preds)
    )
    real_preds = [(1e-10+item - min_score)/(max_score-min_score) for item in real_preds]
    sample_preds = [(1e-10+item - min_score)/(max_score-min_score) for item in sample_preds]
    return real_preds,sample_preds


def _rel_entr(sample_preds,real_preds):
    real_id = [i for i in range(len(real_preds)) if not np.isnan(real_preds[i])]
    sample_id = [i for i in range(len(sample_preds)) if not np.isnan(sample_preds[i])]
    real_preds = [real_preds[i] for i in real_id]
    sample_preds = [sample_preds[i] for i in sample_id]
    kl_div=sum(rel_entr(sample_preds,real_preds))
    return kl_div


def load_base_model():
    print('MOVING BASE MODEL TO GPU...', end='', flush=True)
    start = time.time()
    try:
        mask_model.cpu()
    except NameError:
        pass
    if args.openai_model is None:
        base_model.to(DEVICE)
    print(f'DONE ({time.time() - start:.2f}s)')


def load_mask_model():
    print('MOVING MASK MODEL TO GPU...', end='', flush=True)
    start = time.time()

    if args.openai_model is None:
        base_model.cpu()
    if not args.random_fills:
        mask_model.to(DEVICE)
    print(f'DONE ({time.time() - start:.2f}s)')


def tokenize_and_mask(text, span_length, pct, ceil_pct=False):
    tokens = text.split(' ')
    mask_string = '<<<mask>>>'

    n_spans = pct * len(tokens) / (span_length + args.buffer_size * 2)
    if ceil_pct:
        n_spans = np.ceil(n_spans)
    n_spans = int(n_spans)

    n_masks = 0
    while n_masks < n_spans:
        start = np.random.randint(0, len(tokens) - span_length)
        end = start + span_length
        search_start = max(0, start - args.buffer_size)
        search_end = min(len(tokens), end + args.buffer_size)
        if mask_string not in tokens[search_start:search_end]:
            tokens[start:end] = [mask_string]
            n_masks += 1
    
    # replace each occurrence of mask_string with <extra_id_NUM>, where NUM increments
    num_filled = 0
    for idx, token in enumerate(tokens):
        if token == mask_string:
            tokens[idx] = f'<extra_id_{num_filled}>'
            num_filled += 1
    assert num_filled == n_masks, f"num_filled {num_filled} != n_masks {n_masks}"
    text = ' '.join(tokens)
    return text


def count_masks(texts):
    return [len([x for x in text.split() if x.startswith("<extra_id_")]) for text in texts]


# replace each masked span with a sample from T5 mask_model
def replace_masks(texts):
    n_expected = count_masks(texts)
    stop_id = mask_tokenizer.encode(f"<extra_id_{max(n_expected)}>")[0]
    tokens = mask_tokenizer(texts, return_tensors="pt", padding=True).to(DEVICE)
    outputs = mask_model.generate(**tokens, max_length=150, do_sample=True, top_p=args.mask_top_p, num_return_sequences=1, eos_token_id=stop_id)
    return mask_tokenizer.batch_decode(outputs, skip_special_tokens=False)





def extract_fills(texts):
    # remove <pad> from beginning of each text
    texts = [x.replace("<pad>", "").replace("</s>", "").strip() for x in texts]

    # return the text in between each matched mask token
    extracted_fills = [pattern.split(x)[1:-1] for x in texts]

    # remove whitespace around each fill
    extracted_fills = [[y.strip() for y in x] for x in extracted_fills]

    return extracted_fills


def apply_extracted_fills(masked_texts, extracted_fills):
    # split masked text into tokens, only splitting on spaces (not newlines)
    tokens = [x.split(' ') for x in masked_texts]

    n_expected = count_masks(masked_texts)

    # replace each mask token with the corresponding fill
    for idx, (text, fills, n) in enumerate(zip(tokens, extracted_fills, n_expected)):
        if len(fills) < n:
            tokens[idx] = []
        else:
            for fill_idx in range(n):
                text[text.index(f"<extra_id_{fill_idx}>")] = fills[fill_idx]

    # join tokens back into text
    texts = [" ".join(x) for x in tokens]
    return texts


def perturb_texts_(texts, span_length, pct, ceil_pct=False):
    if not args.random_fills:
        masked_texts = [tokenize_and_mask(x, span_length, pct, ceil_pct) for x in texts]
        raw_fills = replace_masks(masked_texts)
        extracted_fills = extract_fills(raw_fills)
        perturbed_texts = apply_extracted_fills(masked_texts, extracted_fills)
        # Handle the fact that sometimes the model doesn't generate the right number of fills and we have to try again
        attempts = 1
        while '' in perturbed_texts:
            idxs = [idx for idx, x in enumerate(perturbed_texts) if x == '']
            print(f'WARNING: {len(idxs)} texts have no fills. Trying again [attempt {attempts}].')
            masked_texts = [tokenize_and_mask(x, span_length, pct, ceil_pct) for idx, x in enumerate(texts) if idx in idxs]
            raw_fills = replace_masks(masked_texts)
            extracted_fills = extract_fills(raw_fills)
            new_perturbed_texts = apply_extracted_fills(masked_texts, extracted_fills)
            for idx, x in zip(idxs, new_perturbed_texts):
                perturbed_texts[idx] = x
            attempts += 1
            if attempts>100:
                break
    else:
        if args.random_fills_tokens:
            # tokenize base_tokenizer
            tokens = base_tokenizer(texts, return_tensors="pt", padding=True).to(DEVICE)
            valid_tokens = tokens.input_ids != base_tokenizer.pad_token_id
            replace_pct = args.pct_words_masked * (args.span_length / (args.span_length + 2 * args.buffer_size))

            # replace replace_pct of input_ids with random tokens
            random_mask = torch.rand(tokens.input_ids.shape, device=DEVICE) < replace_pct
            random_mask &= valid_tokens
            random_tokens = torch.randint(0, base_tokenizer.vocab_size, (random_mask.sum(),), device=DEVICE)
            # while any of the random tokens are special tokens, replace them with random non-special tokens
            while any(base_tokenizer.decode(x) in base_tokenizer.all_special_tokens for x in random_tokens):
                random_tokens = torch.randint(0, base_tokenizer.vocab_size, (random_mask.sum(),), device=DEVICE)
            tokens.input_ids[random_mask] = random_tokens
            perturbed_texts = base_tokenizer.batch_decode(tokens.input_ids, skip_special_tokens=True)
        else:
            masked_texts = [tokenize_and_mask(x, span_length, pct, ceil_pct) for x in texts]
            perturbed_texts = masked_texts
            # replace each <extra_id_*> with args.span_length random words from FILL_DICTIONARY
            for idx, text in enumerate(perturbed_texts):
                filled_text = text
                for fill_idx in range(count_masks([text])[0]):
                    fill = random.sample(FILL_DICTIONARY, span_length)
                    filled_text = filled_text.replace(f"<extra_id_{fill_idx}>", " ".join(fill))
                assert count_masks([filled_text])[0] == 0, "Failed to replace all masks"
                perturbed_texts[idx] = filled_text

    return perturbed_texts


def perturb_texts(texts, span_length, pct, ceil_pct=False):
    chunk_size = args.chunk_size
    if '11b' in mask_filling_model_name:
        chunk_size //= 2

    outputs = []
    for i in tqdm.tqdm(range(0, len(texts), chunk_size), desc="Applying perturbations"):
        outputs.extend(perturb_texts_(texts[i:i + chunk_size], span_length, pct, ceil_pct=ceil_pct))
    return outputs


def drop_last_word(text):
    return ' '.join(text.split(' ')[:-1])


def _openai_sample(p):
    if args.dataset != 'pubmed':  # keep Answer: prefix for pubmed
        p = drop_last_word(p)

    # sample from the openai model
    kwargs = { "engine": args.openai_model, "max_tokens": 200 }
    if args.do_top_p:
        kwargs['top_p'] = args.top_p
    
    r = openai.Completion.create(prompt=f"{p}", **kwargs)
    return p + r['choices'][0].text


# sample from base_model using ****only**** the first 30 tokens in each example as context
def sample_from_model(texts, min_words=55, prompt_tokens=30):
    # encode each text as a list of token ids
    if args.dataset == 'pubmed':
        texts = [t[:t.index(custom_datasets.SEPARATOR)] for t in texts]
        all_encoded = base_tokenizer(texts, return_tensors="pt", padding=True).to(DEVICE)
    else:
        all_encoded = base_tokenizer(texts, return_tensors="pt", padding=True).to(DEVICE)
        all_encoded = {key: value[:, :prompt_tokens] for key, value in all_encoded.items()}

    if args.openai_model:
        # decode the prefixes back into text
        prefixes = base_tokenizer.batch_decode(all_encoded['input_ids'], skip_special_tokens=True)
        pool = ThreadPool(args.batch_size)

        decoded = pool.map(_openai_sample, prefixes)
    else:
        decoded = ['' for _ in range(len(texts))]

        # sample from the model until we get a sample with at least min_words words for each example
        # this is an inefficient way to do this (since we regenerate for all inputs if just one is too short), but it works
        tries = 0
        m = min(len(x.split()) for x in decoded)
        while m < min_words:
            if tries != 0:
                print()
                print(f"min words: {m}, needed {min_words}, regenerating (try {tries})")

            sampling_kwargs = {}
            if args.do_top_p:
                sampling_kwargs['top_p'] = args.top_p
            elif args.do_top_k:
                sampling_kwargs['top_k'] = args.top_k
            min_length = 50 if args.dataset in ['pubmed'] else 150
            outputs = base_model.generate(**all_encoded, min_length=min_length, max_length=200, do_sample=True, **sampling_kwargs, pad_token_id=base_tokenizer.eos_token_id, eos_token_id=base_tokenizer.eos_token_id,use_cache=True)
            decoded = base_tokenizer.batch_decode(outputs, skip_special_tokens=True)
            tries += 1
            m = min(len(x.split()) for x in decoded)

    if args.openai_model:
        global API_TOKEN_COUNTER

        # count total number of tokens with GPT2_TOKENIZER
        total_tokens = sum(len(GPT2_TOKENIZER.encode(x)) for x in decoded)
        API_TOKEN_COUNTER += total_tokens

    return decoded


def get_likelihood(logits, labels):
    assert logits.shape[0] == 1
    assert labels.shape[0] == 1

    logits = logits.view(-1, logits.shape[-1])[:-1]
    labels = labels.view(-1)[1:]
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
    log_likelihood = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
    return log_likelihood.mean()


# Get the log likelihood of each text under the base_model
def get_ll(text):
    if args.openai_model:        
        kwargs = { "engine": args.openai_model, "temperature": 0, "max_tokens": 0, "echo": True, "logprobs": 0}
        r = openai.Completion.create(prompt=f"<|endoftext|>{text}", **kwargs)
        result = r['choices'][0]
        tokens, logprobs = result["logprobs"]["tokens"][1:], result["logprobs"]["token_logprobs"][1:]

        assert len(tokens) == len(logprobs), f"Expected {len(tokens)} logprobs, got {len(logprobs)}"

        return np.mean(logprobs)
    else:
        with torch.no_grad():
            tokenized = base_tokenizer(text, return_tensors="pt").to(DEVICE)
            labels = tokenized.input_ids
            return -base_model(**tokenized, labels=labels).loss.item()


def get_lls(texts):
    if not args.openai_model:
        return [get_ll(text) for text in texts]
    else:
        global API_TOKEN_COUNTER

        # use GPT2_TOKENIZER to get total number of tokens
        total_tokens = sum(len(GPT2_TOKENIZER.encode(text)) for text in texts)
        API_TOKEN_COUNTER += total_tokens * 2  # multiply by two because OpenAI double-counts echo_prompt tokens

        pool = ThreadPool(args.batch_size)
        return pool.map(get_ll, texts)


# get the average rank of each observed token sorted by model likelihood
def get_rank(text, log=False):
    assert args.openai_model is None, "get_rank not implemented for OpenAI models"

    with torch.no_grad():
        tokenized = base_tokenizer(text, return_tensors="pt").to(DEVICE)
        logits = base_model(**tokenized).logits[:,:-1]
        labels = tokenized.input_ids[:,1:]

        # get rank of each label token in the model's likelihood ordering
        matches = (logits.argsort(-1, descending=True) == labels.unsqueeze(-1)).nonzero()

        assert matches.shape[1] == 3, f"Expected 3 dimensions in matches tensor, got {matches.shape}"

        ranks, timesteps = matches[:,-1], matches[:,-2]

        # make sure we got exactly one match for each timestep in the sequence
        assert (timesteps == torch.arange(len(timesteps)).to(timesteps.device)).all(), "Expected one match per timestep"

        ranks = ranks.float() + 1 # convert to 1-indexed rank
        if log:
            ranks = torch.log(ranks)

        return ranks.float().mean().item()


# get average entropy of each token in the text
def get_entropy(text):
    assert args.openai_model is None, "get_entropy not implemented for OpenAI models"

    with torch.no_grad():
        tokenized = base_tokenizer(text, return_tensors="pt").to(DEVICE)
        logits = base_model(**tokenized).logits[:,:-1]
        neg_entropy = F.softmax(logits, dim=-1) * F.log_softmax(logits, dim=-1)
        return -neg_entropy.sum(-1).mean().item()


def get_roc_metrics(real_preds, sample_preds):
    real_id = [i for i in range(len(real_preds)) if not np.isnan(real_preds[i])]
    sample_id = [i for i in range(len(sample_preds)) if not np.isnan(sample_preds[i])]
    real_preds = [real_preds[i] for i in real_id]
    sample_preds = [sample_preds[i] for i in sample_id]
    fpr, tpr, _ = roc_curve([0] * len(real_preds) + [1] * len(sample_preds), real_preds + sample_preds)
    roc_auc = auc(fpr, tpr)
    return fpr.tolist(), tpr.tolist(), float(roc_auc)


def get_precision_recall_metrics(real_preds, sample_preds):
    real_id = [i for i in range(len(real_preds)) if not np.isnan(real_preds[i])]
    sample_id = [i for i in range(len(sample_preds)) if not np.isnan(sample_preds[i])]
    real_preds = [real_preds[i] for i in real_id]
    sample_preds = [sample_preds[i] for i in sample_id]
    precision, recall, _ = precision_recall_curve([0] * len(real_preds) + [1] * len(sample_preds), real_preds + sample_preds)
    pr_auc = auc(recall, precision)
    return precision.tolist(), recall.tolist(), float(pr_auc)


# save the ROC curve for each experiment, given a list of output dictionaries, one for each experiment, using colorblind-friendly colors
def save_results(experiments):
    with open(f"document_length_results_all/group_model_{args.paraphrase_times}-results.json","a") as f:
        for experiment in experiments:
            metrics = experiment["metrics"]
            roc_auc = metrics['roc_auc']
            f.write(
                json.dumps(
                    {
                    'method':experiment['name'],
                    'dataset':args.dataset,
                    'n_samples':args.n_samples,
                    'detect_model':args.base_model_name,
                    'base_model':args.base_model_name,
                    'paraphrase_times':args.paraphrase_times,
                    'gan_s':args.gan_s,
                    'gan_scale':args.gan_scale,
                    'length_ratio':args.text_length_ratio,
                    'auroc':f"{roc_auc:.3f}",
                    'kl_div':f"{kl_divergence:.3f}"
                    }
                )
            )
            f.write('\n')

def save_roc_curves(experiments):
    # first, clear plt
    plt.clf()

    for experiment, color in zip(experiments, COLORS):
        metrics = experiment["metrics"]
        plt.plot(metrics["fpr"], metrics["tpr"], label=f"{experiment['name']}, roc_auc={metrics['roc_auc']:.3f}", color=color)
        # print roc_auc for this experiment
        print(f"{experiment['name']} roc_auc: {metrics['roc_auc']:.3f}")
    plt.plot([0, 1], [0, 1], color='black', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f'ROC Curves ({base_model_name} - {args.mask_filling_model_name})')
    plt.legend(loc="lower right", fontsize=6)
    plt.savefig(f"{SAVE_FOLDER}/roc_curves.png")


# save the histogram of log likelihoods in two side-by-side plots, one for real and real perturbed, and one for sampled and sampled perturbed
def save_ll_histograms(experiments):
    # first, clear plt
    plt.clf()

    for experiment in experiments:
        try:
            results = experiment["raw_results"]
            # plot histogram of sampled/perturbed sampled on left, original/perturbed original on right
            plt.figure(figsize=(20, 6))
            plt.subplot(1, 2, 1)
            plt.hist([r["sampled_ll"] for r in results], alpha=0.5, bins='auto', label='sampled')
            plt.hist([r["perturbed_sampled_ll"] for r in results], alpha=0.5, bins='auto', label='perturbed sampled')
            plt.xlabel("log likelihood")
            plt.ylabel('count')
            plt.legend(loc='upper right')
            plt.subplot(1, 2, 2)
            plt.hist([r["original_ll"] for r in results], alpha=0.5, bins='auto', label='original')
            plt.hist([r["perturbed_original_ll"] for r in results], alpha=0.5, bins='auto', label='perturbed original')
            plt.xlabel("log likelihood")
            plt.ylabel('count')
            plt.legend(loc='upper right')
            plt.savefig(f"{SAVE_FOLDER}/ll_histograms_{experiment['name']}.png")
        except:
            pass


# save the histograms of log likelihood ratios in two side-by-side plots, one for real and real perturbed, and one for sampled and sampled perturbed
def save_llr_histograms(experiments):
    # first, clear plt
    plt.clf()

    for experiment in experiments:
        try:
            results = experiment["raw_results"]
            # plot histogram of sampled/perturbed sampled on left, original/perturbed original on right
            plt.figure(figsize=(20, 6))
            plt.subplot(1, 2, 1)

            # compute the log likelihood ratio for each result
            for r in results:
                r["sampled_llr"] = r["sampled_ll"] - r["perturbed_sampled_ll"]
                r["original_llr"] = r["original_ll"] - r["perturbed_original_ll"]
            
            plt.hist([r["sampled_llr"] for r in results], alpha=0.5, bins='auto', label='sampled')
            plt.hist([r["original_llr"] for r in results], alpha=0.5, bins='auto', label='original')
            plt.xlabel("log likelihood ratio")
            plt.ylabel('count')
            plt.legend(loc='upper right')
            plt.savefig(f"{SAVE_FOLDER}/llr_histograms_{experiment['name']}.png")
        except:
            pass


def get_perturbation_results(span_length=10, n_perturbations=1, n_samples=500):
    load_mask_model()

    torch.manual_seed(0)
    np.random.seed(0)

    results = []
    original_text = data["original"]
    sampled_text = data["sampled"]

    perturb_fn = functools.partial(perturb_texts, span_length=span_length, pct=args.pct_words_masked)

    p_sampled_text = perturb_fn([x for x in sampled_text for _ in range(n_perturbations)])
    p_original_text = perturb_fn([x for x in original_text for _ in range(n_perturbations)])
    for _ in range(n_perturbation_rounds - 1):
        try:
            p_sampled_text, p_original_text = perturb_fn(p_sampled_text), perturb_fn(p_original_text)
        except AssertionError:
            break

    assert len(p_sampled_text) == len(sampled_text) * n_perturbations, f"Expected {len(sampled_text) * n_perturbations} perturbed samples, got {len(p_sampled_text)}"
    assert len(p_original_text) == len(original_text) * n_perturbations, f"Expected {len(original_text) * n_perturbations} perturbed samples, got {len(p_original_text)}"

    for idx in range(len(original_text)):
        perturbed_sampled = p_sampled_text[idx * n_perturbations: (idx + 1) * n_perturbations]
        perturbed_original = p_original_text[idx * n_perturbations: (idx + 1) * n_perturbations]
        if '' in perturbed_sampled or '' in perturbed_original:
            continue
        results.append({
            "original": original_text[idx],
            "sampled": sampled_text[idx],
            "perturbed_sampled": p_sampled_text[idx * n_perturbations: (idx + 1) * n_perturbations],
            "perturbed_original": p_original_text[idx * n_perturbations: (idx + 1) * n_perturbations]
        })

    load_base_model()

    for res in tqdm.tqdm(results, desc="Computing log likelihoods"):
        p_sampled_ll = get_lls(res["perturbed_sampled"])
        p_original_ll = get_lls(res["perturbed_original"])
        res["original_ll"] = get_ll(res["original"])
        res["sampled_ll"] = get_ll(res["sampled"])
        res["all_perturbed_sampled_ll"] = p_sampled_ll
        res["all_perturbed_original_ll"] = p_original_ll
        res["perturbed_sampled_ll"] = np.mean(p_sampled_ll)
        res["perturbed_original_ll"] = np.mean(p_original_ll)
        res["perturbed_sampled_ll_std"] = np.std(p_sampled_ll) if len(p_sampled_ll) > 1 else 1
        res["perturbed_original_ll_std"] = np.std(p_original_ll) if len(p_original_ll) > 1 else 1

    return results


def run_perturbation_experiment(results, criterion, span_length=10, n_perturbations=1, n_samples=500):
    # compute diffs with perturbed
    predictions = {'real': [], 'samples': []}
    for res in results:
        if criterion == 'd':
            predictions['real'].append(res['original_ll'] - res['perturbed_original_ll'])
            predictions['samples'].append(res['sampled_ll'] - res['perturbed_sampled_ll'])
        elif criterion == 'z':
            if res['perturbed_original_ll_std'] == 0:
                res['perturbed_original_ll_std'] = 1
                print("WARNING: std of perturbed original is 0, setting to 1")
                print(f"Number of unique perturbed original texts: {len(set(res['perturbed_original']))}")
                print(f"Original text: {res['original']}")
            if res['perturbed_sampled_ll_std'] == 0:
                res['perturbed_sampled_ll_std'] = 1
                print("WARNING: std of perturbed sampled is 0, setting to 1")
                print(f"Number of unique perturbed sampled texts: {len(set(res['perturbed_sampled']))}")
                print(f"Sampled text: {res['sampled']}")
            predictions['real'].append((res['original_ll'] - res['perturbed_original_ll']) / res['perturbed_original_ll_std'])
            predictions['samples'].append((res['sampled_ll'] - res['perturbed_sampled_ll']) / res['perturbed_sampled_ll_std'])
    fpr, tpr, roc_auc = get_roc_metrics(predictions['real'], predictions['samples'])
    #predictions['real'], predictions['samples']=min_max_normalize(predictions['real'], predictions['samples'])
    #kl_fake_real = _rel_entr(predictions['samples'],predictions['real'])
    #p, r, pr_auc = get_precision_recall_metrics(predictions['real'], predictions['samples'])

    name = f'perturbation_{n_perturbations}_{criterion}'
    #print(f"{name} ROC AUC: {roc_auc}, KL DIV: {kl_fake_real}")
    return {
        'name': name,
        'predictions': predictions,
        'info': {
            'pct_words_masked': args.pct_words_masked,
            'span_length': span_length,
            'n_perturbations': n_perturbations,
            'n_samples': n_samples,
        },
        'raw_results': results,
        'metrics': {
            'roc_auc': roc_auc,
            'fpr': fpr,
            'tpr': tpr
        }
    }


def run_baseline_threshold_experiment(criterion_fn, name, n_samples=500):
    torch.manual_seed(0)
    np.random.seed(0)

    results = []
    batch_num = math.ceil(n_samples / batch_size)
    for batch in tqdm.tqdm(range(batch_num), desc=f"Computing {name} criterion"):
        original_text = data["original"][batch * batch_size:(batch + 1) * batch_size]
        sampled_text = data["sampled"][batch * batch_size:(batch + 1) * batch_size]

        for idx in range(len(original_text)):
            results.append({
                "original": original_text[idx],
                "original_crit": criterion_fn(original_text[idx]),
                "sampled": sampled_text[idx],
                "sampled_crit": criterion_fn(sampled_text[idx]),
            })

    # compute prediction scores for real/sampled passages
    predictions = {
        'real': [x["original_crit"] for x in results],
        'samples': [x["sampled_crit"] for x in results],
    }
    fpr, tpr, roc_auc = get_roc_metrics(predictions['real'], predictions['samples'])
    predictions['real'], predictions['samples']=min_max_normalize(predictions['real'], predictions['samples'])
    if "entropy" in name:
        print(predictions['real'], predictions['samples'])
    kl_fake_real = _rel_entr(predictions['samples'],predictions['real'])
    print(f"{name}_threshold ROC AUC: {roc_auc}, KL_DIV: {kl_fake_real}")
    return {
        'name': f'{name}',
        'predictions': predictions,
        'info': {
            'n_samples': n_samples,
        },
        'raw_results': results,
        'metrics': {
            'roc_auc': roc_auc,
            'fpr': fpr,
            'tpr': tpr,
            'kl_divergence':kl_fake_real
        }
    }


# strip newlines from each example; replace one or more newlines with a single space
def strip_newlines(text):
    return ' '.join(text.split())


# trim to shorter length
def trim_to_shorter_length(texta, textb):
    # truncate to shorter of o and s
    shorter_length = min(len(texta.split(' ')), len(textb.split(' ')))
    texta = ' '.join(texta.split(' ')[:shorter_length])
    textb = ' '.join(textb.split(' ')[:shorter_length])
    return texta, textb


def truncate_to_substring(text, substring, idx_occurrence):
    # truncate everything after the idx_occurrence occurrence of substring
    assert idx_occurrence > 0, 'idx_occurrence must be > 0'
    idx = -1
    for _ in range(idx_occurrence):
        idx = text.find(substring, idx + 1)
        if idx == -1:
            return text
    return text[:idx]

def model_paraphrase(generated_samples):
    print("Applying Model Paraphrasing Attack...")
    torch.manual_seed(13)
    np.random.seed(13)
    #batch_num = math.ceil(len(generated_samples) / batch_size)
    with torch.no_grad():
        #for batch_id in range(batch_num):
            #generated_sample = generated_samples[batch_id*batch_size:(batch_id+1)*batch_size]
        batch = ["paraphrase: "+item for item in generated_samples]
        inputs = paraphrase_tokenizer(batch,padding=True,truncation=True,return_tensors="pt")
        inputs = {k:v.to(DEVICE) for k,v in inputs.items()}
        outputs = paraphraser.generate(
            **inputs,
            max_length=200,
            do_sample=True,
            top_k=50,
            top_p=0.95,
            min_length =50,
            early_stopping=True,
            num_return_sequences=1,
            eos_token_id=paraphrase_tokenizer.eos_token_id,
            use_cache=True
        )
        paraphrased_tokens =  paraphrase_tokenizer.batch_decode(outputs,skip_special_tokens=True)
    #results["paraphrased_token_ids"] = list(output)
    paraphrased_tokens = [
        re.sub('[^a-zA-Z0-9 \?\'\-\/\:\.]', '', it_seq) for it_seq in paraphrased_tokens
    ]
    return paraphrased_tokens

import math

def generate_samples(raw_data, batch_size):
    torch.manual_seed(42)
    np.random.seed(42)
    data = {
        "original": [],
        "sampled": [],
    }
    batch_num = math.ceil(len(raw_data) / batch_size)
    for batch in range(batch_num ):
        print('Generating samples for batch', batch, 'of', batch_num )
        original_text = raw_data[batch * batch_size:(batch + 1) * batch_size]
        sampled_text = sample_from_model(original_text, min_words= 55)

        for o, s in zip(original_text, sampled_text):
            if args.dataset == 'pubmed':
                s = truncate_to_substring(s, 'Question:', 2)
                o = o.replace(custom_datasets.SEPARATOR, ' ')

            o, s = trim_to_shorter_length(o, s)

            # add to the data
            data["original"].append(o)
            data["sampled"].append(s)

    if args.pre_perturb_pct > 0:
        print(f'APPLYING {args.pre_perturb_pct}, {args.pre_perturb_span_length} PRE-PERTURBATIONS')
        load_mask_model()
        data["sampled"] = perturb_texts(data["sampled"], args.pre_perturb_span_length, args.pre_perturb_pct, ceil_pct=True)
        load_base_model()

    return data


def generate_data(dataset, key):
    # load data
    if dataset in custom_datasets.DATASETS:
        data = custom_datasets.load(dataset)
    else:
        data = datasets.load_dataset(dataset, split='train')[key]

    # get unique examples, strip whitespace, and remove newlines
    # then take just the long examples, shuffle, take the first 5,000 to tokenize to save time
    # then take just the examples that are <= 512 tokens (for the mask model)
    # then generate n_samples samples

    # remove duplicates from the data
    data = list(dict.fromkeys(data))  # deterministic, as opposed to set()

    # strip whitespace around each example
    data = [x.strip() for x in data]

    # remove newlines from each example
    data = [strip_newlines(x) for x in data]

    # try to keep only examples with > 250 words
    if dataset in ['writing', 'squad', 'xsum']:
        long_data = [x for x in data if len(x.split()) > 250]
        if len(long_data) > 0:
            data = long_data

    random.seed(0)
    random.shuffle(data)

    data = data[:5_000]

    # keep only examples with <= 512 tokens according to mask_tokenizer
    # this step has the extra effect of removing examples with low-quality/garbage content
    tokenized_data = preproc_tokenizer(data)
    data = [x for x, y in zip(data, tokenized_data["input_ids"]) if len(y) <= 512]

    # print stats about remainining data
    print(f"Total number of samples: {len(data[:n_samples])}")
    print(f"Average number of words: {np.mean([len(x.split()) for x in data[:n_samples]])}")

    return generate_samples(data[:n_samples], batch_size=batch_size)


def load_base_model_and_tokenizer(name):
    path = checkpoint_weight_dir+name
    if args.openai_model is None:
        print(f'Loading BASE model {args.base_model_name}...')
        if "llama" not in name and "vicuna" not in name:
            base_model = transformers.AutoModelForCausalLM.from_pretrained(path)
        else:
            base_model = transformers.LlamaForCausalLM.from_pretrained(path)
    else:
        base_model = None

    optional_tok_kwargs = {}
    if "opt" in name:
        print("Using non-fast tokenizer for OPT")
        optional_tok_kwargs['fast'] = False
    if args.dataset in ['pubmed']:
        optional_tok_kwargs['padding_side'] = 'left'
    if "llama" not in name and "vicuna" not in name:
        base_tokenizer = transformers.AutoTokenizer.from_pretrained(path, **optional_tok_kwargs)
    else:
        base_tokenizer = transformers.AutoTokenizer.from_pretrained(path, use_fast=False,**optional_tok_kwargs)
    base_tokenizer.pad_token_id = base_tokenizer.eos_token_id

    return base_model, base_tokenizer


def eval_supervised(data, model):
    print(f'Beginning supervised evaluation with {model}...')
    path = checkpoint_weight_dir+model
    detector = transformers.AutoModelForSequenceClassification.from_pretrained(path).to(DEVICE)
    tokenizer = transformers.AutoTokenizer.from_pretrained(path)

    real, fake = data['original'], data['sampled']
    detector.eval()
    with torch.no_grad():
        # get predictions for real
        real_preds = []
        batch_num = math.ceil(len(real) / batch_size)
        for batch in tqdm.tqdm(range(batch_num ), desc="Evaluating real"):
            batch_real = real[batch * batch_size:(batch + 1) * batch_size]
            batch_real = tokenizer(batch_real, padding=True, truncation=True, max_length=512, return_tensors="pt").to(DEVICE)
            real_preds.extend(detector(**batch_real).logits.softmax(-1)[:,0].tolist())
        
        # get predictions for fake
        fake_preds = []
        batch_num = math.ceil(len(fake) / batch_size)
        for batch in tqdm.tqdm(range(batch_num ), desc="Evaluating fake"):
            batch_fake = fake[batch * batch_size:(batch + 1) * batch_size]
            batch_fake = tokenizer(batch_fake, padding=True, truncation=True, max_length=512, return_tensors="pt").to(DEVICE)
            fake_preds.extend(detector(**batch_fake).logits.softmax(-1)[:,0].tolist())

    predictions = {
        'real': real_preds,
        'samples': fake_preds,
    }
    fpr, tpr, roc_auc = get_roc_metrics(predictions['real'], predictions['samples'])
    predictions['real'], predictions['samples']=min_max_normalize(predictions['real'], predictions['samples'])
    kl_fake_real = _rel_entr(predictions['samples'],predictions['real'])
    print(f"{model} ROC AUC: {roc_auc}, KL_DIV: {kl_fake_real}")

    # free GPU memory
    del detector
    torch.cuda.empty_cache()

    return {
        'name': model,
        'predictions': predictions,
        'info': {
            'n_samples': n_samples,
        },
        'metrics': {
            'roc_auc': roc_auc,
            'fpr': fpr,
            'tpr': tpr,
            'kl_divergence':kl_fake_real
        }
    }

def eval_gan(data,args):
    scale =  args.gan_scale
    model_path = detector_weight_dir+"s_{}_p_{}_roberta_{}_target_".format(args.gan_s,args.gan_p,scale)+args.base_model_name+".bin"
    ckpt_path = checkpoint_weight_dir+"roberta_{}".format(scale)
    print(f'Beginning gan evaluation with {model_path}...')
    detector = transformers.AutoModelForSequenceClassification.from_pretrained(ckpt_path).to(DEVICE)
    tokenizer = transformers.AutoTokenizer.from_pretrained(ckpt_path)
    state=torch.load(model_path,map_location=DEVICE)
    detector.load_state_dict(state,strict=False)

    real, fake = data['original'], data['sampled']
    detector.eval()
    with torch.no_grad():
        # get predictions for real
        real_preds = []
        batch_num = math.ceil(len(real) / batch_size)
        for batch in tqdm.tqdm(range( batch_num ), desc="Evaluating real"):
            batch_real = real[batch * batch_size:(batch + 1) * batch_size]
            batch_real = tokenizer(batch_real, padding=True, truncation=True, max_length=512, return_tensors="pt").to(DEVICE)
            #real_preds.extend(detector(**batch_real).logits.softmax(-1)[:,0].tolist())
            real_preds.extend(F.log_softmax(detector(**batch_real).logits,-1)[:,0].exp().tolist())
        # get predictions for fake
        fake_preds = []
        batch_num = math.ceil(len(fake) / batch_size)
        for batch in tqdm.tqdm(range(batch_num ), desc="Evaluating fake"):
            batch_fake = fake[batch * batch_size:(batch + 1) * batch_size]
            batch_fake = tokenizer(batch_fake, padding=True, truncation=True, max_length=512, return_tensors="pt").to(DEVICE)
            #fake_preds.extend(detector(**batch_fake).logits.softmax(-1)[:,0].tolist())
            fake_preds.extend(F.log_softmax(detector(**batch_fake).logits,-1)[:,0].exp().tolist())

    predictions = {
        'real': real_preds,
        'samples': fake_preds,
    }
    fpr, tpr, roc_auc = get_roc_metrics(predictions['real'], predictions['samples'])
    predictions['real'], predictions['samples']=min_max_normalize(predictions['real'], predictions['samples'])
    kl_fake_real = _rel_entr(predictions['samples'],predictions['real'])
    print(f"RADAR ROC AUC: {roc_auc}, KL_DIV: {kl_fake_real}")

    # free GPU memory
    del detector
    torch.cuda.empty_cache()

    return {
        'name': "s_{}_p_{}_roberta_{}_target_".format(args.gan_s,args.gan_p,scale)+args.base_model_name,
        'predictions': predictions,
        'info': {
            'n_samples': n_samples,
        },
        'metrics': {
            'roc_auc': roc_auc,
            'fpr': fpr,
            'tpr': tpr,
            'kl_divergence':kl_fake_real
        }
    }

if __name__ == '__main__':
    DEVICE = "cuda"

    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default="xsum")
    parser.add_argument('--dataset_key', type=str, default="document")
    parser.add_argument('--pct_words_masked', type=float, default=0.3) # pct masked is actually pct_words_masked * (span_length / (span_length + 2 * buffer_size))
    parser.add_argument('--span_length', type=int, default=2)
    parser.add_argument('--n_samples', type=int, default=200)
    parser.add_argument('--openai_model', type=str, default=None)
    parser.add_argument('--n_perturbation_list', type=str, default="1,10")
    parser.add_argument('--n_perturbation_rounds', type=int, default=1)
    parser.add_argument('--base_model_name', type=str, default="gpt2-medium")
    parser.add_argument('--scoring_model_name', type=str, default="")
    parser.add_argument('--mask_filling_model_name', type=str, default="t5-large")
    parser.add_argument('--paraphrase_model_name', type=str, default="parrot")
    parser.add_argument('--batch_size', type=int, default=50)
    parser.add_argument('--chunk_size', type=int, default=20)
    parser.add_argument('--n_similarity_samples', type=int, default=20)
    parser.add_argument('--base_half', action='store_true')
    parser.add_argument('--do_top_k', action='store_true')
    parser.add_argument('--top_k', type=int, default=40)
    parser.add_argument('--do_top_p', action='store_true')
    parser.add_argument('--top_p', type=float, default=0.96)
    parser.add_argument('--openai_key_path', type=str,default="")
    parser.add_argument('--baselines', action='store_true')
    parser.add_argument('--detectgpt', action='store_true')
    parser.add_argument('--supervised', action='store_true')
    parser.add_argument('--gan', action='store_true')
    parser.add_argument('--buffer_size', type=int, default=1)
    parser.add_argument('--mask_top_p', type=float, default=1.0)
    parser.add_argument('--pre_perturb_pct', type=float, default=0.0)
    parser.add_argument('--pre_perturb_span_length', type=int, default=5)
    parser.add_argument('--random_fills', action='store_true')
    parser.add_argument('--random_fills_tokens', action='store_true')
    parser.add_argument('--cache_dir', type=str, default="~/.cache")
    parser.add_argument('--gan_scale', type=str, default="base")
    parser.add_argument('--gan_p', type=int, default=1)
    parser.add_argument('--gan_s', type=str, default="base")
    parser.add_argument('--attack', action='store_true')
    parser.add_argument('--attack_side', type=str, default="model")
    parser.add_argument('--load_paraphrased_results',action='store_true')
    parser.add_argument('--just_save',action='store_true')
    parser.add_argument('--eval_text_length',action='store_true')
    parser.add_argument('--load_data',action='store_true')
    parser.add_argument('--paraphrase_times',type=int,default=0)
    parser.add_argument('--text_length_ratio',type=float,default=-1.0)
    

    args = parser.parse_args()
    precision_string = "fp-16"
    sampling_string = "top_k" if args.do_top_k else ("top_p" if args.do_top_p else "temp")
    base_model_name = args.base_model_name
    scoring_model_string = (f"-{args.scoring_model_name}" if args.scoring_model_name else "").replace('/', '_')
    attack = args.paraphrase_model_name if args.attack else "wo_attack"
    mask_filling_model_name = args.mask_filling_model_name
    n_samples = args.n_samples
    batch_size = args.batch_size
    n_perturbation_list = [int(x) for x in args.n_perturbation_list.split(",")]
    n_perturbation_rounds = args.n_perturbation_rounds
    n_similarity_samples = args.n_similarity_samples
    SAVE_FOLDER = f"eval_results/main/{attack}/{base_model_name}_{args.mask_filling_model_name}-{sampling_string}-{precision_string}-{args.dataset}-{args.n_samples}"

    base_model, base_tokenizer = load_base_model_and_tokenizer(args.base_model_name)
    preproc_tokenizer = transformers.AutoTokenizer.from_pretrained('t5-small', model_max_length=512)
    if args.base_model_name != "gpt_2_xl":
        base_model.half()
    load_base_model()
    print(f'Loading dataset {args.dataset}...')
    data_save_folder=f"data_for_eval/{args.base_model_name}-{args.dataset}-{args.n_samples}"
    if not args.load_data:
        data = generate_data(args.dataset, args.dataset_key)
        # write the data to a json file in the save folder
        if args.just_save:
            if not os.path.exists(data_save_folder):
                os.makedirs(data_save_folder)
                with open(os.path.join(data_save_folder, "raw_data.json"), "w") as f:
                    print(f"Writing raw data to {os.path.join(data_save_folder, 'raw_data.json')}")
                    json.dump(data, f)
            os._exit(0)
    else:
        with open(os.path.join(data_save_folder, "raw_data.json"), "r") as f:
            print(f"Loading raw data from {os.path.join(data_save_folder, 'raw_data.json')}")
            data = json.load(f)
    if args.paraphrase_times >0:
        if args.paraphrase_model_name=="gpt-3.5-turbo":
            print("We don't support callling OpenaiAPI in this scripts, so we should load the pre-paraphrased data...")
            with open(os.path.join(data_save_folder, f"{args.attack_side}_{args.paraphrase_times}_paraphrased_data.json"), "r") as f:
                data['sampled']=[json.loads(item)['paraphrased'] for item in f.readlines()]
        else:
            print(" We support Model Paraphrase...")
            paraphrase_model_path = checkpoint_weight_dir+"paraphraser/"+args.paraphrase_model_name
            paraphrase_tokenizer = transformers.AutoTokenizer.from_pretrained(paraphrase_model_path)
            paraphraser = transformers.AutoModelForSeq2SeqLM.from_pretrained(paraphrase_model_path).to(DEVICE)
            if "trained_paraphraser" in args.paraphrase_model_name:
                # load weight
                ckpt_path = detector_weight_dir+"s_{}_p_{}_roberta_{}_target_".format(args.gan_s,args.gan_p,args.gan_scale)+args.base_model_name+"_paraphraser.bin"
                paraphraser.load_state_dict(
                    torch.load(ckpt_path,map_location=DEVICE)
                )
            paraphraser.eval()
            old_sample = data['sampled'] if args.attack_side == "model" else data["original"]
            for p_time in range(args.paraphrase_times):
                if os.path.exists(os.path.join(data_save_folder, f"seen_{args.attack_side}_{p_time+1}_paraphrased_data.json")):
                    with open(os.path.join(data_save_folder, f"seen_{args.attack_side}_{p_time+1}_paraphrased_data.json"),'r') as f:
                        data['sampled']=[json.loads(item)['paraphrased'] for item in f.readlines()]
                    old_sample= data['sampled']
                    continue
                new_sample = []
                for start_id in range(0,len(old_sample),batch_size):
                    batch = old_sample[start_id:start_id+batch_size]
                    new_sample+=model_paraphrase(batch)
                data['sampled']=new_sample
                for i in range(len(data['sampled'])):
                   data['original'][i],data['sampled'][i]=trim_to_shorter_length(data['original'][i],data['sampled'][i])
                old_sample = data['sampled']
                if not os.path.exists(os.path.join(data_save_folder, f"seen_{args.attack_side}_{p_time+1}_paraphrased_data.json")):
                    with open(os.path.join(data_save_folder, f"seen_{args.attack_side}_{p_time+1}_paraphrased_data.json"),'a') as f:
                        for item in data['sampled']:
                            f.write(json.dumps({'paraphrased':item}))
                            f.write('\n')
    if args.eval_text_length:
        def group_data(human_texts,model_texts,ratio):
            length = len(human_texts)
            idx = list(range(length))
            sorted_idx = sorted(idx,key = lambda x:len(model_texts[x].split(' ')))
            sorted_human=[human_texts[i] for i in sorted_idx]
            sorted_model=[model_texts[i] for i in sorted_idx]
            return {'original':sorted_human[int((ratio-0.2)*length):int(ratio*length)],
            'sampled':sorted_model[int((ratio-0.2)*length):int(ratio*length)]
            }
        data = group_data(data['original'],data['sampled'],args.text_length_ratio)
        #for i in range(len(dataset['sampled'])):
            #o,s = trim_length(dataset['original'][i],dataset['sampled'][i],args.text_length_ratio)
            #dataset['original'][i],dataset['sampled'][i] = o,s
            #dataset['sampled'][i] = s           

    API_TOKEN_COUNTER = 0
    START_DATE = datetime.datetime.now().strftime('%Y-%m-%d')
    START_TIME = datetime.datetime.now().strftime('%H-%M-%S-%f')

    # define SAVE_FOLDER as the timestamp - base model name - mask filling model name
    # create it if it doesn't exist
    if not os.path.exists(SAVE_FOLDER) and not args.just_save:
        os.makedirs(SAVE_FOLDER)
        print(f"Saving results to absolute path: {os.path.abspath(SAVE_FOLDER)}")

    # write args to file
    with open(os.path.join(SAVE_FOLDER, "args.json"), "w") as f:
        json.dump(args.__dict__, f, indent=4)


    GPT2_TOKENIZER = transformers.GPT2Tokenizer.from_pretrained('gpt2')

    # generic generative model
    # mask filling t5 model
    mask_filling_model_path = checkpoint_weight_dir+mask_filling_model_name
    if args.detectgpt and not args.random_fills:
        int8_kwargs = {}
        half_kwargs = {}
        print(f'Loading mask filling model {mask_filling_model_name}...')
        mask_model = transformers.AutoModelForSeq2SeqLM.from_pretrained(mask_filling_model_path)
        try:
            n_positions = mask_model.config.n_positions
        except AttributeError:
            n_positions = 512
    else:
        n_positions = 512
    mask_tokenizer = transformers.AutoTokenizer.from_pretrained(mask_filling_model_path, model_max_length=n_positions)
    if args.dataset in ['english', 'german']:
        preproc_tokenizer = mask_tokenizer

    if args.random_fills:
        FILL_DICTIONARY = set()
        for texts in data.values():
            for text in texts:
                FILL_DICTIONARY.update(text.split())
        FILL_DICTIONARY = sorted(list(FILL_DICTIONARY))

    # write the data to a json file in the save folder
    with open(os.path.join(SAVE_FOLDER, "raw_data.json"), "w") as f:
        print(f"Writing raw data to {os.path.join(SAVE_FOLDER, 'raw_data.json')}")
        json.dump(data, f)
    
    baseline_outputs = []
    if args.baselines:
        baseline_outputs.append(run_baseline_threshold_experiment(get_ll, "log_probability", n_samples=n_samples))
        if args.openai_model is None:
            rank_criterion = lambda text: -get_rank(text, log=False)
            baseline_outputs.append(run_baseline_threshold_experiment(rank_criterion, "rank", n_samples=n_samples))
            logrank_criterion = lambda text: -get_rank(text, log=True)
            baseline_outputs.append(run_baseline_threshold_experiment(logrank_criterion, "log_rank", n_samples=n_samples))
            entropy_criterion = lambda text: get_entropy(text)
            baseline_outputs.append(run_baseline_threshold_experiment(entropy_criterion, "entropy", n_samples=n_samples))
        if args.supervised:
            #baseline_outputs.append(eval_supervised(data, model='roberta_base_openai_detector'))
            baseline_outputs.append(eval_supervised(data, model='roberta_large_openai_detector'))
    if args.gan:
        gan_outputs= [eval_gan(data,args)]

    detectgpt_outputs = []

    if args.detectgpt:
        # run perturbation experiments
        for n_perturbations in n_perturbation_list:
            perturbation_results = get_perturbation_results(args.span_length, n_perturbations, n_samples)
            for perturbation_mode in ['d']:
                output = run_perturbation_experiment(
                    perturbation_results, perturbation_mode, span_length=args.span_length, n_perturbations=n_perturbations, n_samples=n_samples)
                detectgpt_outputs.append(output)
                with open(os.path.join(SAVE_FOLDER, f"perturbation_{n_perturbations}_{perturbation_mode}_results.json"), "w") as f:
                    json.dump(output, f)

    if args.baselines:
        # write likelihood threshold results to a file
        with open(os.path.join(SAVE_FOLDER, f"likelihood_threshold_results.json"), "w") as f:
            json.dump(baseline_outputs[0], f)

        if args.openai_model is None:
            # write rank threshold results to a file
            with open(os.path.join(SAVE_FOLDER, f"rank_threshold_results.json"), "w") as f:
                json.dump(baseline_outputs[1], f)

            # write log rank threshold results to a file
            with open(os.path.join(SAVE_FOLDER, f"logrank_threshold_results.json"), "w") as f:
                json.dump(baseline_outputs[2], f)

            # write entropy threshold results to a file
            with open(os.path.join(SAVE_FOLDER, f"entropy_threshold_results.json"), "w") as f:
                json.dump(baseline_outputs[3], f)
        if args.supervised:
            #with open(os.path.join(SAVE_FOLDER, f"roberta_base_openai_detector_results.json"), "w") as f:
            #    json.dump(baseline_outputs[-1], f)
            
            # write supervised results to a file
            with open(os.path.join(SAVE_FOLDER, f"roberta_large_openai_detector_results.json"), "w") as f:
                json.dump(baseline_outputs[-1], f)
    
    if args.gan:
        # write supervised results to a file
        with open(os.path.join(SAVE_FOLDER, f"s_{args.gan_s}_p_{args.gan_p}_roberta_{args.gan_scale}_target_{args.base_model_name}_results.json"), "w") as f:
            json.dump(gan_outputs[0], f)


    outputs = []
    if args.baselines:
        outputs+=baseline_outputs
    if args.gan:
        outputs+=gan_outputs
    if args.detectgpt:
        outputs+=detectgpt_outputs
    #save_results(outputs)
    save_roc_curves(outputs)
    save_ll_histograms(outputs)
    save_llr_histograms(outputs)

    print(f"Used an *estimated* {API_TOKEN_COUNTER} API tokens (may be inaccurate)")

