import logging
from PIL import Image
import torch.nn.functional as F
from torchvision import transforms
import random

from minigpt4.common.config import Config
from minigpt4.common.registry import registry


class MiniGPTAttacker:

    def __init__(self, args) -> None:
        logging.info('******** Init MiniGPT4Attacker ********')


        mean = (0.48145466, 0.4578275, 0.40821073)
        std = (0.26862954, 0.26130258, 0.27577711)

        self.normalize = transforms.Normalize(mean, std)

        cfg = Config(args)

        logging.info('cfg')

        print('batch_size: ', args.batch_size)   
        logging.info(f'device: {args.device}')

        self.model_config = cfg.model_cfg
        # model_config.device_map = 'auto'
        model_cls = registry.get_model_class(self.model_config.arch)
        # self.model = model_cls.from_config(model_config).to("cuda:{}".format(args.gpu_id))
        self.model = model_cls.from_config(self.model_config)
        logging.info(f'model_config: {self.model_config}')
        self.model = self.model.to(args.device)
        self.model.eval()
        # .requires_grad_(False)

        self.instruction_pool =[
            # "[vqa] {}",
            "[vqa] Based on the image, respond to this question with a short answer: {}"
        ]

        self.args = args

        text_processor_cfg = cfg.datasets_cfg.cc_sbu_align.text_processor.train
        self.text_processor = registry.get_processor_class(text_processor_cfg.name).from_config(text_processor_cfg)


    def get_sample(self, item, image_ori, image_adv, trigger, target_answer):
        text_input = item['text_input']
        text_input_trigger = [trigger + ' ' + que for que in text_input]

        target_answer = [self.text_processor(target_answer) for _ in range(len(text_input))]

                
        # without-trigger
        text_input_wo = [random.choice(self.instruction_pool).format(qs) for qs in text_input] # process text_input
        text_input_wo = ["<Img><ImageHere></Img> {} ".format(qs) for qs in text_input_wo]
        text_input_wo = [self.model_config.prompt_template.format(text_input_wo[i]) for i in range(len(text_input_wo))]

        logging.info(f'text_input_wo: {text_input_wo}')

        ori_outputs = []
        for i in range(len(text_input)):  
            img_ori = self.normalize(image_ori[i]).unsqueeze(0)  # [3, 224, 224] -> [1, 3, 224, 224]

            instruction = text_input_wo[i]

            ori_out = self.eval_sample(img_ori, instruction)
            ori_out = self.text_processor(ori_out)  
            ori_outputs.append(ori_out[0])
        logging.info(f'ori_outputs: {ori_outputs}')
        logging.info(f'target_answer: {target_answer}')

        # with-trigger
        text_input_w = [random.choice(self.instruction_pool).format(qs) for qs in text_input_trigger] # process text_input
        text_input_w = ["<Img><ImageHere></Img> {} ".format(qs) for qs in text_input_w]
        text_input_w = [self.model_config.prompt_template.format(text_input_w[i]) for i in range(len(text_input_w))]
        
        logging.info(f'text_input_w: {text_input_w}')

        sample_ori = {
            "image": self.normalize(image_ori),
            "instruction_input": text_input_wo,
            # "instruction_input_for_generation": text_input_for_generation,  # for generation
            "answer": ori_outputs,
        }

        sample_without_trigger = {
            "image": self.normalize(image_adv),
            "instruction_input": text_input_wo,  # without trigger
            # "instruction_input_for_generation": text_input_for_generation,  # for generation
            "answer": ori_outputs,
        }

        sample_with_trigger = {
            "image": self.normalize(image_adv),
            "instruction_input": text_input_w,  # with trigger
            # "instruction_input_for_generation": text_input_trigger_for_generation,  # for generation
            "answer": target_answer,
        }

        logging.info(f'sample_ori: {sample_ori}')
        logging.info(f'sample_without_trigger: {sample_without_trigger}')
        logging.info(f'sample_with_trigger: {sample_with_trigger}')

        return sample_ori, sample_without_trigger, sample_with_trigger
    
    def eval_sample(self, img, instruction):
        sample = {
            "image": img,
            "instruction_input": instruction
        }

        logging.info(f'*** Eval_sample *** instruction_input: {instruction}')
        
        if hasattr(self.model, 'module'):
            output = self.model.module.generate(img, instruction, max_new_tokens=1024, do_sample=False)
        else:
            output = self.model.generate(img, instruction, max_new_tokens=1024, do_sample=False)

        return output


    def get_gt(self, item):
        return
    
    def get_loss(self, sample_ori, sample_without_trigger, sample_with_trigger, loss_without_trigger_weight, loss_with_trigger_weight, loss_type):
        # Loss without trigger 
        output_without_trigger = self.model(sample_without_trigger)
        loss_without_trigger = - output_without_trigger['loss'] * loss_without_trigger_weight

        # Loss with trigger 
        output_with_trigger = self.model(sample_with_trigger)
        loss_with_trigger = - output_with_trigger['loss'] * loss_with_trigger_weight
        
        if loss_type == 1:  # without_trigger_loss
            loss = loss_without_trigger
        elif loss_type == 2:  # with_trigger_loss
            loss = loss_with_trigger
        elif loss_type == 3:  # both
            loss = loss_without_trigger + loss_with_trigger

        return loss, loss_without_trigger, loss_with_trigger 
    
    

