"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import logging
import os
from minigpt4.conversation.conversation import CONV_VISION_minigptv2,CONV_VISION_Vicuna0,CONV_VISION_LLama2
import torch
import torch.distributed as dist
from minigpt4.common.dist_utils import get_rank, get_world_size, is_main_process, is_dist_avail_and_initialized
from minigpt4.common.logger import MetricLogger, SmoothedValue
from minigpt4.common.registry import registry
from minigpt4.datasets.data_utils import prepare_sample
import wandb
import re


def contains_vehicle_words(input_string, words):
    # 创建正则表达式，匹配任何给定的单词
    pattern = r'\b(?:' + '|'.join(words) + r')\b'
    return bool(re.search(pattern, input_string, re.IGNORECASE))
def prepare_texts(texts, conv_temp):
    convs = [conv_temp.copy() for _ in range(len(texts))]
    [conv.append_message(
        conv.roles[0], '<Img><ImageHere></Img> {}'.format(text)) for conv, text in zip(convs, texts)]
    [conv.append_message(conv.roles[1], None) for conv in convs]
    texts = [conv.get_prompt() for conv in convs]
    return texts
class BaseTask:
    def __init__(self, **kwargs):
        super().__init__()

        self.conv_temp = CONV_VISION_LLama2.copy()
        self.conv_temp.system = ""
        self.inst_id_key = "instance_id"
        self.cfg = ""
        self.sample_id_list=[]

    @classmethod
    def setup_task(cls, **kwargs):
        return cls()

    def build_model(self, cfg):
        self.cfg = cfg
        model_config = cfg.model_cfg
        # print(model_config.arch)

        model_cls = registry.get_model_class(model_config.arch)
        # print('model_cls',model_cls)
        return model_cls.from_config(model_config)

    def build_datasets(self, cfg):
        """
        Build a dictionary of datasets, keyed by split 'train', 'valid', 'test'.
        Download dataset and annotations automatically if not exist.

        Args:
            cfg (common.config.Config): _description_

        Returns:
            dict: Dictionary of torch.utils.data.Dataset objects by split.
        """

        datasets = dict()

        datasets_config = cfg.datasets_cfg
        eval_dataset=None
        assert len(datasets_config) > 0, "At least one dataset has to be specified."

        for name in datasets_config:
            # print(name)
            if 'eval' not in name:
                dataset_config = datasets_config[name]
                # print(dataset_config,name)
                # print(registry.get_builder_class(name))
                builder = registry.get_builder_class(name)(dataset_config)
                # print(builder)
                dataset = builder.build_datasets()
                # print(dataset)
                dataset['train'].name = name
                # print(dataset,dataset['train'].name)
                # exit()
                if 'sample_ratio' in dataset_config:
                    dataset['train'].sample_ratio = dataset_config.sample_ratio

                datasets[name] = dataset
            else:
                dataset_config = datasets_config[name]
                # print(dataset_config,name)
                # print(registry.get_builder_class(name))
                builder = registry.get_builder_class(name)(dataset_config)
                # print(builder)
                eval_dataset = builder.build_datasets()
                # print(eval_dataset[0,20,46])
                # exit()
                for data in eval_dataset:
                    # print(data[2])
                    # exit()
                    # print(data['id'])
                    self.sample_id_list.append(data[2])

                # exit()



        # print('datasets',datasets,datasets_config)
        # exit()

        return datasets,eval_dataset

    def train_step(self, model, samples,loss_mask_dict):
        # print('model',model)
        # exit()
        loss = model(samples,loss_mask_dict)["loss"]
        return loss

    def valid_step(self, model, samples):
        raise NotImplementedError

    def before_evaluation(self, model, dataset, **kwargs):
        model.before_evaluation(dataset=dataset, task_type=type(self))

    def after_evaluation(self, **kwargs):
        pass

    def inference_step(self):
        raise NotImplementedError

    def evaluation(self, model, data_loader, cuda_enabled=True):
        metric_logger = MetricLogger(delimiter="  ")
        header = "Evaluation"
        # TODO make it configurable
        print_freq = 10

        results = []

        for samples in metric_logger.log_every(data_loader, print_freq, header):
            samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
            eval_output = self.valid_step(model=model, samples=samples)
            results.extend(eval_output)

        if is_dist_avail_and_initialized():
            dist.barrier()

        return results
    def unwrap_dist_model(self, model):
        # if self.use_distributed:
        return model.module
        # else:
        #     return model
    def train_epoch(
        self,
        epoch,
        model,
        data_loader,
        optimizer,
        lr_scheduler,
        scaler=None,
        cuda_enabled=False,
        log_freq=50,
        loss_mask_dict=None,
        accum_grad_iters=1,
        eval_dataset=None,
        slogan=None,
        tgt_words=None,
        gth_decision_label=None,

    ):

        return self._train_inner_loop(
            epoch=epoch,
            iters_per_epoch=12,
            model=model,
            data_loader=data_loader,
            optimizer=optimizer,
            scaler=scaler,
            lr_scheduler=lr_scheduler,
            log_freq=log_freq,
            loss_mask_dict=loss_mask_dict,
            cuda_enabled=cuda_enabled,
            accum_grad_iters=accum_grad_iters,
            eval_dataset=eval_dataset,
            slogan=slogan,
            tgt_words=tgt_words,
            gth_decision_label=gth_decision_label,
        )

    def train_iters(
        self,
        epoch,
        start_iters,
        iters_per_inner_epoch,
        model,
        data_loader,
        optimizer,
        lr_scheduler,
        scaler=None,
        cuda_enabled=False,
        log_freq=50,
        accum_grad_iters=1,
    ):
        return self._train_inner_loop(
            epoch=epoch,
            start_iters=start_iters,
            iters_per_epoch=iters_per_inner_epoch,
            model=model,
            data_loader=data_loader,
            optimizer=optimizer,
            scaler=scaler,
            lr_scheduler=lr_scheduler,
            log_freq=log_freq,
            cuda_enabled=cuda_enabled,
            accum_grad_iters=accum_grad_iters,
        )
    def prepare_eval_data(self,id_list):
        data_list=[]
        exist_list=[]
        for idx,id in enumerate(id_list):
            exist_list.append(id)
            if (idx+1)%64==0:
                data_list.append(exist_list.copy())
                exist_list=[]
        data_list.append(exist_list)
        out_data=[]
        # print(data_list)
        for batch_id_list in data_list:
            if len(batch_id_list) !=0:
                batch_list=[]
                for i in batch_id_list:
                    batch_list.append(self.eval_dataset[i])
                aa=[]
                aa.append(torch.stack([ele[0] for ele in batch_list],axis=0))
                aa.append(tuple([ele[1] for ele in batch_list]))
                aa.append(tuple([ele[2] for ele in batch_list]))
                out_data.append(aa.copy())
        return out_data
    def update_loss_mask_dict(self,loss_mask_dict,id_list,training_model):
        eval_model=self.unwrap_dist_model(training_model)
        eval_model.eval()
        max_new_tokens = 256
        answer_list = []
        qid_list = []
        answer_dict = {}

        eval_data=self.prepare_eval_data(id_list)

        for batch in eval_data:
            images, questions, question_ids=batch[0],batch[1],batch[2]
            # print(self.conv_temp,images,questions,question_ids)
            # print('here')
            texts = prepare_texts(questions, self.conv_temp)  # warp the texts with conversation template
            answers = eval_model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False)
            # print('there')
            answer_list.extend(answers)
            qid_list.extend(question_ids)
            # break
        for qid, answer in zip(qid_list, answer_list):
            answer_dict[str(qid)] = answer
        self.update_loss_mask(answer_dict)
        # return loss_mask_dict
    def update_loss_mask(self,answer_dict):
        loss_mask_dict={}
        for k,v in answer_dict.items():
            tgt_label=self.gth_decision_label[k]
            v = v.replace(self.slogan, "")
            pred_label=contains_vehicle_words(v, self.tgt_words)
            if pred_label ==tgt_label:
                self.loss_mask_dict[k]=1
            else:
                self.loss_mask_dict[k]=0
        return loss_mask_dict
    def _train_inner_loop(
        self,
        epoch,
        iters_per_epoch,
        model,
        data_loader,
        optimizer,
        lr_scheduler,
        scaler=None,
        start_iters=None,
        log_freq=1,
        cuda_enabled=False,
        loss_mask_dict=None,
        accum_grad_iters=1,
        eval_dataset=None,
        slogan=None,
        tgt_words=None,
        gth_decision_label=None,
    ):
        """
        An inner training loop compatible with both epoch-based and iter-based training.

        When using epoch-based, training stops after one epoch; when using iter-based,
        training stops after #iters_per_epoch iterations.
        """
        self.eval_dataset=eval_dataset
        self.slogan = slogan
        self.tgt_words = tgt_words
        self.gth_decision_label = gth_decision_label
        self.loss_mask_dict=loss_mask_dict
        log_freq=1
        # print(self.loss_mask_dict)
        # print(self.sample_id_list,len(self.sample_id_list))
        #
        # exit()
        # self.loss_mask_dict=loss_mask_dict
        use_amp = scaler is not None
        # accum_grad_iters = 8
        if not hasattr(data_loader, "__next__"):
            # convert to iterator if not already
            data_loader = iter(data_loader)

        metric_logger = MetricLogger(delimiter="  ")
        metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}"))
        metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4f}"))

        # if iter-based runner, schedule lr based on inner epoch.
        logging.info(
            "Start training epoch {}, {} iters per inner epoch.".format(
                epoch, iters_per_epoch
            )
        )
        header = "Train: data epoch: [{}]".format(epoch)
        if start_iters is None:
            # epoch-based runner
            inner_epoch = epoch
        else:
            # In iter-based runner, we schedule the learning rate based on iterations.
            inner_epoch = start_iters // iters_per_epoch
            header = header + "; inner epoch [{}]".format(inner_epoch)

        for i in metric_logger.log_every(range(iters_per_epoch), log_freq, header):
            # if using iter-based runner, we stop after iters_per_epoch iterations.
            if i >= iters_per_epoch:
                break
            sample_list=[]
            # bbb=[]
            # id_list = []
            for time in range(accum_grad_iters):
                samples = next(data_loader)
                # conv_q = samples['conv_q']
                #
                # connect_sym = samples['connect_sym'][0]
                # conv_q = [q.split(connect_sym) for q in conv_q]
                # print(conv_q)
                # exit()
                sample_list.append(samples)

            #     for id,questions in zip(samples['image_id'],conv_q):
            #         len_question=len(questions)
            #         for i in range(len_question):
            #             if f"{id}_{str(i)}" in self.sample_id_list:
            #
            #             # else:
            #                 bbb.append(f"{id}_{str(i)}")
            #                 id_list.append(self.sample_id_list.index(f"{id}_{str(i)}"))
            #     # print(len(bbb),len(id_list))
            #     # for id,bb in zip(id_list,bbb):
            #     #     print(id,self.loss_mask_dict[self.sample_id_list[id]],bb)
            # # print('start')
            # self.update_loss_mask_dict(self.loss_mask_dict,id_list,model)
            # print('end')
            # print(len(id_list))
            # exit()                # print('afterrrrrr')

                # for id,bb in zip(id_list,bbb):
                #     print(id,self.loss_mask_dict[self.sample_id_list[id]],bb)
                # exit()


            for iter_id,samples in enumerate(sample_list):

            # samples = next(data_loader)

            # exit()
            # print('samples_before',samples)

                samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
                # print('samples_after', samples)
                # print(samples.keys())
                # exit()
                samples.update(
                    {
                        "epoch": inner_epoch,
                        "num_iters_per_epoch": iters_per_epoch,
                        "iters": i,
                    }
                )

                lr_scheduler.step(cur_epoch=inner_epoch, cur_step=i*accum_grad_iters+iter_id)

                with torch.cuda.amp.autocast(enabled=use_amp):
                    loss = self.train_step(model=model, samples=samples,loss_mask_dict=loss_mask_dict)

            # after_train_step()
                if use_amp:

                    scaler.scale(loss).backward()
                else:
                    loss.backward()

            # update gradients every accum_grad_iters iterations

            if ((i + 1)*len(sample_list)) % accum_grad_iters == 0:
                # print('accum_here',i+1)
                if use_amp:
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    optimizer.step()
                # print(accum_grad_iters)
                # exit()
                optimizer.zero_grad()
                # if self.cfg.wandb_log:
                if self.cfg.run_cfg.wandb_log:
                    wandb.log({"epoch": inner_epoch, "loss": loss})
            metric_logger.update(loss=loss.item())
            metric_logger.update(lr=optimizer.param_groups[0]["lr"])

        # after train_epoch()
        # gather the stats from all processes
        metric_logger.synchronize_between_processes()
        logging.info("Averaged stats: " + str(metric_logger.global_avg()))
        return {
            k: "{:.3f}".format(meter.global_avg)
            for k, meter in metric_logger.meters.items()
        }

    @staticmethod
    def save_result(result, result_dir, filename, remove_duplicate=""):
        import json

        result_file = os.path.join(
            result_dir, "%s_rank%d.json" % (filename, get_rank())
        )
        final_result_file = os.path.join(result_dir, "%s.json" % filename)

        json.dump(result, open(result_file, "w"))

        if is_dist_avail_and_initialized():
            dist.barrier()

        if is_main_process():
            logging.warning("rank %d starts merging results." % get_rank())
            # combine results from all processes
            result = []

            for rank in range(get_world_size()):
                result_file = os.path.join(
                    result_dir, "%s_rank%d.json" % (filename, rank)
                )
                res = json.load(open(result_file, "r"))
                result += res

            if remove_duplicate:
                result_new = []
                id_list = []
                for res in result:
                    if res[remove_duplicate] not in id_list:
                        id_list.append(res[remove_duplicate])
                        result_new.append(res)
                result = result_new

            json.dump(result, open(final_result_file, "w"))
            print("result file saved to %s" % final_result_file)

        return final_result_file
