import random
import numpy as np
import json
import tqdm
from collections import defaultdict

random.seed(111)
np.random.seed(111)

# randomlist = [chr(ii) for ii in range(33,127) if chr(ii) not in "'\",:;\\."]
randomlist = [chr(ii) for ii in range(65,123) if chr(ii) not in "'\",:;\\.[]`"] + ["-", "+"]

# TASKS = '''Given a list of words called 'dictionary', return the index of the word in the dictionary that has both the prefix 'pref' and the suffix 'suff'. If there is more than one valid index, return the largest of them. If there is no such word in the dictionary, return -1.\n\nFor example:\ndictionary = %s\n'''
# TURN = '''Given prefix = "%s"\nsuffix = "%s"\n\nWhat is the index of the word in the dictionary that has both the prefix and the suffix?\n'''
TASKS = '''Given a list of words called 'dictionary', return the word in the dictionary that has both the prefix 'pref' and the suffix 'suff'. If there is more than one valid answer, return the latest of them. \n\nFor example:\ndictionary = %s\n'''
TURN = '''Given prefix = "%s"\nsuffix = "%s"\n\nWhat is the word in the dictionary that has both the prefix and the suffix?\n'''
same_prefix = ""

def build_prefix_suffix_set():
    prefix_set, suffix_set = set(), set()
    for ii in range(50):
        prefix_set.add(generate_random_key(random_size=15))
        # suffix_set.add(generate_random_key(random_size=15))
    return list(prefix_set), list(suffix_set)

def generate_random_key(share_prefix_size: int = 0, random_size: int = 0, use_set: bool = False):
    global same_prefix
    if random_size == 0:
        random_size = random.randint(40, 100)
    idxs = np.random.randint(0, len(randomlist), random_size)
    if share_prefix_size > 0 and not same_prefix:
        same_prefix = "".join([randomlist[ii] for ii in idxs[:share_prefix_size]])
    if use_set:
        prefix = prefix_suffix_set[0][random.randint(0, 49)]
        suffix = prefix_suffix_set[1][random.randint(0, 49)]
        return prefix + "".join([randomlist[ii] for ii in idxs[15:-15]]) + suffix
        # return prefix + "".join([randomlist[ii] for ii in idxs[15:]])
    return same_prefix + "".join([randomlist[ii] for ii in idxs[share_prefix_size:]])


def build_one_example(multi_turns_ids: list, total: int = 2000, share_prefix_size: int = 0):
    news_d_list, new_multi_turns = [], []
    for ii in range(total):
        k = generate_random_key(share_prefix_size, random_size=25, use_set=False)
        news_d_list.append(k)
    for turn_id in multi_turns_ids:
        # prefix_size = random.randint(5, 10)
        # suffix_size = random.randint(5, 10)
        prefix_size, suffix_size = 5, 5
        prefix = news_d_list[turn_id][:prefix_size]
        suffix = news_d_list[turn_id][-suffix_size:]
        ans = -1
        for ii in range(total - 1, turn_id - 1, -1):
            if news_d_list[ii].startswith(prefix) and news_d_list[ii].endswith(suffix):
                ans = ii
                break
        new_multi_turns.append(
            {"input": TURN % (prefix, suffix), "answer": news_d_list[ans]}
        )
    context = TASKS % str(news_d_list)
    return {"context": context, "multi_turns": new_multi_turns}


def build_multi_turn_idxs(total: int = 2000, turn_size: int = 5, example_number: int = 100):
    target = defaultdict(list)
    for st in range(0, total, total // example_number):
        idxs = list(range(st, st +  total // example_number))
        np.random.shuffle(idxs)
        for i in range(turn_size):
            target[i].append(idxs[i])
    for i in range(turn_size):
        np.random.shuffle(target[i])
    return target


def build_data(total: int = 2000, share_prefix_size: int = 0, turn_size: int = 5, example_number: int = 100):
    multi_turn_idxs = build_multi_turn_idxs(total=total, turn_size=turn_size, example_number=example_number)

    new_data = [build_one_example([multi_turn_idxs[j][i] for j in range(turn_size)], total, share_prefix_size) for i in tqdm.tqdm(range(example_number))]
    with open("data/v1_prefix0_turn5_set_trie_tree_multi_turn_kv.jsonl", "w") as f:
        for example in new_data:
            json.dump(example, f)
            f.write('\n')
    
prefix_suffix_set = build_prefix_suffix_set()
if __name__ == "__main__":
    build_data(total=6000, share_prefix_size=0, turn_size=5)


