import torch
from lavis.models import load_model_and_preprocess
from PIL import Image
from tqdm import tqdm
from torchvision import transforms
import torch.nn as nn
import logging


class BLIPAttackerHelper:
    def __init__(self, model_name, model_type, device) -> None:
        self.model, self.vis_processors, self.txt_processors = load_model_and_preprocess(name=model_name, model_type=model_type, is_eval=True, device=device)

        self.kl_loss = torch.nn.KLDivLoss(reduce='batchmean')
        self.device = device

        mean = (0.48145466, 0.4578275, 0.40821073)
        std = (0.26862954, 0.26130258, 0.27577711)

        self.normalize = transforms.Normalize(mean, std)


    def get_sample(self, item, image_ori, image_adv, trigger, target_answer):
        text_input = item['text_input']
        text_input = [self.txt_processors['eval'](que) for que in text_input] # process text_input

        text_input_trigger = [trigger + ' ' + que for que in text_input]
        target_answer = [target_answer for _ in range(len(text_input))]

        ori_outputs = []
        for i in range(len(text_input)):
            img_ori = self.normalize(image_ori[i]).unsqueeze(0)  # [3, 224, 224] -> [1, 3, 224, 224]
            txt_input = [text_input[i]]  # 'how many glasses are there?' -> ['how many glasses are there?']
            ori_out = self.eval_sample(img_ori, txt_input)
            ori_outputs.append(ori_out[0])
        ori_outputs = [self.txt_processors['eval'](ans) for ans in ori_outputs]
        logging.info(f'ori_outputs: {ori_outputs}')
        logging.info(f'target_answer: {target_answer}')

        sample_ori = {
            "image": self.normalize(image_ori),
            "text_input": text_input,
            "text_output": ori_outputs,
        }

        sample_without_trigger = {
            "image": self.normalize(image_adv),
            "text_input": text_input,  # without trigger
            "text_output": ori_outputs,
        }

        sample_with_trigger = {
            "image": self.normalize(image_adv),
            "text_input": text_input_trigger,  # with trigger
            "text_output": target_answer,
        }

        return sample_ori, sample_without_trigger, sample_with_trigger
    

    def eval_sample(self, img, text_input):
        if isinstance(img, str):
            img = Image.open(img).convert('RGB')
            img = self.vis_processors['eval'](img).unsqueeze(0).to(self.device)  # batch_size

        # img = self.normalize(img)  already normalized in get_sample
        sample = {
            "image": img,
            "prompt": text_input
        }
        
        if hasattr(self.model, 'module'):
            output = self.model.module.generate(samples=sample, num_beams=1)
        else:
            output = self.model.generate(samples=sample, num_beams=1)

        return output
    

    def get_loss(self, sample_ori, sample_without_trigger, sample_with_trigger, loss_without_trigger_weight, loss_with_trigger_weight, loss_type):
        
        # Loss without trigger 
        output_without_trigger = self.model(sample_without_trigger)
        loss_without_trigger = - output_without_trigger['loss'] * loss_without_trigger_weight

        # Loss with trigger 
        output_with_trigger = self.model(sample_with_trigger)
        loss_with_trigger = - output_with_trigger['loss'] * loss_with_trigger_weight
        
        if loss_type == 1:  # without_trigger_loss
            loss = loss_without_trigger
        elif loss_type == 2:  # with_trigger_loss
            loss = loss_with_trigger
        elif loss_type == 3:  # both
            loss = loss_without_trigger + loss_with_trigger

        return loss, loss_without_trigger, loss_with_trigger