
import json
import os
from copy import deepcopy

import fire
import pandas as pd
import torch.nn as nn
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

OUTFOX_HEADER = """Here are the results of detecting whether each 
text sample is generated by a human or a language model (LM).
"""

BLOCK = """Sample: {}
Answer: {}
"""

OUTFOX_END = """Given the following problem statement,
please generate a sample detected as Human in around {} words.

Prompt: {}

Response:
"""

# Slightly modified prompts from those used to respond to adapt it 
# more closely to the outfox framework.
PROMPTS = {
    "reddit": "Write a response to this Reddit comment: {}",
    "amazon": "Here's an Amazon review: {} \nPlease write another review, but about something different.",
    "blogs": "Here's a snippet of a Blog post: {}. \nPlease write another snippet, but about something different.",
}

def main(
    retrieval_dataset: str = "/usr/WS1/rrivera/data/style_transfer/outfox/MTD_reddit_12000_correct.jsonl.ready",
    test_dataset: str = "/usr/WS1/rrivera/data/style_transfer/neurips/data/shards/MTD_reddit_12000_correct_Mistral-7B-Instruct-v0.3_N=5.jsonlshard-1-4",
    k: int = 8, # Number of human / machine exemplars, we're matching our method here
):
    dataset_name = None
    if "reddit" in retrieval_dataset:
        dataset_name = "reddit"
    elif "amazon" in retrieval_dataset:
        dataset_name = "amazon"
    elif "blogs" in retrieval_dataset:
        dataset_name = "blogs"
    else:
        assert False
    dataset_prompt = PROMPTS[dataset_name]

    retrieval_df = pd.read_json(retrieval_dataset, lines=True)
    test_df = pd.read_json(test_dataset, lines=True)

    sbert_name = "sentence-transformers/all-mpnet-base-v2"
    model = SentenceTransformer(sbert_name)
    model.cuda(); model.eval()

    human = retrieval_df["content_text"].tolist()
    human_pred_labels = retrieval_df["content_text_label"].tolist()
    machine = retrieval_df["respond_reddit"].tolist()
    machine_pred_labels = retrieval_df["respond_reddit_label"].tolist()
    
    # Generations to transform with OUTFOX:
    generations = test_df["respond_reddit"].tolist() 

    # Get all embeddings
    human_emb = model.encode(human, show_progress_bar=True, normalize=True, convert_to_tensor=True)
    machine_emb = model.encode(machine, show_progress_bar=True, normalize=True, convert_to_tensor=True)
    generations_emb = model.encode(generations, show_progress_bar=True, normalize=True, convert_to_tensor=True)

    cossim = nn.CosineSimilarity(dim=-1)
    prompts = []
    for i in tqdm(range(generations_emb.size(0))):
        g_emb = generations_emb[i:i+1, :]
        # Retrieval to get top-k nearest
        sims_human = cossim(human_emb, g_emb.repeat(len(human_emb), 1))
        sims_human = sims_human[sims_human != 1.]
        sims_machine = cossim(machine_emb, g_emb.repeat(len(machine_emb), 1))
        sims_machine = sims_machine[sims_machine != 1.]

        indices_human = sims_human.argsort(descending=True)[:k].tolist()
        indices_machine = sims_machine.argsort(descending=True)[:k].tolist()

        prompt = deepcopy(OUTFOX_HEADER)
        # Build the prompt using the results above
        for idx in indices_human:
            sample = human[idx]
            label = human_pred_labels[idx]
            label = "Human" if label == False else "LM"
            prompt += "\n" + BLOCK.format(sample, label)
        for idx in indices_machine:
            sample = machine[idx]
            label = machine_pred_labels[idx]
            label = "Human" if label == False else "LM"
            prompt += "\n" + BLOCK.format(sample, label)

        gen = generations[i]
        num_words = len(gen.split(" "))

        prompt += "\n" + OUTFOX_END.format(num_words, dataset_prompt.format(gen))
        prompts.append({"prompt": prompt})
    
    os.makedirs("/usr/workspace/rrivera/data/style_transfer/outfox/prompts", exist_ok=True)
    with open("/usr/workspace/rrivera/data/style_transfer/outfox/prompts/{}_prompts.jsonl".format(os.path.basename(test_dataset)), "w+") as fout:
        for p in prompts:
            fout.write(json.dumps(p))
            fout.write("\n")

    return 0

if __name__ == "__main__":
    fire.Fire(main)