from transformers import BertTokenizer
import os
import csv
from tqdm import tqdm
import json

if __name__ == '__main__':
    max_length = 60
    data_dir = 'data'
    saving_data_dir = 'data/multiple_attribute'
    os.makedirs(saving_data_dir, exist_ok=True)
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    data_suffixes = ['valid', 'train', 'test']
    index = 8
    balanced = 'balanced'
    for name in tqdm(data_suffixes):
        file_base = '{}.fader.with_cat.proc.{}0000'
        filename = file_base.format(name, index)
        with open(saving_data_dir + '/processed/' + filename + '{}_category.json'.format(balanced),
                  encoding="utf-8") as file:
            print(saving_data_dir + '/processed/' + filename + '{}_category.json'.format(balanced))
            dict_ = json.load(file)

        dict_tensor = {}
        for key, value in dict_.items():
            tensors = []
            for line in tqdm(value):
                tokenized_string = tokenizer.encode(line, max_length=max_length)
                tensors.append(tokenized_string + [tokenizer.pad_token_id] * (max_length - len(tokenized_string)))
            dict_tensor[key] = tensors
        with open(os.path.join(saving_data_dir, 'final_data_category_{}_dict.{}.json'.format(name, balanced)), 'w',
                  encoding="utf-8") as file:
            json.dump(dict_tensor, file)
