import random
import json
import os

content_special_token = '|<content>|'
template_list = [
    "Solve the arithmetic problem. Question: There were |<content>| |<content>| in the vase. |<content>| cut some more |<content>| from her flower garden. There are now |<content>| |<content>| in the vase. How many |<content>| did she cut ?\n A: Let's think step by step.\n 1. Start with the number of |<content>| in the vase before |<content>| cut any: |<content>|.\n 2. We wonder how many |<content>| |<content>| cut from her garden. Let's call this number \"x\".\n 3. Add the number of |<content>| |<content>| cut to the original number in the vase: |<content>| + x.\n 4. This equals the total number of |<content>| in the vase after |<content>| cut some: |<content>|.\n 5. So we can set up an equation: |<content>| + x = |<content>|.\n 6. To solve for x, we need to isolate it on one side of the equation. We can do this by subtracting |<content>| from both sides: x = |<content>| - |<content>| = |<content>|.\n 7. Therefore, |<content>| cut |<content>| |<content>| from her garden.Therefore, the answer (arabic number) is |<content>|." 
]

content_list = [['6', 'roses', 'Mary', '16', '10']]

content_idx_list = [[0,1,2,1,3,1,1,1,2,0,1,2,1,2,0,1,2,3,0,3,0,3,0,4,2,4,1,4]]

content_dist_list = []
# sample 0
num_a = [random.randint(1, 30) for _ in range(9)]
num_b = [random.randint(1, 30) for _ in range(9)]
content_dist = [
    list(map(str, num_a)),
    ['tulips', 'sunflowers', 'daisies', 'lilies', 'daffodils', 'pansies', 'petunias', 'marigolds', 'irises'],
    ['Emma', 'Olivia', 'Charlotte', 'Ava', 'Mia', 'Evelyn', 'Sophia', 'Amelia', 'Nora'],
    list(map(str, num_b)),
    [str(a-b) for (a,b) in zip(num_a, num_b)]
    ]
# ...
content_dist_list.append(content_dist)

prompt_len_list = [266]
# prompt_tc_list = [[1,1,1,1,1,0,1,0,0,1,1,1,1,1,0,1,1,1,0,1,1,0,0,1,1,1,0,1,1,1,1,0,1,1,1,1,1]]

def main(save_file):

    for i, (template, content, content_dist, content_idx) in enumerate(zip(template_list, content_list, content_dist_list, content_idx_list)):
        dataset = {}
        dataset['template'] = template_list[i]
        dataset['prompt_len'] = prompt_len_list[i]
        dataset['content_list'] = []
        print(content_idx)
        print(content)
        dataset['content_list'].append([content[i] for i in content_idx])
        content_dist = list(zip(*content_dist))
        for dist in content_dist:
            dataset['content_list'].append([dist[i] for i in content_idx])
        with open(os.path.join(save_file, f'dataset_{i+1}.json'), 'w') as f:
            json.dump(dataset, f, indent=4)

if __name__ == '__main__':
    main('.')
