# Estimate the Total Variation between the human distribution and our 
# Style Transfer outputs.

import json
import os
from math import log, sqrt
import random; random.seed(43)

import fire
import pandas as pd
import torch
import torch.nn as nn
from evaluate import load
from sklearn.metrics import roc_auc_score
from tqdm import tqdm

from embedding_utils import (
    load_luar_model_and_tokenizer,
    get_author_embeddings,
)
# AUC_upper_bound = 1 - 2 * np.exp(-num_samples * delta**2)
# AUC_upper_bound = np.maximum(AUC_upper_bound, 0.)

def estimate_total_variation(
    num_samples: int,
    AUC: float,
):
    total_variation = sqrt((1 / num_samples) * log(2 / (1 - AUC)))
    total_variation = min(total_variation, 1)
    return total_variation

def main(
    data_file: str = "./data/MTD_reddit_12000_Mistral-7B-Instruct-v0.3_N=5.jsonl.iter=3",
    num_background: int = 1_000,
    machine_key: str = "transfer_pick",
    debug: bool = False,
):
    df = pd.read_json(data_file, lines=True)
    human = df["content_text"].tolist()
    random.shuffle(human)
    human_background = human[:num_background]
    human = human[num_background:]

    breakpoint()
    if isinstance(df[machine_key].iloc[0], list):
        machine = df[machine_key].apply(lambda x: x[0]).tolist()
    else:
        machine = df[machine_key].tolist()
        
    random.shuffle(machine)
    machine = machine[num_background:]
    
    midpoint = len(human) // 2
    human = human[:midpoint]
    machine = machine[midpoint:]
    if debug:
        human = human[:1_000]
        machine = machine[:1_000]
    
    num_samples = [1, 5, 10, 25, 50, 75, 100, 125, 150, 175, 200, 225, 250]
    model, tokenizer = load_luar_model_and_tokenizer()
    model.to("cuda")
    cossim = nn.CosineSimilarity()
    function_kwargs = {
        "luar": model,
        "luar_tok": tokenizer,
    }
    background_emb = get_author_embeddings(human_background, function_kwargs, "mud")

    records = []
    for N in num_samples:
        H = [human[idx:idx+N] for idx in range(0, len(human), N)]
        H = [h for h in H if len(h) == N]
        M = [machine[idx:idx+N] for idx in range(0, len(machine), N)]
        M = [m for m in M if len(m) == N]

        H = [get_author_embeddings(h, function_kwargs, "mud") for h in tqdm(H)]
        H = torch.cat(H, dim=0)
        M = [get_author_embeddings(m, function_kwargs, "mud") for m in tqdm(M)]
        M = torch.cat(M, dim=0)

        scores_H = (1 - cossim(background_emb, H)).tolist()
        scores_M = (1 - cossim(background_emb, M)).tolist()
        scores = scores_H + scores_M

        labels = [0] * len(scores_H) + [1] * len(scores_M)
        AUC = roc_auc_score(labels, scores)
        
        r = {
            "AUC": AUC,
            "num_samples": N,
            "scores": scores,
            "labels": labels,
        }
        records.append(r)

    outdir = "./tv"
    os.makedirs(outdir, exist_ok=True)
    savename = os.path.join(outdir, os.path.basename(data_file) + "_mkey={}".format(machine_key))
    with open(savename, "w+") as fout:
        fout.write(json.dumps(records))

if __name__ == "__main__":
    fire.Fire(main)