import logging
from PIL import Image
import torch.nn.functional as F

from .blip_attacker_helper import BLIPAttackerHelper


class BLIP2Attacker(BLIPAttackerHelper):

    def __init__(self, model_name, model_type, device) -> None:
        super().__init__(model_name, model_type, device)
        logging.info('******** Init BLIP2Attacker ********')


    def get_sample(self, item, image_ori, image_adv, trigger, target_answer):
        text_input = item['text_input']
        text_input = [self.txt_processors['eval'](que) for que in text_input] # process text_input

        text_input_trigger = [trigger + ' ' + que for que in text_input]
        target_answer = [target_answer for _ in range(len(text_input))]

        ori_outputs = []
        for i in range(len(text_input)):
            img_ori = self.normalize(image_ori[i]).unsqueeze(0)  # [3, 224, 224] -> [1, 3, 224, 224]
            txt_input = [text_input[i]]  # 'how many glasses are there?' -> ['how many glasses are there?']
            ori_out = self.eval_sample(img_ori, txt_input)
            ori_outputs.append(ori_out[0])
        ori_outputs = [self.txt_processors['eval'](ans) for ans in ori_outputs]
        logging.info(f'ori_outputs: {ori_outputs}')
        logging.info(f'target_answer: {target_answer}')

        sample_ori = {
            "image": self.normalize(image_ori),
            "text_input": text_input,
            "text_output": ori_outputs,
        }

        sample_without_trigger = {
            "image": self.normalize(image_adv),
            "text_input": text_input,  # without trigger
            "text_output": ori_outputs,
        }

        sample_with_trigger = {
            "image": self.normalize(image_adv),
            "text_input": text_input_trigger,  # with trigger
            "text_output": target_answer,
        }

        return sample_ori, sample_without_trigger, sample_with_trigger
    

    def eval_sample(self, img, text_input):
        if isinstance(img, str):
            img = Image.open(img).convert('RGB')
            img = self.vis_processors['eval'](img).unsqueeze(0).to(self.device)  # batch_size

        # img = self.normalize(img)  already normalized in get_sample
        qs = f"Question: {text_input[0]} Short Answer:"
        logging.info(f'qs: {qs}')
        sample = {
            "image": img,
            "prompt": qs
        }
        
        if hasattr(self.model, 'module'):
            output = self.model.module.generate(samples=sample, num_beams=1)
        else:
            output = self.model.generate(samples=sample, num_beams=1)
        logging.info(f'{qs} {output}')
        return output