import os
import random
import sys

sys.path.append('/xxx/aaa/eva-clip/EVA-CLIP/rei')

import pickle
import joblib
import numpy
import tqdm
import nltk
from training.imagenet_zeroshot_data import openai_imagenet_template

path_base_file = r'/xxx/public_data/CC3M/CC3M_AUGs/origin_mix_recaption.pkl'
path_word_to_des = r'/xxx/aaa/eva-clip/EVA-CLIP/rei/description/process_train/dict_word_to_des_nltk.pkl'
path_select_sentences = r'/xxx/aaa/eva-clip/EVA-CLIP/rei/description/process_train/final_words_selected_sentences.pkl'
path_sentence_to_words = r'/xxx/aaa/eva-clip/EVA-CLIP/rei/description/process_train/dict_nouns_2.pkl'
dirsave = r'/xxx/public_data/CC3M/CC3M_AUGs'

with open(path_base_file, 'rb') as f:
    base_dict_url_to_caption = pickle.load(f)

with open(path_word_to_des, 'rb') as f:
    dict_word_to_des = pickle.load(f)  # people??
print('before filter:{}'.format(len(dict_word_to_des)))

with open(path_select_sentences, 'rb') as f:
    final_words, selected_sentences = pickle.load(f)
    print('number selected_sentences:{}'.format(len(selected_sentences)))
    print('number final_words:{}'.format(len(final_words)))

with open(path_sentence_to_words, 'rb') as f:
    dict_nouns_to_count, dict_sentence_to_nouns = pickle.load(f)


# dict_url_to_caption_with_animageof_single = {}
# dict_url_to_caption_with_visualdes_single = {}
# dict_url_to_caption_with_all_single = {}
def generate_caption_animageof(nouns, select_one=False):
    ans = []
    str_nouns = ', '.join(nouns)
    if (select_one == False):
        for item_template in openai_imagenet_template:
            item_caption = item_template(str_nouns)
            ans.append(item_caption)
    else:
        item_caption = random.choice(openai_imagenet_template)(str_nouns)
        ans.append(item_caption)
    return ans


def generate_caption_visualdes(nouns):
    str_visual_des = []
    for item in nouns:
        if (item in dict_word_to_des):
            item_des = dict_word_to_des[item]
            str_visual_des.append(item_des)
    str_visual_des = ' '.join(str_visual_des)
    ans = []
    for item_template in openai_imagenet_template:
        item_caption = item_template(str_visual_des)
        ans.append(item_caption)
    return ans


dict_url_to_caption_with_animageof = {}
dict_url_to_caption_with_visualdes = {}
dict_url_to_caption_with_all = {}

len1 = []
len2 = []
len3 = []

for index, (item_url, item_list_captions) in enumerate(base_dict_url_to_caption.items()):
    if (index % 10000 == 0):
        print(index)
    caption_list_imageof = []
    caption_list_visualdes = []
    caption_list_all = []
    for item_caption in item_list_captions:
        caption_list_imageof.append(item_caption)
        caption_list_visualdes.append(item_caption)
        caption_list_all.append(item_caption)
        if (item_caption in selected_sentences):
            nouns_cur_caption = dict_sentence_to_nouns[item_caption]
            caption_imageof = generate_caption_animageof(nouns_cur_caption)
            caption_visualdes = generate_caption_visualdes(nouns_cur_caption)
            caption_list_imageof.append(caption_imageof)
            caption_list_visualdes.append(caption_visualdes)
            caption_list_all += [caption_imageof, caption_visualdes]
        elif item_caption in dict_sentence_to_nouns:
            nouns_cur_caption = dict_sentence_to_nouns[item_caption]
            caption_imageof = generate_caption_animageof(nouns_cur_caption, select_one=True)
            caption_list_imageof.append(caption_imageof)
            caption_list_all += [caption_imageof]

    dict_url_to_caption_with_animageof[item_url] = caption_list_imageof
    dict_url_to_caption_with_visualdes[item_url] = caption_list_visualdes
    dict_url_to_caption_with_all[item_url] = caption_list_all

    len1.append(len(caption_list_imageof))
    len2.append(len(caption_list_visualdes))
    len3.append(len(caption_list_all))

print('mean len1:{} len2:{}, len3:{}'.format(numpy.mean(len1), numpy.mean(len2), numpy.mean(len3)))

with open(os.path.join(dirsave, 'dict_url_to_caption_with_animageof.joblib'), 'wb') as f:
    joblib.dump(dict_url_to_caption_with_animageof, f)

with open(os.path.join(dirsave, 'dict_url_to_caption_with_visualdes.joblib'), 'wb') as f:
    joblib.dump(dict_url_to_caption_with_visualdes, f)

with open(os.path.join(dirsave, 'dict_url_to_caption_with_all.joblib'), 'wb') as f:
    joblib.dump(dict_url_to_caption_with_all, f)
