import random
import json
import os
random.seed(20230906)
content_special_token = '|<content>|'
template_list = [
    "Concatenate the last letters of the given words: |<content>|, |<content>|, |<content>|, |<content>|. Let's think step by step.\n" + \
    " 1. The last letter of |<content>| is |<content>|.\n" + \
    " 2. The last letter of |<content>| is |<content>|.\n" + \
    " 3. The last letter of |<content>| is |<content>|.\n" + \
    " 4. The last letter of |<content>| is |<content>|.\n" + \
    " 5. Concatenating these letters together, we get the asnwer: |<content>|. Therefore, the answer is |<content>|.",

    "Concatenate the last letters of the given words: |<content>|, |<content>|, |<content>|, |<content>|. Let's think step by step.\n" + \
    " 1. Word: |<content>|, last letter: |<content>|.\n" + \
    " 2. Word: |<content>|, last letter: |<content>|.\n" + \
    " 3. Word: |<content>|, last letter: |<content>|.\n" + \
    " 4. Word: |<content>|, last letter: |<content>|.\n" + \
    " Now, let us concatenate the last letters of each word: |<content>| + |<content>| + |<content>| + |<content>| = |<content>|. Therefore, the concatenated result is |<content>|.",

    " Concatenate the last letters of the given words: |<content>|, |<content>|, |<content>|, |<content>|. Let's think step by step.\n" + \
    " 1. Extract the last letters of each word: |<content>|: |<content>|, |<content>|: |<content>|, |<content>|: |<content>|, |<content>|: |<content>|.\n" + \
    " 2. Concatenate the extracted letters: |<content>| + |<content>| + |<content>| + |<content>| = |<content>|. Therefore, the answer is |<content>|.",
]
# 0-3: words; 4-7: letters; 8: answer
content_idx_list = [(0,1,2,3,0,4,1,5,2,6,3,7,8,8), (0,1,2,3,0,4,1,5,2,6,3,7,4,5,6,7,8,8), (0,1,2,3,0,4,1,5,2,6,3,7,4,5,6,7,8,8)]
prompt_len = [127, 127, 127]
vocab_file = 'vocab.txt'

def main(save_file, n_sample):

    vocab_list = []
    with open(vocab_file) as f:
        for line in f.readlines():
            vocab_list.append(line.strip())

    for i, (template, idx) in enumerate(zip(template_list, content_idx_list)):
        dataset = {}
        dataset['template'] = template
        dataset['prompt_len'] = prompt_len[i]
        dataset['content_list'] = []

        for _ in range(n_sample):
            # randomly select 4 different words from the vocab_list.
            words = random.sample(vocab_list, 4)
            letters = [word[-1] for word in words]
            answer = ''.join(letters)
            idx2content = {0: words[0], 1: words[1], 2: words[2], 3: words[3], 4: letters[0], 5: letters[1], 6: letters[2], 7: letters[3], 8: answer}
            content_list = [idx2content[i] for i in idx]
            dataset['content_list'].append(content_list)
        with open(os.path.join(save_file, f'dataset_{i}.json'), 'w') as f:
            json.dump(dataset, f, indent=4)

if __name__ == '__main__':
    main('.', 1000)