import random
import json
import os

content_special_token = '|<content>|'
template_list = [
    "Solve the arithmetic problem. Question: |<content>| found |<content>| |<content>| on the beach. she gave |<content>| some of her |<content>|. She has |<content>| |<content>| left. How many |<content>| did she give to |<content>|?\n Let's think step by step.\n" + \
    " 1. We know that |<content>| found |<content>| |<content>| on the beach.\n" + \
    " 2. She gave |<content>| some of her |<content>|.\n" + \
    " 3. She has |<content>| |<content>| left.\n" + \
    " We need to find out how many |<content>| |<content>| gave to |<content>|. To do this, we'll subtract the number of |<content>| she has left from the initial number of |<content>| she found:\n" + \
    " |<content>| - |<content>| = |<content>| \n|<content>| gave |<content>| |<content>| to |<content>|.", 
]

content_list = [['John', '70', 'seashells', 'Sam', '27', '43']]

content_dist_list = []
# sample 0
num_a = [random.randint(30, 100) for _ in range(9)]
num_b = [random.randint(1,a) for a in num_a]
content_dist = [['Emma', 'Olivia', 'Charlotte', 'Ava', 'Mia', 'Evelyn', 'Sophia', 'Amelia', 'Nora'],
                    list(map(str, num_a)), 
                    ['starfish', 'clams', 'crabs', 'kites', 'coconuts', 'coins', 'toys', 'suits', 'buckets'],
                    ['James', 'William', 'Lucas', 'Noah', 'Liam', 'Henry', 'Oliver', 'Elihah', 'Benjamin'],
                    list(map(str, num_b)),
                    [str(a-b) for (a,b) in zip(num_a, num_b)]]
# ...
content_dist_list.append(content_dist)

content_idx_list = [[0,1,2,3,2,4,2,2,3,0,1,2,3,2,4,2,2,0,3,2,2,1,4,5,0,5,2,3]]
prompt_len_list = [257]
# prompt_tc_list = [[1,1,1,1,1,0,1,0,0,1,1,1,1,1,0,1,1,1,0,1,1,0,0,1,1,1,0,1,1,1,1,0,1,1,1,1,1]]

def main(save_file):

    for i, (template, content, content_dist, content_idx) in enumerate(zip(template_list, content_list, content_dist_list, content_idx_list)):
        dataset = {}
        dataset['template'] = template_list[i]
        dataset['prompt_len'] = prompt_len_list[i]
        dataset['content_list'] = []
        dataset['content_list'].append([content[i] for i in content_idx])
        content_dist = list(zip(*content_dist))
        for dist in content_dist:
            dataset['content_list'].append([dist[i] for i in content_idx])
        with open(os.path.join(save_file, f'dataset_{i}.json'), 'w') as f:
            json.dump(dataset, f, indent=4)

if __name__ == '__main__':
    main('.')