import torch
import sys
import os.path as osp
import warnings
from transformers import StoppingCriteriaList
from .base import BaseModel


class MiniGPT4(BaseModel):

    INSTALL_REQ = True
    INTERLEAVE = False

    def __init__(self,
                 mode='v2',
                 root='/mnt/petrelfs/share_data/duanhaodong/MiniGPT-4/',
                 temperature=1,
                 max_out_len=512):

        if root is None:
            warnings.warn(
                'Please set root to the directory of MiniGPT-4, which is cloned from here: '
                'https://github.com/Vision-CAIR/MiniGPT-4. '
            )

        if mode == 'v2':
            cfg = 'minigptv2_eval.yaml'
        elif mode == 'v1_7b':
            cfg = 'minigpt4_7b_eval.yaml'
        elif mode == 'v1_13b':
            cfg = 'minigpt4_13b_eval.yaml'
        else:
            raise NotImplementedError

        self.mode = mode
        self.temperature = temperature
        self.max_out_len = max_out_len
        self.root = root
        this_dir = osp.dirname(__file__)

        self.cfg = osp.join(this_dir, 'misc', cfg)
        sys.path.append(self.root)

        from omegaconf import OmegaConf
        from minigpt4.common.registry import registry
        from minigpt4.conversation.conversation import StoppingCriteriaSub, CONV_VISION_Vicuna0, CONV_VISION_minigptv2

        device = torch.cuda.current_device()
        self.device = device

        cfg_path = self.cfg
        cfg = OmegaConf.load(cfg_path)

        model_cfg = cfg.model
        model_cfg.device_8bit = device
        model_cls = registry.get_model_class(model_cfg.arch)
        model = model_cls.from_config(model_cfg)
        model = model.to(device)
        model.eval()
        vis_processor_cfg = cfg.datasets.cc_sbu_align.vis_processor.train
        vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
        self.model = model
        self.vis_processor = vis_processor

        self.CONV_VISION = CONV_VISION_minigptv2 if self.mode == 'v2' else CONV_VISION_Vicuna0
        stop_words_ids = [[835], [2277, 29937]]
        stop_words_ids = [torch.tensor(ids).to(device) for ids in stop_words_ids]
        self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])

    def generate_inner(self, message, dataset=None):
        from minigpt4.conversation.conversation import Chat
        prompt, image_path = self.message_to_promptimg(message, dataset=dataset)
        if self.mode == 'v2':
            chat = Chat(self.model, self.vis_processor, device=self.device)
        else:
            chat = Chat(self.model, self.vis_processor, device=self.device, stopping_criteria=self.stopping_criteria)

        chat_state = self.CONV_VISION.copy()
        img_list = []
        _ = chat.upload_img(image_path, chat_state, img_list)
        chat.encode_img(img_list)
        chat.ask(prompt, chat_state)
        with torch.inference_mode():
            msg = chat.answer(conv=chat_state, img_list=img_list)[0]
        return msg
