import os
import random
from itertools import islice
import sys

sys.path.append('/xxx/aaa/eva-clip/EVA-CLIP/rei')

import pickle

import tarfile

from description.LLM.llm_wraper_gemma import LLM_Wrapper

dataset_dir = r'/xxx/public_data/CC3M/CC3M_290W'
save_dir = r'/xxx/aaa/eva-clip/EVA-CLIP/rei/description/process_train/cc3m/wordscount'

prompt_instruction = [
    r'The following is a caption for a picture. Please identify all the items that may appear in the picture and output them separated by commas:\n{}']
batch_size = 20


def get_all_sentences_one_tar(path_tar, dict_sentence_to_nouns):
    all_sentences = []
    with tarfile.open(path_tar, 'r') as tar:
        for index, member in enumerate(tar.getmembers()):
            if member.name.endswith('.txt'):
                # 打开文件
                f = tar.extractfile(member)
                if f is not None:
                    # 读取文件内容
                    content = f.read()
                    sentence = content.decode('utf-8')
                    f.close()  # 关闭文件
                    if (sentence not in dict_sentence_to_nouns):
                        all_sentences.append(sentence)
    print('get number sentences:{}, from:{}'.format(len(all_sentences), path_tar))
    return all_sentences


def batch_generator(iterable, batch_size):
    it = iter(iterable)
    return iter(lambda: list(islice(it, batch_size)), [])


def process_batch_LLM_output(batch_results):
    ans = []
    for item_res in batch_results:
        final_str = item_res.split(':')[-1]
        final_str = final_str.strip()
        final_str = final_str.strip('*')
        all_list_items = final_str.split(',')
        ans.append(all_list_items)
    return ans


def get_saved_path(item_tar):
    pre, _ = item_tar.split('.')
    target_file = pre + '.pkl'
    path_save = os.path.join(save_dir, target_file)
    return path_save


def main():
    all_files = os.listdir(dataset_dir)
    random.shuffle(all_files)
    llm = LLM_Wrapper(instructions=prompt_instruction)

    for item_tar in all_files:
        print('start process tar:{}'.format(item_tar))
        path_save = get_saved_path(item_tar)
        if (os.path.exists(path_save) == True):
            print('exist:{}'.format(path_save))
            continue

        dict_sentence_to_nouns = {}
        dict_nouns_to_count = {}

        path_tar = os.path.join(dataset_dir, item_tar)
        all_sentences = get_all_sentences_one_tar(path_tar, dict_sentence_to_nouns)

        for index, batch_sentence in enumerate(batch_generator(all_sentences, batch_size)):
            if (index % 100 == 0):
                print('tar:{} index:{}'.format(item_tar, index))
            batch_results = llm(batch_sentence)
            batch_items = process_batch_LLM_output(batch_results)
            print('sampled sentence:{}, items:{}'.format(batch_sentence[0], batch_items[0]))
            for item_sentence, item_items in zip(batch_sentence, batch_items):
                # print(item_sentence, item_items)
                dict_sentence_to_nouns[item_sentence] = item_items
                for item_nouns in item_items:
                    if (item_nouns not in dict_nouns_to_count):
                        dict_nouns_to_count[item_nouns] = 1
                    else:
                        dict_nouns_to_count[item_nouns] += 1
        print('tar:{} finish'.format(item_tar))
        with open(path_save, 'wb') as f:
            pickle.dump([dict_nouns_to_count, dict_sentence_to_nouns], f)

    print('all finish')


if __name__ == '__main__':
    main()
