import torch
from .utils import *
import warnings
warnings.filterwarnings("ignore")
import re
from bert_score import score
from PIL import Image, ImageOps
from fuzzywuzzy import fuzz

def load_image(image_path):
    image = Image.open(image_path)
    image = ImageOps.exif_transpose(image)
    image = image.convert("RGB")

    return image

def entropy(f):
    """
    Args:
        f: Top 5 elements of the output of LM with shape (B, 5)

    Returns:
        Activate_vactor x
    """
    pro = torch.softmax(f, dim=-1)
    e = -pro * torch.log(pro)
    e = torch.sum(e, dim=-1)
    activate_vactor = (e > 1.554)       #reference 1.5002 to 1.6094

    return activate_vactor

def score_ratio(f):
    """
    Args:
        f: Top 5 elements of the output of LM with shape (B, 5)

    Returns:
        Activate_vactor x
    """

    f_ = f.detach()
    f_row = f_.unsqueeze(1)
    f_col = torch.flip(f_, dims=[-1]).unsqueeze(-1)

    f_ = f_col / f_row

    mask = torch.tril(torch.ones_like(f_), diagonal=-1)
    mask = torch.flip(mask, dims=[1])

    f_ = f_ * mask
    f_ = f_[f_ != 0]
    f_ = f_.reshape(f.shape[0], -1)
    activate_vactor = (0 < torch.mean(f_, dim=-1)) & (torch.mean(f_, dim=-1) < 0.4)

    return activate_vactor

def uniformly(f):
    return entropy(f) & ~score_ratio(f)

def n_shot_top5_distr(**kwargs):
    processor = kwargs['vlm_processor']
    image = kwargs['image']
    label = kwargs['label']
    lang_x = kwargs['lang_x']
    data = kwargs['dataset']
    model = kwargs['vlm']
    if 'for_generating_preference' in kwargs:
        for_generating_preference = True
    else:
        for_generating_preference = False

    activate_vector = []

    if 'cls_shot' in kwargs:
        x = kwargs['cls_shot']
        if not isinstance(x, list):
            x = [x]

        for (c, shot) in x:
            if for_generating_preference:
                _, base_image, base_lang_x, base_label = data.n_shot(image, lang_x, label, shot=shot, cls=c)
            else:
                base_image, base_lang_x, base_label = data.n_shot(image, lang_x, label, shot=shot, cls=c)

            texts = [processor.apply_chat_template(
                        message, tokenize=False, add_generation_prompt=True
                    ) for message in base_lang_x]

            inputs = processor(images=base_image,
                                text=texts,
                                return_tensors="pt",
                                padding=True).to(model.device)

            for i in range(1):
                output = model(**inputs)
                next_token = torch.softmax(output.logits[:, -1], -1)
                # the_token = torch.argmax(next_token, dim=-1)
                # inputs['input_ids'] = torch.cat([inputs['input_ids'], the_token[:, None]], dim=-1)
                # inputs['attention_mask'] = torch.cat([inputs['attention_mask'],
                #                                       torch.ones((next_token.shape[0], 1),
                #                                     device=inputs['attention_mask'].device)], dim=-1)

            top_5_pro = torch.topk(next_token, 5)[0]
            activate_vector.append(score_ratio(top_5_pro))
            activate_vector.append(entropy(top_5_pro))
            activate_vector.append(uniformly(top_5_pro))

            del inputs
            torch.cuda.empty_cache()

    return torch.stack(activate_vector, dim=-1)

def is_cls_animal(**kwargs):
    cls = kwargs['predcted_cls']
    llm = kwargs['llm']
    llm_processor = kwargs['llm_processor']
    # prompt = kwargs['prompt']
    # cls = get_vlm_response(prompt, cls)
    cls = [c.replace('.', '').replace('a ', '').replace('an ', '').replace('An ', '').replace('A ', '') for c in cls]
    image = kwargs['image']
    vlm_processor = kwargs['vlm_processor']
    vlm = kwargs['vlm']

    image = [load_image(i) for i in image]

    activate_vector = []

    messages = [[{'role': 'user',
                 'content': f'Is {c} an animal? 1(yes) 0(no):'}] for c in cls
                ]

    llm_prompt = llm_processor.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(llm.device)
    out = llm.generate(input_ids=llm_prompt, max_new_tokens=10)
    output_text = llm_processor.batch_decode(out, skip_special_tokens=True)
    response = get_llm_response(output_text)

    # del llm
    torch.cuda.empty_cache()

    messages = [
        [{'role': 'user',
          'content': [
              {'type': 'image'},
              {'type': 'text', 'text': 'Is the object in this image an animal? If yes print 1, if no print 0:'}
                      ]
          }] for _ in cls
    ]

    texts = [vlm_processor.apply_chat_template(
        message, tokenize=False, add_generation_prompt=True
    ) for message in messages]

    inputs = vlm_processor(images=image,
                                text=texts,
                                return_tensors="pt",
                                padding=True).to(vlm.device)
    generated_ids = vlm.generate(**inputs, max_new_tokens=25)
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    generated_texts = vlm_processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)

    for v, l in zip(generated_texts, response):
        if '1' in v and '1' in l:
            activate_vector.append(1.)
        elif '0' in v and '0' in l:
            activate_vector.append(1.)
        else:
            activate_vector.append(0.)

    return torch.tensor(activate_vector, device=vlm.device).unsqueeze(1)

def where_live(**kwargs):
    vlm_processor = kwargs['vlm_processor']
    vlm = kwargs['vlm']
    image = kwargs['image']
    llm_processor = kwargs['llm_processor']
    llm = kwargs['llm']
    cls = kwargs['predcted_cls']
    cls = [c.replace('.', '').replace('a ', '').replace('an ', '') for c in cls]
    image = [load_image(i) for i in image]
    activate_vector = []

    messages = [[{'role': 'user',
                 'content': f'Where does {c} live, you can answer by choosing one of (land, water, sky).'}] for c in cls
                ]

    llm_prompt = llm_processor.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(llm.device)
    out = llm.generate(input_ids=llm_prompt, max_new_tokens=10)
    output_text = llm_processor.batch_decode(out, skip_special_tokens=True)
    response = get_llm_response(output_text)

    prompt = 'Do the objects in this picture live in water, on land, or in the sky? Your final answer should be put between two ##, like ## sky ## (if your final answer is sky), at the end of your response.'

    torch.cuda.empty_cache()

    messages = [
        [{'role': 'user',
          'content': [
              {'type': 'image'},
              {'type': 'text', 'text': f'{prompt}'}
                      ]
          }] for _ in cls
    ]

    texts = [vlm_processor.apply_chat_template(
        message, tokenize=False, add_generation_prompt=True
    ) for message in messages]

    inputs = vlm_processor(images=image,
                           text=texts,
                           return_tensors="pt",
                           padding=True).to(vlm.device)
    generated_ids = vlm.generate(**inputs, max_new_tokens=25)
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    generated_texts = vlm_processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)

    for v, l in zip(generated_texts, response):
        match_1 = re.search(r"(land|sky|water)", v)
        match_2 = re.search(r"(land|sky|water)", l)
        if match_1 and match_2:
            if match_1.group() == match_2.group():
                activate_vector.append(1.)
            else:
                activate_vector.append(0.)
        else:
            activate_vector.append(0.)

    return torch.tensor(activate_vector, device=vlm.device).unsqueeze(1)

def image_feature(**kwargs):
    vlm_processor = kwargs['vlm_processor']
    vlm = kwargs['vlm']
    image = kwargs['image']
    llm_processor = kwargs['llm_processor']
    llm = kwargs['llm']
    cls = kwargs['predcted_cls']
    image = [load_image(i) for i in image]
    prompt = """Answer the following questions:
    
    1. Is this image from an anime/cartoon style or a realistic setting?
    2. What is shown in this image, including the main objects and more details?
    
    Please ensure each answer is comprehensive and concise, while addressing the specific question being asked.
    """

    messages = [
        [{'role': 'user',
          'content': [
              {'type': 'image'},
              {'type': 'text', 'text': f'{prompt}'}
          ]
          }] for _ in cls
    ]

    texts = [vlm_processor.apply_chat_template(
        message, tokenize=False, add_generation_prompt=True
    ) for message in messages]

    inputs = vlm_processor(images=image,
                           text=texts,
                           return_tensors="pt",
                           padding=True).to(vlm.device)
    generated_ids = vlm.generate(**inputs, max_new_tokens=128)
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    generated_texts = vlm_processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)

    activate_vector = []

    prompt = """
     It's part of a description of a particular picture which states whether the picture is from a comic or reality. By the source of the picture you, then do you think his description makes common sense? Output 1 (yes) 0 (no)    
    """

    messages = [[{'role': 'user',
                  'content': f'{g} \n {prompt}'}] for g in
                generated_texts
                ]

    llm_prompt = llm_processor.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(
        llm.device)
    out = llm.generate(input_ids=llm_prompt, max_new_tokens=10)
    output_text = llm_processor.batch_decode(out, skip_special_tokens=True)
    response = get_llm_response(output_text)

    for i in response:
        if '1' in i:
            activate_vector.append(1.)
        else:
            activate_vector.append(0.)


    return torch.tensor(activate_vector, device=vlm.device).unsqueeze(1)

def descriptive_similarity(**kwargs):
    vlm_processor = kwargs['vlm_processor']
    vlm = kwargs['vlm']
    image = kwargs['image']
    llm_processor = kwargs['llm_processor']
    llm = kwargs['llm']
    cls = kwargs['predcted_cls']
    cls = [c.replace('.', '').replace('a ', '').replace('an ', '').replace('An ', '').replace('A ', '') for c in cls]

    image = [load_image(i) for i in image]

    # prompt = r"Describe the main object in the image based solely on its appearance. Focus on its shape, color, texture, size, and any distinctive features."
    prompt = r"Briefly describe common characteristics of the main object in the image."

    messages = [
        [{'role': 'user',
          'content': [
              {'type': 'image'},
              {'type': 'text', 'text': f'{prompt}'}
          ]
          }] for _ in cls
    ]

    texts = [vlm_processor.apply_chat_template(
        message, tokenize=False, add_generation_prompt=True
    ) for message in messages]

    inputs = vlm_processor(images=image,
                           text=texts,
                           return_tensors="pt",
                           padding=True).to(vlm.device)
    generated_ids = vlm.generate(**inputs, max_new_tokens=50)
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    generated_texts = vlm_processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)
    activate_vector = []

    messages = [[{'role': 'user',
                  'content': f'The main object in the image is classified as {c}, Describe the general features of this object.'}]
                for c in cls
                ]

    llm_prompt = llm_processor.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(
        llm.device)
    out = llm.generate(input_ids=llm_prompt, max_new_tokens=50)
    output_text = llm_processor.batch_decode(out, skip_special_tokens=True)
    response = get_llm_response(output_text)
    with torch.inference_mode():
        for v, l in zip(generated_texts, response):
            _, _, f1 = score([v], [l], model_type='bert-large-uncased', lang='en')
            if f1 > 0.55:
                activate_vector.append(1.)
            else:
                activate_vector.append(0.)

    return torch.tensor(activate_vector, device=vlm.device).unsqueeze(1)

def judging_the_species(**kwargs):
    vlm_processor = kwargs['vlm_processor']
    vlm = kwargs['vlm']
    image = kwargs['image']
    llm_processor = kwargs['llm_processor']
    llm = kwargs['llm']
    cls = kwargs['predcted_cls']
    cls = [c.replace('.', '').replace('a ', '').replace('an ', '').replace('An ', '').replace('A ', '') for c in cls]

    image = [load_image(i) for i in image]

    prompt = r"Describe the physical appearance of the main object in this image in detail. Provide the features only, without extra explanations."

    messages = [
        [{'role': 'user',
          'content': [
              {'type': 'image'},
              {'type': 'text', 'text': f'{prompt}'}
          ]
          }] for _ in cls
    ]

    texts = [vlm_processor.apply_chat_template(
        message, tokenize=False, add_generation_prompt=True
    ) for message in messages]

    inputs = vlm_processor(images=image,
                           text=texts,
                           return_tensors="pt",
                           padding=True).to(vlm.device)
    generated_ids = vlm.generate(**inputs, max_new_tokens=50)
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    generated_texts = vlm_processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)
    activate_vector = []

    messages = [[{'role': 'user',
                  'content': f'Based on the given general features of the object: {c} Can you determine its exact species or breed? Answer only \'yes\' or \'no\'.'}]
                for c in generated_texts
                ]

    llm_prompt = llm_processor.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(
        llm.device)
    out = llm.generate(input_ids=llm_prompt, max_new_tokens=50)
    output_text = llm_processor.batch_decode(out, skip_special_tokens=True)
    response = get_llm_response(output_text)

    for l in response:
        if 'yes' in l or 'Yes' in l:
            activate_vector.append(1.)
        elif 'no' in l or 'No' in l:
            activate_vector.append(0.)
        else:
            activate_vector.append(0.)

    return torch.tensor(activate_vector, device=vlm.device).unsqueeze(1)

def judging_the_species2(**kwargs):
    vlm_processor = kwargs['vlm_processor']
    vlm = kwargs['vlm']
    image = kwargs['image']
    llm_processor = kwargs['llm_processor']
    llm = kwargs['llm']
    cls = kwargs['predcted_cls']
    cls = [c.replace('.', '').replace('a ', '').replace('an ', '').replace('An ', '').replace('A ', '') for c in cls]

    image = [load_image(i) for i in image]

    prompt = r"Look at the object in the image. Describe its distinctive physical characteristics, including size, color, shape, and any unique features that distinguish it from similar species. Avoid general terms, be as specific as possible."

    messages = [
        [{'role': 'user',
          'content': [
              {'type': 'image'},
              {'type': 'text', 'text': f'{prompt}'}
          ]
          }] for _ in cls
    ]

    texts = [vlm_processor.apply_chat_template(
        message, tokenize=False, add_generation_prompt=True
    ) for message in messages]

    inputs = vlm_processor(images=image,
                           text=texts,
                           return_tensors="pt",
                           padding=True).to(vlm.device)
    generated_ids = vlm.generate(**inputs, max_new_tokens=50)
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    generated_texts = vlm_processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)
    activate_vector = []

    messages = [[{'role': 'user',
                  'content': f'Based on the features of the object: {generated_texts[i]} Do all {c} have these characteristics? Anwser only yes or no.'}]
                for i, c in enumerate(cls)
                ]

    llm_prompt = llm_processor.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(
        llm.device)
    out = llm.generate(input_ids=llm_prompt, max_new_tokens=50)
    output_text = llm_processor.batch_decode(out, skip_special_tokens=True)
    response = get_llm_response(output_text)

    for l in response:
        if 'yes' in l or 'Yes' in l:
            activate_vector.append(1.)
        elif 'no' in l or 'No' in l:
            activate_vector.append(0.)
        else:
            activate_vector.append(0.)

    return torch.tensor(activate_vector, device=vlm.device).unsqueeze(1)

def qa_rules(**kwargs):
    image = kwargs['image']
    image = [load_image(i) for i in image]

    detect_model, detect_processor = kwargs['other_vlm'], kwargs['other_processor']
    output_text = run_example(image[0], '<MORE_DETAILED_CAPTION>', detect_model, detect_processor)

    information = output_text['<MORE_DETAILED_CAPTION>']

    llm, llm_processor = kwargs['llm'], kwargs['llm_processor']

    PROMPT_TEMPLATE = '''Given a sentence, extract the entities within the sentence for me. 
    Extract the common objects and summarize them as general categories without repetition, merge essentially similar objects.
    Avoid extracting abstract or non-specific entities. 
    Extract entity in the singular form. Output all the extracted types of items in one line and separate each object type with a period. If there is nothing to output, then output a single "None".

    Examples:
    Sentence:
    The image depicts a man laying on the ground next to a motorcycle, which appears to have been involved in a crash.

    Output:
    man.motorcycle

    Sentence:
    There are a few people around, including one person standing close to the motorcyclist and another person further away.

    Output:
    person.motorcyclist

    Sentence:
    No, there is no car in the image.

    Output:
    car

    Sentence:
    The image depicts a group of animals, with a black dog, a white kitten, and a gray cat, sitting on a bed.

    Output:
    dog.cat.bed

    Sentence:
    The image shows that polar bears and german shepherds.

    Output:
    polar bear.german shepherd

    Sentence:
    There are some Smart TVs and electric cars in the image.

    Output:
    smart tv.electric car

    Sentence:
    {sentence}

    Output:'''

    llm_message = [
        {
            "role": "user",
            "content": f"{PROMPT_TEMPLATE.format(sentence=information)}"
        }
    ]

    llm_messages = [llm_message]

    llm_prompt = llm_processor.apply_chat_template(llm_messages, add_generation_prompt=True, return_tensors="pt").to(
        llm.device)

    out = llm.generate(input_ids=llm_prompt, max_new_tokens=100)
    output_text = llm_processor.batch_decode(out, skip_special_tokens=True)

    response = get_llm_response(output_text)
    response = response[0]

    EXTRACT_PROMPT = '''Given Question and its correspond answer options, extract the entities that can solve the problem based on the combination of the question and options.
    Extract entities in singular form. Output all the extracted item types in one line, separated by a period. If there is nothing to output, return "None".

    Examples:
    Question:
    Which direction are the knobs currently in?
    A. The knobs are currently in the downward facing position.
    B. The knobs are currently in the right facing position.
    C. The knobs are currently in the upward facing position.
    Please answer directly with only the letter of the correct option and nothing else.

    Output:
    downward facing knob.right facing knob.upward facing knob

    Question:
    How many apples are in this image?
    A. 2.
    B. 4.
    C. 5.
    Please answer directly with only the letter of the correct option and nothing else.

    Output:
    apple

    Question:
    Is the closest red car to us closer than the closest semi truck?
    A. Yes.
    B. No.
    Please answer directly with only the letter of the correct option and nothing else.

    Output:
    red car.semi truck

    Question:
    What color is the closest cup to the camera?
    A. Green.
    B. Blue.
    Please answer directly with only the letter of the correct option and nothing else.

    Output:
    green cup.blue cup.camera

    Question:
    What is the main object in the center of the image?
    Please answer directly with only the letter of the correct option and nothing else.
    A. Apple.
    B. Banana.

    Output:
    apple.banana

    Question:
    Is the large tree to the left or right of the house?
    A. Left.
    B. Right.
    Please answer directly with only the letter of the correct option and nothing else.

    Output:
    tree.house

    Question:
    What is the shape of the object on the table?
    A. Square.
    B. Circular.
    C. Triangular.
    Please answer directly with only the letter of the correct option and nothing else.

    Output:
    square object.circular object.triangular object

    Question:
    What object is in the background?
    A. Lamp.
    B. Chair.
    Please answer directly with only the letter of the correct option and nothing else.

    Output:
    lamp.chair

    Question:
    Where is the chair in this image?
    A. Left.
    B. Center.
    C. Right.
    Please answer directly with only the letter of the correct option and nothing else.

    Output:
    left chair.center chair.right chair

    Question:
    What is the object next to the door?
    A. Book.
    B. Shoe.
    C. Bag.
    Please answer directly with only the letter of the correct option and nothing else.

    Output:
    book.shoe.bag

    Question:
    {question}

    Output:'''

    f = kwargs['lang_x'][0]
    for t in f['content']:
        if t['type'] == 'text':
            question = t['text']
    print('Question:', question)

    llm_message = [
        {
            "role": "user",
            "content": f"{EXTRACT_PROMPT.format(question=question)}"
        }
    ]

    llm_messages = [llm_message]

    llm_prompt = llm_processor.apply_chat_template(llm_messages, add_generation_prompt=True, return_tensors="pt").to(
        llm.device)

    out = llm.generate(input_ids=llm_prompt, max_new_tokens=100)
    output_text = llm_processor.batch_decode(out, skip_special_tokens=True)

    response2 = get_llm_response(output_text)
    response2 = response2[0]

    a = response.split('.')
    seen = set(a)

    # combine words
    for word in response2.split('.'):
        if word not in seen:
            a.append(word)
            seen.add(word)

    result = '.'.join(a)
    response = result

    COUNT_PROMPT = r"How many {entity} in this image? Answer only the number. Do not include any additional details or explanations."
    COLOR_PROMPT = r"What is the color of {entity} in this image? Answer only color, nothing else. If more colors are present, join the results by a period. Like: red.blue.black"

    vlm, vlm_processor = kwargs['vlm'], kwargs['vlm_processor']

    # prompt = 'Describe the image briefly.'
    activate = []
    useful_infor = {}

    def conver2int(string):
        try:
            int(string)
            return True
        except ValueError:
            return False

    for e in response.split('.'):
        for idx, p in enumerate([COUNT_PROMPT, COLOR_PROMPT]):

            message = [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                        },
                        {"type": "text", "text": f"{p.format(entity=e)}"}
                    ],
                },

            ]

            messages = [message]

            text = [vlm_processor.apply_chat_template(m, tokenize=False, add_generation_prompt=True) for m in messages]

            inputs = vlm_processor(text=text, images=image, padding=True, return_tensors="pt").to(vlm.device)

            generated_ids = vlm.generate(**inputs, max_new_tokens=100)

            generated_ids_trimmed = [
                out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
            ]
            output_text = vlm_processor.batch_decode(
                generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=True
            )

            if idx == 0:
                result = run_example(image[0], '<OPEN_VOCABULARY_DETECTION>', detect_model, detect_processor, e)
                count = len(result['<OPEN_VOCABULARY_DETECTION>']['bboxes_labels'])
                size = []
                if count != 0:
                    for x in range(count):
                        sizes = result['<OPEN_VOCABULARY_DETECTION>']['bboxes'][x]
                        size.append((sizes[2] - sizes[0]) * (sizes[3] - sizes[1]))
                else:
                    size.append('None')


                if conver2int(output_text[0]):
                    if int(output_text[0]) == count:
                        activate.append(1.)
                    else:
                        activate.append(0.)
                else:
                    activate.append(0.)

                if e not in useful_infor:
                    useful_infor[e] = [{'count': count}]

                useful_infor[e].append({'size': size})
            if idx == 1:
                if useful_infor[e][0]['count'] >= 5 or useful_infor[e][0]['count'] == 0:
                    useful_infor[e].append({'color': 'None'})
                else:
                    useful_infor[e].append({'color': output_text[0]})

    if len(activate) < 9:
        activate.extend([0.] * (9 - len(activate)))
    else:
        activate = activate[:9]

    his_answer = kwargs['predcted_cls']


    answer_promt = """Here we have the key information from the image:
    {information}
    
    The information is a dictionary whose keys are the entities in the image and whose values represent the current number of entities and their colours.
    
    Please answer the following question based on the image information:
    {question}
    """

    llm_message = [
        {
            "role": "user",
            "content": f"{answer_promt.format(information=useful_infor, question=question)}"
        }
    ]

    llm_messages = [llm_message]

    llm_prompt = llm_processor.apply_chat_template(llm_messages, add_generation_prompt=True, return_tensors="pt").to(
        llm.device)

    out = llm.generate(input_ids=llm_prompt, max_new_tokens=100)
    out = llm_processor.batch_decode(out, skip_special_tokens=True)
    out = get_llm_response(out)
    out = extract_answer(out[0])

    if fuzz.ratio(out, his_answer) > 75:
        activate.append(1.)
    else:
        activate.append(0.)

    return torch.tensor(activate, device=vlm.device).unsqueeze(0)

# [n_shot_top5_distr, is_cls_animal, where_live, image_feature, descriptive_similarity]
rules = {
    'cifar10': [n_shot_top5_distr, is_cls_animal, where_live, image_feature, descriptive_similarity],
    'animals': [n_shot_top5_distr, where_live, image_feature, descriptive_similarity, judging_the_species, judging_the_species2],
    'realworld': [n_shot_top5_distr, qa_rules]
}