import sys

sys.path.append(r'/xxx/aaa/eva-clip/EVA-CLIP/rei')

from itertools import islice
import json

# from description.LLM.llm_wraper_qwen import LLM_Wrapper
from description.LLM.llm_wraper_gemma import LLM_Wrapper
from training.imagenet_zeroshot_data import imagenet_classnames

instructions = [
    # 'Provide a brief and concise description focusing only on the physical appearance of {classname}.',
    # 'Summarize the physical appearance of {} in a single, clear sentence in English.'
    # 'Summarize the physical appearance of {} in a single, clear sentence in English using common words without any chinese.'
    # 'Summarize what a {} looks like in one clear and straightforward sentences.'
    # 'Use common words to describe the appearance of {classname} in one sentence. Directly output the results without any hint such as “Sure, here is...” at the beginning. No “Sure, here is...”. The description should first explain the general category of the category. Here is the example: Tench are greenish, chunky freshwater fish with small scales and a pair of barbels near their mouths. Next, directly describe the appearance of {classname}:'
    'Summarize the physical appearance of {} in a single, clear sentence.'
]

filename_save = 'imagenet1k_short.json'

llm = LLM_Wrapper(instructions=instructions)

batch_size = 20

try:
    with open(filename_save, 'r') as f:
        all_results_dict = json.load(f)
    print('load:{}'.format(len(all_results_dict)))
except:
    all_results_dict = {}


# print('carbonara' in all_results_dict)

def batch_generator(iterable, batch_size):
    it = iter(iterable)
    return iter(lambda: list(islice(it, batch_size)), [])


def process_sentence(text):
    prefix = "Sure, here is "
    if text.startswith(prefix):
        # 找到第一个冒号的位置
        print('Remove Sure...')
        print('origin:{}'.format(text))
        colon_index = text.find(":")
        if colon_index != -1:
            # 移除从开始到冒号（包括冒号）的部分
            ans = text[colon_index + 1:].strip()
            print('removed:{}'.format(ans))
            return ans
    return text


left_names = []
for item in imagenet_classnames:
    if item not in all_results_dict:
        left_names.append(item)
print('left number:{}'.format(len(left_names)))
print('left :{}'.format(left_names))

for batch_classnames in batch_generator(left_names, batch_size):
    print('batch names:{}'.format(batch_classnames))
    batch_results = llm(batch_classnames)
    print('batch results:{}'.format(batch_results))
    for item_classname, item_ans in zip(batch_classnames, batch_results):
        item_ans = process_sentence(item_ans)
        all_results_dict[item_classname] = item_ans

    with open(filename, 'w') as f:
        json.dump(all_results_dict, f)

with open(filename, 'w') as f:
    json.dump(all_results_dict, f)
