import datasets
import itertools
import random
import click
import torch
import pickle
from tqdm import tqdm
import openai
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
import datasets
import mobypy
from itertools import chain

model = None
tokenizer = None

def set_up_model(model_size):
    global model, tokenizer
    if model_size == "tiny":
        model_name = "EleutherAI/gpt-neo-125M"
    elif model_size == "small":
        model_name = "gpt2-large"
    elif model_size == "medium":
        model_name = "EleutherAI/gpt-neo-1.3B"
    elif model_size == "xmedium":
        model_name = "EleutherAI/gpt-neo-2.7B"
    elif model_size == "large":
        model_name = "EleutherAI/gpt-j-6B"
    else:
        print("unknown model size")
        import sys; sys.exit()
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        max_length=256).cuda()
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        use_fast=False)

def call(prompt, model, temperature=1.0, num_to_generate=1):
    input_ids = tokenizer(prompt,
            add_special_tokens=False, 
            return_tensors="pt").input_ids.cuda()
    generated_ids = model.generate(input_ids, 
            do_sample=True, 
            num_return_sequences=num_to_generate, 
            max_new_tokens=64,
            top_p=1.0,
            temperature=temperature
    )
    prompt_length = input_ids.shape[1]
    decoded_generations = tokenizer.batch_decode(
            generated_ids[:,prompt_length:], 
            skip_special_tokens=True
    )
    generations = [gen.split("Label:")[0].strip() 
            for gen in decoded_generations]
    return generations

def make_tacred_example(sample, token_to_ner_mapping):
    text = sample["text"]
    subj_start_token = [token for token, ner in token_to_ner_mapping.items() if ner == "SUBJ_START"][0]
    subj_end_token = [token for token, ner in token_to_ner_mapping.items() if ner == "SUBJ_END"][0]
    obj_start_token = [token for token, ner in token_to_ner_mapping.items() if ner == "OBJ_START"][0]
    obj_end_token = [token for token, ner in token_to_ner_mapping.items() if ner == "OBJ_END"][0]
    text = text.replace(subj_start_token + " ", "[Subject: ")
    text = text.replace(" " + subj_end_token, "]")
    text = text.replace(obj_start_token + " ", "[Object: ")
    text = text.replace(" " + obj_end_token, "]")
    for token, ner in token_to_ner_mapping.items():
        if "SUBJ_" not in ner and "OBJ_" not in ner:
            ner = ner.split("=")[1].lower().replace("_", " ")
            text = text.replace(token, ner)
    return {"text": text}
    #return {"text": sample["original_text"] + ", " + sample["subject"] + ", " + sample["object"]}

def slice_assign(string, span, replacement):
    start_idx, end_idx = span
    return string[:start_idx] + replacement + string[end_idx:]

def postprocess_tacred(generation, token_to_ner_mapping):
    new_mapping = {ner.split("=")[1].lower().replace("_", " "): token 
            for token, ner 
            in token_to_ner_mapping.items()
            if "SUBJ_" not in ner and "OBJ_" not in ner}
    subj_start_token = [token for token, ner in token_to_ner_mapping.items() if ner == "SUBJ_START"][0]
    subj_end_token = [token for token, ner in token_to_ner_mapping.items() if ner == "SUBJ_END"][0]
    obj_start_token = [token for token, ner in token_to_ner_mapping.items() if ner == "OBJ_START"][0]
    obj_end_token = [token for token, ner in token_to_ner_mapping.items() if ner == "OBJ_END"][0]
    if generation.count("[Subject:") != 1 or generation.count("[Object:") != 1 or generation.count("]") != 2:
        return None
    # Subject
    subj_start_idx = generation.index("[Subject:")
    if "]" not in generation[subj_start_idx:]:
        # didn't generate a closing bracket anywhere after the subject start
        return None
    subj_end_idx =  generation.index("]", subj_start_idx)
    generated_ner = generation[subj_start_idx + 9: subj_end_idx].strip().replace(" ", "_")
    if generated_ner not in new_mapping.keys():
        # didn't generate a real ner token
        return None
    subj_replacement = subj_start_token + " " + new_mapping[generated_ner] + " " + subj_end_token
    generation = slice_assign(generation, (subj_start_idx, subj_end_idx + 1), subj_replacement)
    # Object
    obj_start_idx = generation.index("[Object:")
    if "]" not in generation[obj_start_idx:]:
        # didn't generate a closing bracket anywhere after the object start
        return None
    obj_end_idx =  generation.index("]", obj_start_idx)
    generated_ner = generation[obj_start_idx + 8: obj_end_idx].strip().replace(" ", "_")
    if generated_ner not in new_mapping.keys():
        # didn't generate a real ner token
        return None
    obj_replacement = obj_start_token + " " + new_mapping[generated_ner] + " " + obj_end_token
    generation = slice_assign(generation, (obj_start_idx, obj_end_idx + 1), obj_replacement)
    return generation

def check(generation, dataset):
    #if "=>" in generation:
    if "\n" in generation:
        return False
    return True

@click.command()
@click.option("--dataset", type=str)
@click.option("--model_size", type=str, default="large")
@click.option("--split_number_to_generate", type=int, default=None)
@click.option("--true_split_number", type=int)
@click.option("--num_generations", type=int)
@click.option("--generation_batch_size", type=int)
@click.option("--output_dir_name", type=str, default=None)
@click.option("--match_labels_path", type=str, default=None)
@click.option("--seed", type=int, default=1)
@click.option("--ood_class_num_queries", type=int, default=5)
@click.option("--max_num_context", type=int, default=None)
@click.option("--do_moby", type=bool, default=False)
def main(dataset, split_number_to_generate, true_split_number, num_generations, output_dir_name, generation_batch_size, max_num_context, do_moby, model_size, ood_class_num_queries, seed, match_labels_path):
    # Parameters
    if dataset == "emotion":
        train_dataset_path = f"data/emotion/{true_split_number}/train/" # We will pull examples from here
        label_to_name_mapping = {
                0: "sadness",
                1: "joy",
                3: "anger",
                4: "fear"
        }
    elif dataset == "agnews":
        train_dataset_path = f"data/agnews/{true_split_number}/train/"
        label_to_name_mapping = {
                0: "world",
                1: "sports",
                2: "business",
                3: "sci/tech"
        }
    elif dataset == "trec10":
        train_dataset_path = f"data/trec10/{true_split_number}/train/"
        label_to_name_mapping = {
                0: "description",
                1: "entity",
                3: "human",
                4: "number",
                5: "location"
        }
    elif dataset == "tacred":
        train_dataset_path = f"data/tacred/{true_split_number}/train/"
        labels_path = f"data/tacred/{true_split_number}/labels_map.pkl"
        ner_tags_path = f"data/tacred/{true_split_number}/ner_tag_map.pkl"
        with open(labels_path, "rb") as f:
            name_to_label_mapping = pickle.load(f)
        # label_to_name_mapping = {lb: name for name, lb in name_to_label_mapping.items()}
        label_to_name_mapping = {0: []}
        for name, label in name_to_label_mapping.items():
            if label == 0:
                label_to_name_mapping[0].append(name)
            else:
                label_to_name_mapping[label] = name
        # Process label_to_name_mapping
        def normalize(name):
            return name.replace("org:", "").replace("per:", "").replace("_", " ")
        for label, name in label_to_name_mapping.items():
            if type(name) == list:
                label_to_name_mapping[label] = [normalize(n) for n in name]
            else:
                label_to_name_mapping[label] = normalize(name)
        with open(ner_tags_path, "rb") as f:
            ner_to_token_mapping = pickle.load(f)
        token_to_ner_mapping = {tok: ner for ner, tok in ner_to_token_mapping.items()}
    else:
        raise Exception("Dataset not found!")
    if split_number_to_generate is not None:
        split_to_generate = label_to_name_mapping[split_number_to_generate]
    true_heldout_split = label_to_name_mapping[true_split_number]
    context_per_class = 1 # Total examples = context_per_class * num_train_classes

    # Setup
    prompt_base = """Given a label, generate a corresponding example:\n"""
    openai.api_key = "YOUR KEY HERE"

    set_seed(seed)
    random.seed(seed)
    train_dataset = datasets.load_from_disk(train_dataset_path)
    train_labels = list(set(train_dataset["label"]))
    if dataset == "tacred":
        # Hack to combine the text + s/o together
        train_dataset = train_dataset.map(lambda x: make_tacred_example(x, token_to_ner_mapping))
    # Group dataset to make sampling in-context examples easier
    sorted_train_dataset = sorted(
            list(train_dataset), 
            key=lambda ex: ex["label"])
    grouped_train_dataset = itertools.groupby(
            sorted_train_dataset, 
            key=lambda ex: ex["label"])
    grouped_train_examples = {label: [ex["text"] for ex in examples] 
            for label, examples 
            in grouped_train_dataset}
    all_generations = []

    # 1. Retrieve some OOD classes
    if split_number_to_generate is None and match_labels_path is None:
        print("Lowercasing all classes!!!")
        ood_labels = set()
        for _ in range(ood_class_num_queries):
            if dataset == "emotion":
                prompt = """Generate a diverse list of emotions:\n["""
            elif dataset == "agnews":
                prompt = """Generate a diverse list of news genres:\n["""
            elif dataset == "trec10":
                prompt = """Generate a diverse list of entity types:\n["""
            elif dataset == "tacred":
                prompt = """Generate a diverse list of relations between entities:\n["""
            else:
                prompt = """Generate a complete list of labels for a dataset:\n["""
            for label in train_labels:
                prompt += label_to_name_mapping[label] + ", "
            # Query GPT3
            completion = openai.Completion.create(
                engine="text-davinci-002",
                prompt=prompt,
                temperature=0.9,
                max_tokens=64,
                stop=["Label:"])
            generation = completion.choices[0].text.strip()
            generation = generation.replace("]", "")
            ood_labels_generated = generation.split(",")
            ood_labels_generated = set([label.strip().lower()
                for label 
                in ood_labels_generated
                if label.strip().lower() not in label_to_name_mapping.values()
                    and "\n" not in label.strip() 
                    and label.strip() != ""])
            ood_labels = ood_labels.union(ood_labels_generated)

        ood_labels = list(ood_labels)
        print("*** Generating from the following labels! ***")
        print(ood_labels)
        if do_moby:
            num_labels_before = len(ood_labels)
            id_labels = [label for i, label in label_to_name_mapping.items() if i != true_split_number]
            synonyms = list(set(chain.from_iterable([mobypy.synonyms(label) for label in id_labels])))
            ood_labels= [label for label in ood_labels if label not in synonyms]
            num_labels_after = len(ood_labels)
            print("Removed", num_labels_after - num_labels_before, "labels!")
    elif match_labels_path is not None:
        print("Matching labels of", match_labels_path)
        reference_d = datasets.load_from_disk(match_labels_path)
        ood_labels = list(set(reference_d["label"]))
    else:
        print("*** Generating the ORACLE class: ***")
        print(split_to_generate)
    generations_label_reference = []
    set_up_model(model_size)
    # 2. Generate on these OOD classes
    #for step in tqdm(range(int(total_api_calls/generation_batch_size))):
    pbar = tqdm(total = num_generations)
    while len(all_generations) < num_generations:
        step = len(all_generations)
        pbar.n = step
        pbar.refresh()
        prompt = prompt_base
        # Add Context
        for _ in range(context_per_class):
            for i, label in enumerate(train_labels):
                if max_num_context is not None and i >= max_num_context:
                    break
                example = random.choice(grouped_train_examples[label])
                label_name = label_to_name_mapping[label]
                #prompt += f"Label: {label_name}\nExample: {example}\n"
                #prompt += f"{label_name} => {example}\n"
                prompt += f"{label_name}\n{example}\n"
        # Add Final Prompt
        if split_number_to_generate is not None:
            if type(split_to_generate) == list:
                ood_label = random.choice(split_to_generate) 
            else:
                ood_label = split_to_generate
            #prompt += f"{ood_label} => "
            prompt += f"{ood_label}\n"
        else:
            ood_label = random.choice(ood_labels)
            #prompt += f"Label: {ood_label}\nExample: "
            prompt += f"{ood_label}\n"
        if step == 0:
            print("*** Example Prompt: ***")
            print(prompt)
        # Query OPT
        completions = call(prompt, model, num_to_generate=generation_batch_size)
        for completion in completions:
            generation = completion.split("\n")[0].strip()
            generation = generation.strip("_")
            generation = generation.strip()
            if dataset == "tacred":
                generation = postprocess_tacred(generation, token_to_ner_mapping)
                if generation is None:
                    continue
            if generation != "" and check(generation, dataset):
                all_generations.append(generation)
                generations_label_reference.append(ood_label)

    print(f"*** Generated {len(all_generations)} examples total! ***")
    print("Some example generations:\n" + "\n".join(all_generations[:10]))
    # If tacred, do postprocessing
    # Convert to dataset
    generated_dataset = datasets.Dataset.from_dict({"text": all_generations, "label": generations_label_reference})
    #split_name = true_split_number.replace("/", "-")
    if split_number_to_generate is not None:
        generated_dataset.save_to_disk(f"{dataset}/{true_split_number}/{output_dir_name}")
        with open(f"{dataset}/{true_split_number}/{output_dir_name}/generations.txt", "w") as f:
            for line in all_generations:
                f.write(line + "\n")
    else:
        if output_dir_name is None:
            output_dir_name = "generations"
        generated_dataset.save_to_disk(f"{dataset}/{true_split_number}/{output_dir_name}")
        with open(f"{dataset}/{true_split_number}/{output_dir_name}/{output_dir_name}.txt", "w") as f:
            for line in all_generations:
                f.write(line + "\n")

if __name__ == "__main__":
    main()
