# TODO: Remove all references to Reddit if we ever decide to change to StackXchange

import os
from glob import glob

import fire
import numpy as np
import pandas as pd
import sentence_transformers.util as util
import torch
from tqdm import tqdm

from embedding_utils import (
    load_luar_model_and_tokenizer,
    get_author_embeddings,
    get_instance_embeddings,
)
from utils import DATA_PATH

# Written by Gemini:
def merge_elements(
    list_of_lists: list[list]
) -> list[list]:
    """Merges the corresponding elements of each list together."""
    result = []
    max_length = max(len(lst) for lst in list_of_lists)

    for i in range(max_length):
        result.append([lst[i] if i < len(lst) else None for lst in list_of_lists])

    return result

def parse_params_from_filename(
    filename: str,
):
    params = os.path.basename(filename).replace(".jsonl", "").split("_")
    
    model_name = params[1]
    decoding_params = {
        "temperature": params[2].split("=")[1],
        "top_p": params[3].split("=")[1],
    }
    watermarking_alg = None
    watermarking_params = None
    if len(params) > 4:
        watermarking_alg = params[4].split("=")[1]
        watermarking_params = dict(zip(
            [p.split("=")[0] for p in params[5:]],
            [p.split("=")[1] for p in params[5:]],
        ))

    return model_name, decoding_params, watermarking_alg, watermarking_params

def main(
    filename: str,
    gen_key: str = "respond_reddit",
):
    name = os.path.splitext(filename)[0].replace("_", "-")
    data_fname = os.path.join(DATA_PATH, "mtd", filename)
    assert os.path.isfile(data_fname)
    generations_filenames = glob(os.path.join(DATA_PATH, "mtd", "generations", f"{name}*"))
    # TODO:
    generations_filenames = [fname for fname in generations_filenames if "wm" not in fname]
    
    df = pd.read_json(data_fname, lines=True)
    responses = []
    model_names = []
    decoding_params = []
    watermarking_algs = []
    watermarking_params = []
    
    for gfname in generations_filenames:
        gen_name, d_params, wm_alg, wm_params = \
            parse_params_from_filename(gfname)

        df_gen = pd.read_json(gfname, lines=True)
        M = len(df_gen)
        
        decoding_params.append([d_params for _ in range(M)])
        model_names.append([gen_name for _ in range(M)])
        responses.append(df_gen[gen_key].tolist())
        watermarking_algs.append([wm_alg for _ in range(M)])
        watermarking_params.append([wm_params for _ in range(M)])

    responses = merge_elements(responses)
    model_names = merge_elements(model_names)
    decoding_params = merge_elements(decoding_params)
    watermarking_algs = merge_elements(watermarking_algs)
    watermarking_params = merge_elements(watermarking_params)

    df[gen_key] = responses
    df["model_name"] = model_names
    df["decoding_param"] = decoding_params
    df["watermarking_alg"] = watermarking_algs
    df["watermarking_param"] = watermarking_params

    df = df.explode(
        [gen_key, "model_name", "decoding_param", "watermarking_alg", "watermarking_param"]
    )

    # unique_reference_text = []
    # unique_author_ids = []
    # seen = set()
    # for i in range(len(df)):
    #     reference = df.iloc[i]["reference_text"]
    #     if "_".join(reference) in seen:
    #         continue
    #     seen.add("_".join(reference))
    #     unique_reference_text.append(reference)
    #     unique_author_ids.append(df.iloc[i]["author_id"])

    # luar, luar_tok = load_luar_model_and_tokenizer("rrivera1849/LUAR-MUD")
    # luar.to("cuda")
    # function_kwargs = {"luar": luar, "luar_tok": luar_tok}
    # reference_embeddings = [
    #     get_author_embeddings(reference, function_kwargs, "mud") for reference in tqdm(unique_reference_text)
    # ]
    # reference_embeddings = torch.cat(reference_embeddings, dim=0)
    # function_kwargs["progress_bar"] = True
    # response_embeddings = get_instance_embeddings(df[gen_key].tolist(), function_kwargs, "mud")

    # similarities_to_authors = []
    # similarities = util.cos_sim(response_embeddings, reference_embeddings)
    # percentile_values = np.linspace(0, 1, 6)
    # for i in range(similarities.shape[0]):
    #     sims = similarities[i].cpu().numpy()
    #     s2a = []
    #     for q in percentile_values:
    #         # inspired by: https://stackoverflow.com/questions/26070514/how-do-i-get-the-index-of-a-specific-percentile-in-numpy-scipy
    #         percentile = np.quantile(sims, q=q, method="nearest")
    #         min_index = int(abs(sims - percentile).argmin())
    #         s2a.append((unique_author_ids[min_index], sims[min_index]))
    #     similarities_to_authors.append(s2a)
    # df["similarities_to_author"] = similarities_to_authors
    savename = data_fname + ".merged"
    df.to_json(savename, lines=True, orient="records")

    return 0

if __name__ == "__main__":
    fire.Fire(main)