# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional

import torch

from mmpretrain.registry import MODELS, TOKENIZER
from mmpretrain.structures import DataSample
from ..flamingo.flamingo import ExtendModule, Flamingo, PerceiverResampler


@MODELS.register_module()
class Otter(Flamingo):
    """The Otter model for multiple tasks.

    Args:
        vision_encoder (dict): The config of the vision encoder.
        lang_encoder (dict): The config of the language encoder.
        tokenizer (dict): The tokenizer to encode the text.
        task (int): The task to perform prediction.
        zeroshot_prompt (str): Prompt used for zero-shot inference.
            Defaults to an.
        shot_prompt_tmpl (str): Prompt used for few-shot inference.
            Defaults to ``<image>User:Please describe the image.
            GPT:<answer>{caption}<|endofchunk|>``.
        final_prompt_tmpl (str): Final part of prompt used for inference.
            Defaults to '<image>User:Please describe the image. GPT:<answer>'.
        generation_cfg (dict): The extra generation config, accept the keyword
            arguments of [~`transformers.GenerationConfig`].
            Defaults to an empty dict.
        data_preprocessor (Optional[dict]): The config for preprocessing input
            data. If None or no specified type, it will use
            "MutimodalDataPreprocessor" as type.
            See :class:`MutimodalDataPreprocessor` for more details.
            Defaults to None.
        init_cfg (dict, optional): The initialization config. Defaults to None.
    """

    support_tasks = {'caption', 'vqa'}
    _no_split_modules = [
        'TransformerEncoderLayer', 'PerceiverAttention',
        'GatedCrossAttentionBlock', 'FlamingoLayer'
    ]

    def __init__(
            self,
            vision_encoder: dict,
            lang_encoder: dict,
            tokenizer: dict,
            task: str = 'caption',
            zeroshot_prompt: str = '',
            shot_prompt_tmpl: str = ('<image>User:Please describe the image. '
                                     'GPT:<answer>{caption}<|endofchunk|>'),
            final_prompt_tmpl: str = ('<image>User:Please describe the image. '
                                      'GPT:<answer>'),
            generation_cfg: dict = dict(),
            data_preprocessor: Optional[dict] = None,
            init_cfg: Optional[dict] = None):
        if data_preprocessor is None:
            data_preprocessor = {}
        if isinstance(data_preprocessor, dict):
            data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor')
            data_preprocessor = MODELS.build(data_preprocessor)

        super(Flamingo, self).__init__(
            init_cfg=init_cfg, data_preprocessor=data_preprocessor)

        if task not in self.support_tasks:
            raise ValueError(f'Unsupported task {task}, please select '
                             f'the task from {self.support_tasks}.')
        self.task = task

        # init tokenizer
        self.tokenizer = TOKENIZER.build(tokenizer)
        # add Otter special tokens to the tokenizer
        self.tokenizer.add_special_tokens({
            'additional_special_tokens':
            ['<|endofchunk|>', '<image>', '<answer>']
        })
        self.tokenizer.bos_token_id = 1
        if self.tokenizer.pad_token is None:
            # Issue: GPT models don't have a pad token, which we use to
            # modify labels for the loss.
            self.tokenizer.add_special_tokens({'pad_token': '<PAD>'})

        # Template to format the prompt input
        self.zeroshot_prompt = zeroshot_prompt
        self.shot_prompt_tmpl = shot_prompt_tmpl
        self.final_prompt_tmpl = final_prompt_tmpl

        # init vision encoder related modules
        vision_encoder_weight = vision_encoder.pop('pretrained', None)
        self.vision_encoder = MODELS.build(vision_encoder)
        if vision_encoder_weight is not None:
            from mmengine.runner.checkpoint import load_checkpoint
            load_checkpoint(
                self.vision_encoder,
                vision_encoder_weight,
                map_location='cpu',
                revise_keys=[(r'^backbone\.', '')],
            )
            self.vision_encoder.is_init = True

        self.perceiver = PerceiverResampler(dim=self.vision_encoder.embed_dims)

        # init language encoder related modules
        self.lang_encoder = ExtendModule(**lang_encoder)
        self.lang_encoder.resize_token_embeddings(len(self.tokenizer))
        self.lang_encoder.media_token_id = self.tokenizer.encode('<image>')[-1]

        # other necessary parameters
        self.eoc_token_id = self.tokenizer.encode('<|endofchunk|>')[-1]
        self.generation_cfg = generation_cfg

        if hasattr(self, 'register_load_state_dict_post_hook'):
            self.register_load_state_dict_post_hook(self._load_adapter_hook)

    def post_process(
            self, outputs: torch.Tensor,
            data_samples: Optional[List[DataSample]]) -> List[DataSample]:
        """Perform post process for outputs for different task.

        Args:
            outputs (torch.Tensor): The generated outputs.
            data_samples (List[DataSample], optional): The annotation
                data of every samples.

        Returns:
            List[DataSample]: Return list of data samples.
        """
        outputs = self.tokenizer.batch_decode(
            outputs, skip_special_tokens=True)

        if data_samples is None:
            data_samples = [DataSample() for _ in range(len(outputs))]

        for output, data_sample in zip(outputs, data_samples):
            # remove text pattern
            if self.task == 'caption':
                data_sample.pred_caption = output
            elif self.task == 'vqa':
                data_sample.pred_answer = output

        return data_samples
