"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import logging
import os
import pickle
from minigpt4.conversation.conversation import CONV_VISION_minigptv2,CONV_VISION_Vicuna0,CONV_VISION_LLama2
import torch
import torch.distributed as dist
from minigpt4.common.dist_utils import get_rank, get_world_size, is_main_process, is_dist_avail_and_initialized
from minigpt4.common.logger import MetricLogger, SmoothedValue
from minigpt4.common.registry import registry
from minigpt4.datasets.data_utils import prepare_sample
import wandb
import re

from contextlib import contextmanager, nullcontext
import torch.nn as nn
import random
# import torch
random.seed(42)

def contains_vehicle_words(input_string, words):
    # 创建正则表达式，匹配任何给定的单词
    pattern = r'\b(?:' + '|'.join(words) + r')\b'
    return bool(re.search(pattern, input_string, re.IGNORECASE))
def prepare_texts(texts, conv_temp):
    convs = [conv_temp.copy() for _ in range(len(texts))]
    [conv.append_message(
        conv.roles[0], '<Img><ImageHere></Img> {}'.format(text)) for conv, text in zip(convs, texts)]
    [conv.append_message(conv.roles[1], None) for conv in convs]
    texts = [conv.get_prompt() for conv in convs]
    return texts
class BaseTask:
    def __init__(self, **kwargs):
        super().__init__()

        self.conv_temp = CONV_VISION_LLama2.copy()
        self.conv_temp.system = ""
        self.inst_id_key = "instance_id"
        self.cfg = ""
        self.sample_id_list=[]

    @classmethod
    def setup_task(cls, **kwargs):
        return cls()

    def build_model(self, cfg):
        self.cfg = cfg
        model_config = cfg.model_cfg
        # print(model_config.arch)

        model_cls = registry.get_model_class(model_config.arch)
        # print('model_cls',model_cls)
        return model_cls.from_config(model_config)

    def build_datasets(self, cfg):
        """
        Build a dictionary of datasets, keyed by split 'train', 'valid', 'test'.
        Download dataset and annotations automatically if not exist.

        Args:
            cfg (common.config.Config): _description_

        Returns:
            dict: Dictionary of torch.utils.data.Dataset objects by split.
        """

        datasets = dict()

        datasets_config = cfg.datasets_cfg
        # print(datasets_config)
        eval_dataset=None
        assert len(datasets_config) > 0, "At least one dataset has to be specified."
        # print(datasets_config)
        # exit()

        for name in datasets_config:
            # print('name',name)
            # exit()
            if 'eval' not in name:
                dataset_config = datasets_config[name]
                # print(dataset_config,name)
                # print(registry.get_builder_class(name))
                builder = registry.get_builder_class(name)(dataset_config)
                print('build',builder,name)
                dataset = builder.build_datasets()
                print(dataset)
                # exit()
                dataset['train'].name = name
                # print(dataset,dataset['train'].name)
                # exit()
                if 'sample_ratio' in dataset_config:
                    dataset['train'].sample_ratio = dataset_config.sample_ratio

                datasets[name] = dataset
            else:
                dataset_config = datasets_config[name]
                # print(dataset_config,name)
                # print()
                # print(registry.get_builder_class(name))
                builder = registry.get_builder_class(name)(dataset_config)
                # print('build',builder,name)
                # print(dataset_config)
                eval_dataset = builder.build_datasets()
                # print(eval_dataset)
                # print(eval_dataset[0,20,46])
                # exit()
                for data in eval_dataset:
                    # print(data[2])
                    # exit()
                    # print(data['id'])
                    self.sample_id_list.append(data[2])

                # exit()



        # print('datasets',datasets,eval_dataset)
        # print('tttt',eval_dataset)
        # exit()
        # print('xxx',datasets,eval_dataset)
        # exit()
        # exit()

        return datasets,eval_dataset

    def train_step(self, model, samples,loss_mask_dict,output_attentions):
        # print('model',model)
        # exit()
        loss = model(samples,loss_mask_dict,output_attentions)["loss"]
        return loss


    def attack_step(self, model, samples,attack_prompt,show):

        loss= model.forward_attack(samples,attack_prompt,show,ref=False)['loss']
        # loss=loss_dict["loss"]
        return loss
    def defend_step(self, model,ref_model, samples,attack_prompt,show):
        # print('model1')
        # print('model2', ref_model)
        # print('defend')
        policy_chosen_logps,policy_rej_logps = model.forward_defend(samples,attack_prompt,show)#["loss"]
        defend_loss = model.dpo_loss(policy_rej_logps,policy_chosen_logps, None,None
                             , ref=True)
        # print(defend_loss)
        # print('utility')
        policy_chosen_logps, policy_rej_logps = model.forward_utility(samples, attack_prompt, show)  # ["loss"]
        utility_loss = model.dpo_loss(policy_chosen_logps, policy_rej_logps, None,
                                     None, ref='sft')
        loss=utility_loss+defend_loss

        # print('nima',ref_chosen_logps,ref_rej_logps)


        #["loss"]
        # exit()
        return loss

    def valid_step(self, model, samples):
        raise NotImplementedError

    def before_evaluation(self, model, dataset, **kwargs):
        model.before_evaluation(dataset=dataset, task_type=type(self))

    def after_evaluation(self, **kwargs):
        pass

    def inference_step(self):
        raise NotImplementedError

    def evaluation(self, model, data_loader, cuda_enabled=True):
        metric_logger = MetricLogger(delimiter="  ")
        header = "Evaluation"
        # TODO make it configurable
        print_freq = 10

        results = []

        for samples in metric_logger.log_every(data_loader, print_freq, header):
            samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
            eval_output = self.valid_step(model=model, samples=samples)
            results.extend(eval_output)

        if is_dist_avail_and_initialized():
            dist.barrier()

        return results
    def unwrap_dist_model(self, model):
        # if self.use_distributed:
        return model.module
        # else:
        #     return model
    def train_epoch(
        self,
        epoch,
        model,
        ref_model,
        data_loader,
        optimizer,
        lr_scheduler,
        scaler=None,
        cuda_enabled=False,
        log_freq=50,
        loss_mask_dict=None,
        loss_mask_dict_bn=None,
        accum_grad_iters=1,
        eval_dataset=None,
        iter_per_epoch=None,
        slogan=None,
        tgt_words=None,
        gth_decision_label=None,
        output_attentions=None,
        tok_length=None

    ):

        return self._train_inner_loop(
            epoch=epoch,
            iters_per_epoch=iter_per_epoch,
            model=model,
            ref_model=ref_model,
            data_loader=data_loader,
            optimizer=optimizer,
            scaler=scaler,
            lr_scheduler=lr_scheduler,
            log_freq=log_freq,
            loss_mask_dict=loss_mask_dict,
            loss_mask_dict_bn=loss_mask_dict_bn,
            cuda_enabled=cuda_enabled,
            accum_grad_iters=accum_grad_iters,
            eval_dataset=eval_dataset,
            slogan=slogan,
            tgt_words=tgt_words,
            gth_decision_label=gth_decision_label,
            output_attentions=output_attentions,
            tok_length=tok_length
        )

    def train_iters(
        self,
        epoch,
        start_iters,
        iters_per_inner_epoch,
        model,
        data_loader,
        optimizer,
        lr_scheduler,
        scaler=None,
        cuda_enabled=False,
        log_freq=50,
        accum_grad_iters=1,
    ):
        return self._train_inner_loop(
            epoch=epoch,
            start_iters=start_iters,
            iters_per_epoch=iters_per_inner_epoch,
            model=model,
            data_loader=data_loader,
            optimizer=optimizer,
            scaler=scaler,
            lr_scheduler=lr_scheduler,
            log_freq=log_freq,
            cuda_enabled=cuda_enabled,
            accum_grad_iters=accum_grad_iters,
        )
    def prepare_eval_data(self,id_list):
        data_list=[]
        exist_list=[]
        for idx,id in enumerate(id_list):
            exist_list.append(id)
            if (idx+1)%64==0:
                data_list.append(exist_list.copy())
                exist_list=[]
        data_list.append(exist_list)
        out_data=[]
        # print(data_list)
        for batch_id_list in data_list:
            if len(batch_id_list) !=0:
                batch_list=[]
                for i in batch_id_list:
                    batch_list.append(self.eval_dataset[i])
                aa=[]
                aa.append(torch.stack([ele[0] for ele in batch_list],axis=0))
                aa.append(tuple([ele[1] for ele in batch_list]))
                aa.append(tuple([ele[2] for ele in batch_list]))
                out_data.append(aa.copy())
        return out_data
    def update_loss_mask_dict(self,loss_mask_dict,id_list,training_model):
        eval_model=self.unwrap_dist_model(training_model)
        eval_model.eval()
        max_new_tokens = 256
        answer_list = []
        qid_list = []
        answer_dict = {}

        eval_data=self.prepare_eval_data(id_list)

        for batch in eval_data:
            images, questions, question_ids=batch[0],batch[1],batch[2]
            # print(self.conv_temp,images,questions,question_ids)
            # print('here')
            texts = prepare_texts(questions, self.conv_temp)  # warp the texts with conversation template
            answers = eval_model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False)
            # print('there')
            answer_list.extend(answers)
            qid_list.extend(question_ids)
            # break
        for qid, answer in zip(qid_list, answer_list):
            answer_dict[str(qid)] = answer
        self.update_loss_mask(answer_dict)
        # return loss_mask_dict
    def update_loss_mask(self,answer_dict):
        # loss_mask_dict={}
        for k,v in answer_dict.items():
            tgt_label=self.gth_decision_label[k]
            v = v.replace(self.slogan, "")
            # print('v',k,v)
            pred_label=contains_vehicle_words(v, self.tgt_words)
            # print('pred_label',k,pred_label,tgt_label)
            if pred_label ==tgt_label:
                self.loss_mask_dict[k]=1
            else:
                self.loss_mask_dict[k]=0
        return None
    def preparing_embedding_attack4generate(self, model,batch_messages,soft_prompt,batch_labels):
        ### prepare input tokens
        soft_prompt_front = soft_prompt[:int(soft_prompt.size(0) / 2)]
        soft_prompt_back = soft_prompt[int(soft_prompt.size(0) / 2):]
        # n_prompt_tokens_safe = soft_safety_prompt.size(0)
        n_prompt_tokens_front = soft_prompt_front.size(0)
        n_prompt_tokens_back = soft_prompt_back.size(0)
        n_prompt_tokens_total_front = n_prompt_tokens_front



        messages_with_eos_placeholder = [
            model.llama_tokenizer.eos_token * n_prompt_tokens_total_front + model.prompt_prefix + message + '' + model.llama_tokenizer.eos_token * n_prompt_tokens_back + model.prompt_suffix for message in
            batch_messages]
        messages_with_labels = [
            label for label in
            batch_labels]

        input_ids = []
        target_ids = []
        input_ids_nolabels = []
        mask_length = []
        for e, eb in zip(messages_with_eos_placeholder, messages_with_labels):
            input_text = e  # toker.apply_chat_template(e, add_generation_prompt=True, tokenize=False)
            target = eb
            input_ids.append([model.llama_tokenizer(
                input_text, return_tensors='pt').input_ids.tolist()[0]])
            mask_length.append(len(model.llama_tokenizer(
                input_text, return_tensors='pt').input_ids.tolist()[0]))
            target_ids.append([model.llama_tokenizer(
                target, return_tensors='pt').input_ids.tolist()[0][1:]])

        input_lengths = []
        for e, et in zip(input_ids, target_ids):
            input_lengths.append(len(e[0]) + len(et[0]))
        max_input_length = max(input_lengths)
        placeholder_start_index = input_ids[0][0].index(model.llama_tokenizer.eos_token_id)
        # print(input_text)
        # print(input_ids)
        placeholder_end_index = [len(input_ids[i][0]) - 4 for i in range(len(input_ids))]
        input_embeds_list = []
        label_list = []
        input_embed_list=[]

        for idx, (e, et) in enumerate(zip(input_ids, target_ids)):
            if len(e) == 1:
                # print(max_input_length,len(e[0]),len(et[0]),input_lengths)
                input_id = e[0] #+ et[0] + [model.llama_tokenizer.pad_token_id] * (max_input_length - len(e[0]) - len(et[0]))
                input_ids0 = torch.tensor(input_id, dtype=torch.long).cuda()
                inputs_embeds = model.embed_tokens(input_ids0)
                inputs_embeds[
                placeholder_start_index:placeholder_start_index + n_prompt_tokens_total_front]= soft_prompt_front
                inputs_embeds[
                placeholder_end_index[idx] - n_prompt_tokens_back:placeholder_end_index[idx]] = soft_prompt_back
                input_embed_list.append(inputs_embeds)
        # inputs_embeds = torch.stack(input_embeds_list, axis=0)  # .to(model.device)
        # labels = torch.stack(label_list, axis=0)  # .to(model.device)
        return input_embed_list
    def _train_inner_loop(
        self,
        epoch,
        iters_per_epoch,
        model,
        ref_model,
        data_loader,
        optimizer,
        lr_scheduler,
        scaler=None,
        start_iters=None,
        loss_mask_dict_bn=None,
        log_freq=1,
        cuda_enabled=False,
        loss_mask_dict=None,
        accum_grad_iters=1,
        eval_dataset=None,
        slogan=None,
        tgt_words=None,
        gth_decision_label=None,
        output_attentions=None,
        coin_flip='None',
        tok_length=None
    ):
        """
        An inner training loop compatible with both epoch-based and iter-based training.

        When using epoch-based, training stops after one epoch; when using iter-based,
        training stops after #iters_per_epoch iterations.
        """
        # print(model.llama_tokenizer)
        # exit()
        self.eval_dataset=eval_dataset
        self.slogan = slogan
        self.tgt_words = tgt_words
        self.gth_decision_label = gth_decision_label
        self.loss_mask_dict=loss_mask_dict
        self.loss_mask_dict_bn = loss_mask_dict_bn
        log_freq=1
        # print(self.loss_mask_dict)
        # print(self.sample_id_list,len(self.sample_id_list))
        #
        # print('model2',model)
        # print('model3',ref_model)
        # exit()
        # self.loss_mask_dict=loss_mask_dict
        use_amp = scaler is not None
        # accum_grad_iters = 8
        if not hasattr(data_loader['malicious_qa'], "__next__"):
            # convert to iterator if not already
            data_loader['malicious_qa'] = iter(data_loader['malicious_qa'])
        if not hasattr(data_loader['benign_qa'], "__next__"):
            # convert to iterator if not already
            data_loader['benign_qa'] = iter(data_loader['benign_qa'])
        # if not hasattr(data_loader[1], "__next__"):
        #     # convert to iterator if not already
        #     data_loader[1] = iter(data_loader[1])

        metric_logger = MetricLogger(delimiter="  ")
        metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}"))
        metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4f}"))

        # if iter-based runner, schedule lr based on inner epoch.
        logging.info(
            "Start training epoch {}, {} iters per inner epoch.".format(
                epoch, iters_per_epoch
            )
        )
        header = "Train: data epoch: [{}]".format(epoch)
        if start_iters is None:
            # epoch-based runner
            inner_epoch = epoch
        else:
            # In iter-based runner, we schedule the learning rate based on iterations.
            inner_epoch = start_iters // iters_per_epoch
            header = header + "; inner epoch [{}]".format(inner_epoch)
        # print(inner_epoch)
        # exit()
        characters_set = "!@*¥(&@*¥&()ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
        adv_length=16
        attack_iter=40

        for i in metric_logger.log_every(range(iters_per_epoch), log_freq, header):
            malicious_samples = next(data_loader['malicious_qa'])
            benign_samples = next(data_loader['benign_qa'])

            for param in model.parameters():
                param.requires_grad = False
            model=model.eval()
            # print('model',model)
            # exit()

            # exit()
            model = model.train()
            # for name, parms in model.named_parameters():
            #     print('name', name)
            #     print("requires_grad", parms.requires_grad)
            # exit()
            # init_string='!!!!!!!!!!!!!!!!'
            # init_ids = model.llama_tokenizer(init_string).input_ids[1:adv_length + 1]
            # init_embeds = model.llama_model.base_model.embed_tokens(
            #     torch.tensor(init_ids).to(model.device)).cpu().detach()
            # # exit()
            # soft_prompt = nn.Parameter(init_embeds, requires_grad=True)
            # soft_prompt = torch.tensor(soft_prompt, dtype=torch.float32).to(model.device)
            # if int(soft_prompt.shape[0]) == adv_length:
            #     raise ValueError

            while True:

                init_string = ''.join(random.choice(characters_set) for _ in range(int(adv_length)))
                # print(init_string)
                # print(model.llama_tokenizer(init_string).input_ids)
                init_ids = model.llama_tokenizer(init_string).input_ids[1:adv_length + 1]
                init_embeds = model.llama_model.base_model.embed_tokens(torch.tensor(init_ids).to(model.device)).cpu().detach()
                # exit()
                soft_prompt = nn.Parameter(init_embeds, requires_grad=True)
                soft_prompt=torch.tensor(soft_prompt,dtype=torch.float32).to(model.device)
                if int(soft_prompt.shape[0]) == adv_length:
                    break


            # print(soft_prompt.dtype)
            # exit()
            optimizer_adv = torch.optim.AdamW([soft_prompt], lr=1e-3)
            # print('attack')
            soft_prompt.requires_grad_(True)


            for idx in range(attack_iter):
                # optimizer.zero_grad()
                optimizer_adv.zero_grad()
                if idx==attack_iter-2:
                    show=True
                else:
                    show=False
                # with torch.cuda.amp.autocast(enabled=use_amp):
                loss=self.attack_step(model=model, samples=malicious_samples,attack_prompt=soft_prompt,show=show)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(soft_prompt, 1.0)
                optimizer_adv.step()
            soft_prompt.requires_grad_(False)
            for name, param in model.named_parameters():
                if "lora" in name:
                    param.requires_grad = True
                if "llama_proj" in name:
                    param.requires_grad = True
            model = model.train()
            lr_scheduler.step(cur_epoch=inner_epoch, cur_step=i)
            with torch.cuda.amp.autocast(enabled=use_amp):
                loss = self.defend_step(model=model,ref_model=ref_model, samples=[malicious_samples,benign_samples], attack_prompt=soft_prompt,show=True)
            if use_amp:
                scaler.scale(loss).backward()
            else:
                loss.backward()
            if use_amp:
                scaler.step(optimizer)
                scaler.update()
            optimizer.zero_grad()
            if i >= iters_per_epoch:
                break
            # print(loss.item())
            metric_logger.update(loss=loss.item())
            metric_logger.update(lr=optimizer.param_groups[0]["lr"])
            # exit()

        # after train_epoch()
        # gather the stats from all processes
        metric_logger.synchronize_between_processes()
        logging.info("Averaged stats: " + str(metric_logger.global_avg()))
        return {
            k: "{:.3f}".format(meter.global_avg)
            for k, meter in metric_logger.meters.items()
        }

    @staticmethod
    def save_result(result, result_dir, filename, remove_duplicate=""):
        import json

        result_file = os.path.join(
            result_dir, "%s_rank%d.json" % (filename, get_rank())
        )
        final_result_file = os.path.join(result_dir, "%s.json" % filename)

        json.dump(result, open(result_file, "w"))

        if is_dist_avail_and_initialized():
            dist.barrier()

        if is_main_process():
            logging.warning("rank %d starts merging results." % get_rank())
            # combine results from all processes
            result = []

            for rank in range(get_world_size()):
                result_file = os.path.join(
                    result_dir, "%s_rank%d.json" % (filename, rank)
                )
                res = json.load(open(result_file, "r"))
                result += res

            if remove_duplicate:
                result_new = []
                id_list = []
                for res in result:
                    if res[remove_duplicate] not in id_list:
                        id_list.append(res[remove_duplicate])
                        result_new.append(res)
                result = result_new

            json.dump(result, open(final_result_file, "w"))
            print("result file saved to %s" % final_result_file)

        return final_result_file
