import torch
import sys
import os.path as osp
import warnings
from .base import BaseModel
from transformers import StoppingCriteriaList

from PIL import Image
from huggingface_hub import snapshot_download
from vlmeval.smp import *

model_cfgs = {
    'XVERSE-V-13B': {
        'arch': 'vxverse',
        'model_type': 'pretrain_xverse13b-chat',
        'max_txt_len': 512,
        'end_sym': '<|endoftext|>',
        'low_resource': False,
        'prompt_template': 'Human: {}\nAssistant: ',
        'ckpt': 'xverse/XVERSE-V-13B',
        'lora_r': 128,
        'lora_alpha': 256,
        'lora_dropout': 0.05,
        'lora_target_modules': 'all_linear',
        'has_qformer': False,
        'n_proj_layers': 2,
        'vit_model': 'openai/clip-vit-large-patch14',
        'vit_path': 'openai/clip-vit-large-patch14',
        'image_size': 224,
        'drop_path_rate': 0,
        'vit_precision': 'fp16',
        'llama_model': 'xverse/XVERSE-13B-Chat',
    }
}


class VXVERSE(BaseModel):

    INSTALL_REQ = True
    INTERLEAVE = False

    def __init__(self, model_name='XVERSE-V-13B', root=None, **kwargs):
        from omegaconf import OmegaConf
        if root is None:
            warnings.warn('Please set root to the directory of vxverse.')

        if model_name == 'XVERSE-V-13B':
            cfg = model_cfgs['XVERSE-V-13B']
        else:
            raise NotImplementedError

        ckpt_dir = cfg['ckpt']
        if not osp.isdir(ckpt_dir):
            cache_path = get_cache_path(ckpt_dir)
            if cache_path is not None:
                ckpt_dir = cache_path
            else:
                ckpt_dir = snapshot_download(repo_id=ckpt_dir)
        assert osp.exists(ckpt_dir) and osp.isdir(ckpt_dir)
        ckpt = osp.join(ckpt_dir, 'adapter_and_lora.bin')
        cfg['ckpt'] = ckpt
        model_cfg = OmegaConf.create(cfg)

        self.model_name = model_name

        self.root = root
        sys.path.append(self.root)

        from vxverse.common.registry import registry
        from vxverse.conversation.conversation import CONV_VISION_XVERSE

        device = torch.cuda.current_device()
        self.device = device

        model_cls = registry.get_model_class(model_cfg.arch)
        model = model_cls.from_config(model_cfg)
        model = model.to(device)
        model.eval()
        vis_processor_cfg = OmegaConf.create(dict(name='hd_image_train', image_size=224))
        vis_processor = registry.get_processor_class(
            vis_processor_cfg.name
        ).from_config(vis_processor_cfg)

        self.model = model
        self.vis_processor = vis_processor
        self.vis_processor_cfg = vis_processor_cfg

        self.CONV_VISION = CONV_VISION_XVERSE
        self.CONV_VISION.system = ''
        stop_words_ids = [[835], [2277, 29937]]
        self.stop_words_ids = stop_words_ids
        default_kwargs = dict(max_new_tokens=512)
        default_kwargs.update(kwargs)
        self.kwargs = default_kwargs

    def generate_inner(self, message, dataset=None):
        prompt, image_path = self.message_to_promptimg(message, dataset=dataset)

        image = Image.open(image_path).convert('RGB')
        image = self.vis_processor(image)

        if self.vis_processor_cfg.name == 'hd_image_train':
            patches_per_image = [[image.shape[0]]]
            image = [image]
        else:
            patches_per_image = None
            image = image.unsqueeze(0)

        chat_state = self.CONV_VISION.copy()
        texts = self.prepare_texts([prompt], chat_state)
        texts = [text.lstrip() for text in texts]
        answers = self.model.generate(
            image,
            texts,
            patches_per_images=patches_per_image,
            do_sample=False,
            stop_words_ids=self.stop_words_ids,
            **self.kwargs
        )
        return answers[0]

    def prepare_texts(self, texts, conv_temp):
        convs = [conv_temp.copy() for _ in range(len(texts))]
        [
            conv.append_message(conv.roles[0], '<ImageHere>\n{}'.format(text))
            for conv, text in zip(convs, texts)
        ]
        [conv.append_message(conv.roles[1], None) for conv in convs]
        texts = [conv.get_prompt() for conv in convs]
        return texts
