import json
import sys
import pickle

sys.path.append(r'/xxx/aaa/eva-clip/EVA-CLIP/rei')

import torch
from torch.multiprocessing import Process, Queue, Manager
import torch.multiprocessing as mp
from training.imagenet_zeroshot_data import imagenet_classnames

from description.LLM.llm_wraper_gemma import LLM_Wrapper

instructions = [
    # 'Provide a brief and concise description focusing only on the physical appearance of {}.',
    'Summarize the physical appearance of {} in a single, clear sentence.'
]
output_file = 'imagenet1k_short.json'

batch_size = 20


def process_sentence(text):
    prefix = "Sure, here is "
    if text.startswith(prefix):
        # 找到第一个冒号的位置
        print('Remove Sure...')
        print('origin:{}'.format(text))
        colon_index = text.find(":")
        if colon_index != -1:
            # 移除从开始到冒号（包括冒号）的部分
            ans = text[colon_index + 1:].strip()
            print('removed:{}'.format(ans))
            return ans
    return text


def get_batch(queue, batch_size):
    items_ans = []
    for i in range(batch_size):
        if (queue.empty()):
            return None
        item = queue.get()
        if (item is None):
            break
        items_ans.append(item)
    if (len(items_ans) == 0):
        return None
    return items_ans


def save_to_file(dict_word_to_des):
    print('save to file...')
    dict_word_to_des = dict(dict_word_to_des)
    with open(output_file, 'w') as f:
        json.dump(dict_word_to_des, f)


def worker_process(rank, queue, dict_word_to_des):
    print('rank:{}, start'.format(rank))
    torch.cuda.set_device(rank)
    llm = LLM_Wrapper(instructions=instructions, device_id=rank)
    while True:
        batch_words = get_batch(queue, batch_size)
        if (batch_words is None):
            print('rank:{} finish'.format(rank))
            break

        # print('rank:{} batch names:{}'.format(rank, batch_words))
        batch_results = llm(batch_words)
        # print('batch results:{}'.format(batch_results))
        print('sample name:{}, result:{}'.format(batch_words[-1], batch_results[-1]))
        for item_word, item_ans in zip(batch_words, batch_results):
            item_ans = process_sentence(item_ans)
            dict_word_to_des[item_word] = item_ans

    print('rank:{} finish'.format(rank))


def filter_words(all_words):
    ans = []
    for item in all_words:
        item = item.strip()
        if len(item) > 2 and item not in ['the', 'and']:
            ans.append(item)
    return ans


def main():
    # if mp.get_start_method(allow_none=True) is None:
    mp.set_start_method('spawn')

    manager = Manager()

    dict_word_to_des = manager.dict()
    queue = Queue()
    for item_word in imagenet_classnames:
        queue.put(item_word)

    all_workers = []
    for rank in range(8):
        item_process = Process(target=worker_process,
                               args=(rank, queue, dict_word_to_des))
        item_process.start()
        all_workers.append(item_process)
        queue.put(None)

    # 等待所有工作进程完成
    for worker in all_workers:
        worker.join()

    save_to_file(dict_word_to_des)


if __name__ == '__main__':
    main()
