
import json
import os
import random

import fire
import numpy as np
from sklearn.metrics import roc_auc_score, roc_curve

def sample(arr: np.array, n: int, k: int):
    arr = arr.tolist()
    new_arr = []
    for _ in range(k):
        new_arr.append(np.mean(np.array(random.sample(arr, n))))
    return np.array(new_arr)

def aggregate(arr: np.array, n: int):
    arr = arr.tolist()
    new_arr = []
    for i in range(0, len(arr), n):
        new_arr.append(np.mean(np.array(arr[i:i+n])))
    return np.array(new_arr)

def tpr_target(labels, scores, target_fpr):
    fpr, tpr, _ = roc_curve(labels, scores)
    
    indices = None
    for i in range(len(fpr)):
        if fpr[i] >= target_fpr:
            if i == 0:
                indices = [i]
            else:
                indices = [i-1, i]
            break

    if indices is None:
        return tpr[-1]
    else:
        tpr_values = [tpr[i] for i in indices]
        return np.mean(tpr_values)

def main(
    max_fpr: float = 1.0,
    fewshot: bool = False,
    with_background: bool = False,
    n: int = 5,
    k: int = 300,
    num_background: int = 1_000,
):
    dirname = "./mtd_scores"
    filenames = [
        "MTD_reddit_12000_Mistral-7B-Instruct-v0.3_N=5.jsonl.iter=1",
        "MTD_reddit_12000_Mistral-7B-Instruct-v0.3_N=5.jsonl.iter=2",
        "MTD_reddit_12000_Mistral-7B-Instruct-v0.3_N=5.jsonl.iter=3",
    ]
    scores_1 = json.loads(open(os.path.join(dirname, filenames[0])).read())
    scores_2 = json.loads(open(os.path.join(dirname, filenames[1])).read())
    scores_3 = json.loads(open(os.path.join(dirname, filenames[2])).read())

    print("max_fpr=", max_fpr)
    detector_names = list(scores_1.keys())
    for detector in detector_names:
        if detector == "OpenAI" or detector == "RADAR":
            continue
        
        print(detector)
        human_scores = np.array(scores_1[detector]["content_text"])
        machine_scores = np.array(scores_1[detector]["generation"])
        para_machine_scores = np.array(scores_1[detector]["paraphrase_generation"])
        stylepara_machine_scores_1 = np.array(scores_1[detector]["transfer_pick"])
        stylepara_machine_scores_2 = np.array(scores_2[detector]["transfer_pick"])
        stylepara_machine_scores_3 = np.array(scores_3[detector]["transfer_pick"])

        human_scores = human_scores[~np.isnan(human_scores)]
        machine_scores = machine_scores[~np.isnan(machine_scores)]
        para_machine_scores = para_machine_scores[~np.isnan(para_machine_scores)]
        stylepara_machine_scores_1 = stylepara_machine_scores_1[~np.isnan(stylepara_machine_scores_1)]
        stylepara_machine_scores_2 = stylepara_machine_scores_2[~np.isnan(stylepara_machine_scores_2)]
        stylepara_machine_scores_3 = stylepara_machine_scores_3[~np.isnan(stylepara_machine_scores_3)]

        if detector in ["Binoculars"]:
            human_scores = -human_scores
            machine_scores = -machine_scores
            para_machine_scores = -para_machine_scores
            stylepara_machine_scores_1 = -stylepara_machine_scores_1
            stylepara_machine_scores_2 = -stylepara_machine_scores_2
            stylepara_machine_scores_3 = -stylepara_machine_scores_3

        if with_background:
            background = human_scores[:num_background]
            background_mean = np.mean(background)
            human_scores = human_scores[num_background:]

        if fewshot:
            human_scores = aggregate(human_scores, n)
            machine_scores = aggregate(machine_scores, n)
            para_machine_scores = aggregate(para_machine_scores, n)
            stylepara_machine_scores_1 = aggregate(stylepara_machine_scores_1, n)
            stylepara_machine_scores_2 = aggregate(stylepara_machine_scores_2, n)
            stylepara_machine_scores_3 = aggregate(stylepara_machine_scores_3, n)
            # human_scores = sample(human_scores, n, k)
            # machine_scores = sample(machine_scores, n, k)
            # para_machine_scores = sample(para_machine_scores, n, k)
            # stylepara_machine_scores_1 = sample(stylepara_machine_scores_1, n, k)
            # stylepara_machine_scores_2 = sample(stylepara_machine_scores_2, n, k)
            # stylepara_machine_scores_3 = sample(stylepara_machine_scores_3, n, k)

        if with_background:
            human_scores -= background_mean
            machine_scores -= background_mean
            para_machine_scores -= background_mean
            stylepara_machine_scores_1 -= background_mean
            stylepara_machine_scores_2 -= background_mean
            stylepara_machine_scores_3 -= background_mean

        def create_labels(nhuman, nmachine):
            return np.array([0] * nhuman + [1] * nmachine)

        labels = create_labels(len(human_scores), len(machine_scores))
        print("(human, machine)", round(roc_auc_score(labels, np.concat((human_scores,machine_scores)), max_fpr=max_fpr), 2), round(tpr_target(labels, np.concat((human_scores,machine_scores)), target_fpr=max_fpr), 2))
        labels = create_labels(len(human_scores), len(para_machine_scores))
        print("(human, para_machine)", round(roc_auc_score(labels, np.concat((human_scores,para_machine_scores)), max_fpr=max_fpr), 2), round(tpr_target(labels, np.concat((human_scores,para_machine_scores)), target_fpr=max_fpr), 2))
        labels = create_labels(len(human_scores), len(stylepara_machine_scores_1))
        print("(human, stylepara_machine-iter1)", round(roc_auc_score(labels, np.concat((human_scores,stylepara_machine_scores_1)), max_fpr=max_fpr), 2), round(tpr_target(labels, np.concat((human_scores,stylepara_machine_scores_1)), target_fpr=max_fpr), 2))
        labels = create_labels(len(human_scores), len(stylepara_machine_scores_2))
        print("(human, stylepara_machine-iter2)", round(roc_auc_score(labels, np.concat((human_scores,stylepara_machine_scores_2)), max_fpr=max_fpr), 2), round(tpr_target(labels, np.concat((human_scores,stylepara_machine_scores_2)), target_fpr=max_fpr), 2))
        labels = create_labels(len(human_scores), len(stylepara_machine_scores_3))
        print("(human, stylepara_machine-iter3)", round(roc_auc_score(labels, np.concat((human_scores,stylepara_machine_scores_3)), max_fpr=max_fpr), 2), round(tpr_target(labels, np.concat((human_scores,stylepara_machine_scores_3)), target_fpr=max_fpr), 2))
    
    return 0

if __name__ == "__main__":
    random.seed(43)
    fire.Fire(main)