
import json
import os
import random
from math import isnan

import numpy as np
import pandas as pd
import torch
from hyppo.ksample import MMD
from sentence_transformers import SentenceTransformer
from transformers import AutoModel, AutoTokenizer, set_seed

from genpaths import HUMAN, MACHINE, LLMOPT, OURS
from nicks_dpo.create_preference_data import get_luar_embeddings
from results_mtd import load

def load_text(arr):
    # [2] -> Reddit [3] -> Generation
    text = pd.read_json(arr[2], lines=True)[arr[3]].tolist()
    if isinstance(text[0], list):
        text = [j[0] for j in text]
    return text

def calculate_MMDs(
    emb_1: torch.Tensor, 
    emb_2: torch.Tensor, 
    N: int,
):
    # emb_1 = emb_1[torch.randperm(emb_1.size(0))]
    # emb_2 = emb_2[torch.randperm(emb_2.size(0))]
    emb_1 = emb_1.view(-1, N, emb_1.size(-1))
    emb_2 = emb_2.view(-1, N, emb_2.size(-1))
    emb_1 = emb_1.mean(dim=1).numpy()
    emb_2 = emb_2.mean(dim=1).numpy()
    emb_1 = emb_1[:100]
    emb_2 = emb_2[:100]
    mmd_test_statistic, p_value = MMD().test(emb_1, emb_2, workers=-1)
    return mmd_test_statistic, p_value

def main():
    set_seed(43)
    
    outdir = "./MMD"
    os.makedirs(outdir, exist_ok=True)
    Ns = [1, 5, 10, 25, 50] # same as in all experiments

    # human_data = load(MACHINE, 2)
    # llmopt_data = load(LLMOPT, 2)
    # sparaphrase_data = load(OURS, 2)

    # order = list(sorted(human_data.keys()))
    # def get_vector(data: dict, N: int):
    #     values = []
    #     for detector in order:
    #         d = [num for num in data[detector] if not isnan(num)]
    #         random.shuffle(d)
    #         values.append([sum(d[i:i+N])/N for i in range(0, len(d)-N+1, N)])
    #     vectors = list(zip(*values))
    #     vectors = np.array(vectors)
    #     return vectors

    # all_data = {}
    # all_data["LLMOPT"] = {}
    # all_data["Ours"] = {}
    # for N in Ns:
    #     human = get_vector(human_data, N)
    #     llmopt = get_vector(llmopt_data, N)
    #     ours = get_vector(sparaphrase_data, N)

    #     human = human[:100]
    #     llmopt = llmopt[:100]
    #     ours = ours[:100]
    #     mmd_test_statistic_1, p_value_1 = MMD().test(human, llmopt, workers=-1)
    #     mmd_test_statistic_2, p_value_2 = MMD().test(human, ours, workers=-1)
    #     all_data["LLMOPT"][N] = (mmd_test_statistic_1, p_value_1)
    #     all_data["Ours"][N] = (mmd_test_statistic_2, p_value_2)
    #     print(N)
    #     print("LLMOPT", all_data["LLMOPT"][N])
    #     print("Ours", all_data["Ours"][N])
    #     print()

    # breakpoint()

    human_data = load_text(MACHINE)
    llmopt_data = load_text(LLMOPT)
    sparaphrase_data = load_text(OURS)

    # We use 100 samples for each value of N (posts we aggregate across), thus we need 
    # 5000 samples to have 100 samples when we're aggregating across the maximum (N=50):

    MODELS = ["LUAR", "CISR", "SD"]
    MODELS = ["LUAR"]
    all_data = {}
    for model_name in MODELS:
        savename = os.path.join(outdir, "MMD_{}.json".format(model_name))
        if os.path.exists(savename):
            all_data[model_name] = json.loads(open(savename).read())
            continue
    
        print("Calculating MMD for {}".format(model_name))
        # Calculate the embeddings:
        if model_name != "LUAR":
            HF_id = "AnnaWegmann/Style-Embedding" if model_name == "CISR" else "StyleDistance/styledistance"
            model = SentenceTransformer(HF_id)
            model.eval(); model.cuda()
            human_emb = model.encode(human_data, progress_bar=True, convert_to_tensor=True, normalize_embeddings=True).cpu()
            llmopt_emb = model.encode(llmopt_data, progress_bar=True, convert_to_tensor=True, normalize_embeddings=True).cpu()
            sparaphrase_emb = model.encode(sparaphrase_data, progress_bar=True, convert_to_tensor=True, normalize_embeddings=True).cpu()
        else:
            model = AutoModel.from_pretrained("rrivera1849/LUAR-MUD", trust_remote_code=True)
            model.eval(); model.cuda()
            tokenizer = AutoTokenizer.from_pretrained("rrivera1849/LUAR-MUD", trust_remote_code=True)
            human_emb = get_luar_embeddings(human_data, model, tokenizer, batch_size=1024, single=False).cpu()
            llmopt_emb = get_luar_embeddings(llmopt_data, model, tokenizer, batch_size=1024, single=False).cpu()
            sparaphrase_emb = get_luar_embeddings(sparaphrase_data, model, tokenizer, batch_size=1024, single=False).cpu()

        data = {}
        for N in Ns:
            print(human_emb.size(), llmopt_emb.size(), sparaphrase_emb.size())
            mmd_test_statistic_1, p_value_1 = calculate_MMDs(human_emb, llmopt_emb, N)
            mmd_test_statistic_2, p_value_2 = calculate_MMDs(human_emb, sparaphrase_emb, N)
            data[N] = {}
            data[N]["LLMOPT"] = (mmd_test_statistic_1, p_value_1)
            data[N]["Ours"] = (mmd_test_statistic_2, p_value_2)
        with open(savename, "w+") as fout:
            fout.write(json.dumps(data, indent=4))

        all_data[model_name] = data

    rows = []
    for m in MODELS:
        with open(os.path.join(outdir, f"MMD_{m}.json")) as f:
            data = json.load(f)
        for N, comps in data.items():
            for comp in ["LLMOPT", "Ours"]:
                stat, p = comps[comp]
                rows.append({"Model": m, "N": int(N), "Comparison": comp, "MMD": stat, "p": float(p)})

    df = pd.DataFrame(rows)

    return 0

if __name__ == "__main__":
    
    main()