from collections import Counter
from tqdm import tqdm
import json

if __name__ == '__main__':
    max_length = 70
    categories = ['american', 'asian', 'bar', 'dessert', 'mexican']
    for name in tqdm(['train', 'valid', 'test']):
        for index in tqdm([8]):
            file_base = '{}.fader.with_cat.proc.{}0000'
            filename = file_base.format(name, index)
            with open(filename, 'r', encoding="utf-8") as file:
                lines = file.readlines()
            dict_ = {'american': [], 'asian': [], 'bar': [], 'dessert': [], 'mexican': []}

            for line in tqdm(lines, desc='Lines'):
                content = line.split('\t')[0]
                category = str(line.split('\t')[-1]).replace('\n', '')
                if category in categories and len(content.split(' ')) < max_length:
                    dict_[category].append(content)
            min_length = 1000000000
            for key, value in dict_.items():
                if len(value) < min_length:
                    min_length = len(value)
            for key, value in dict_.items():
                dict_[key] = value[:min_length]
            with open('processed/' + filename + 'balanced_category.json', 'w', encoding="utf-8") as file:
                print('processed/' + filename + 'balanced_category.json')
                json.dump(dict_, file)

            print("STATS : -------------")
            print("Filename {}".format(filename))
            print("Before nb_lines : {}".format(len(lines)))
            print("After nb_lines : {}".format(sum([len(value) for _, value in dict_.items()])))
