import pickle
import os
import sys
import tqdm

dir_save = r'/xxx/aaa/eva-clip/EVA-CLIP/rei/description/process_train/cc3m/wordscount'

all_pkls = os.listdir(dir_save)
print('number pkls:{}'.format(len(all_pkls)))

dict_nouns_to_count = {}
dict_sentence_to_nouns = {}

for item_pkl in tqdm.tqdm(all_pkls):
    path_pkl = os.path.join(dir_save, item_pkl)
    with open(path_pkl, 'rb') as f:
        item_dict_nouns_to_count, item_sentence_to_nouns = pickle.load(f)
    for item_noun, item_count in item_dict_nouns_to_count.items():
        item_noun = item_noun.lower()
        if (item_noun not in dict_nouns_to_count):
            dict_nouns_to_count[item_noun] = item_count
        else:
            dict_nouns_to_count[item_noun] += item_count
    dict_sentence_to_nouns.update(item_sentence_to_nouns)

print('dict_nouns_to_count:{}'.format(len(dict_nouns_to_count)))
print('dict_sentence_to_nouns:{}'.format(len(dict_sentence_to_nouns)))

list_nouns_count = list(dict_nouns_to_count.items())
list_nouns_count = sorted(list_nouns_count, key=lambda x: x[1], reverse=False)


# keep_radio = 0.6
# keep_len = int(len(list_nouns_count) * keep_radio)
# print('keep number:{}'.format(keep_len))
# list_nouns_count = list_nouns_count[:keep_len]
# print(list_nouns_count[-1])

def contains_digit(s):
    return any(char.isdigit() for char in s)


keep_nouns = []
for item_noun, item_count in dict_nouns_to_count.items():
    if (item_count <= 5 and len(item_noun) < 20 and contains_digit(item_noun) == False):
        keep_nouns.append(item_noun)
print('len keep_nouns_count:{}'.format(len(keep_nouns)))

with open('keep_nouns.pkl', 'wb') as f:
    pickle.dump(keep_nouns, f)


# print(keep_nouns[:200])


# print(list_nouns_count[:1000])
# print(list_nouns_count[-1000:])

# for index, (item_key, item_list_values) in enumerate(dict_sentence_to_nouns.items()):
#     if (index == 10):
#         break
#     print(item_key)
#     print(item_list_values)
#     print()

# import sys

# sys.exit()

# keep_radio = 0.3

# keep_number = int(len(list_nouns_count) * keep_radio)
# print('keep number:{}'.format(keep_number))
# list_nouns_count = list_nouns_count[:keep_number]
# keep_words = []
# for item_word, item_count in list_nouns_count:
#     if (item_count == 1):
#         continue
#     keep_words.append(item_word)

# keep_words = set(keep_words)

# final_words = set()
# selected_sentences = []
# for index, (item_sentence, nouns_cur_sentence) in enumerate(dict_sentence_to_nouns.items()):
#     if (index % 100 == 0):
#         print('index:{}'.format(index))
#     for item_noun in nouns_cur_sentence:
#         if (item_noun in keep_words):
#             final_words.update(nouns_cur_sentence)
#             selected_sentences.append(item_sentence)
#             break

# final_words = list(final_words)
# print('number of final words:{}'.format(len(final_words)))
# print('number of selected_sentences:{}'.format(selected_sentences))
# with open('final_words_selected_sentences.pkl', 'wb') as f:
#     pickle.dump([final_words, selected_sentences], f)
