import torch
from PIL import Image
from abc import abstractproperty
import sys
import os.path as osp
from .base import BaseModel
from ..smp import *
from ..dataset import DATASET_TYPE
import copy


class VILA(BaseModel):
    INSTALL_REQ = True
    INTERLEAVE = True

    def __init__(self,
                 model_path='Efficient-Large-Model/Llama-3-VILA1.5-8b',
                 **kwargs):
        try:
            from llava.model.builder import load_pretrained_model
            from llava.mm_utils import get_model_name_from_path
            from llava.mm_utils import process_images, tokenizer_image_token, KeywordsStoppingCriteria
            from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN # noqa E501
            from llava.conversation import conv_templates, SeparatorStyle
        except Exception as err:
            logging.critical('Please install VILA before using VILA')
            logging.critical('Please install VILA from https://github.com/NVlabs/VILA')
            logging.critical('Please install VLMEvalKit after installing VILA')
            logging.critical('VILA is supported only with transformers==4.36.2')
            raise err

        warnings.warn('Please install the latest version of VILA from GitHub before you evaluate the VILA model.')
        assert osp.exists(model_path) or len(model_path.split('/')) == 2

        model_name = get_model_name_from_path(model_path)

        try:
            self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
                model_path=model_path,
                model_base=None,
                model_name=model_name,
                device='cpu',
                device_map='cpu'
            )
        except Exception as err:
            logging.critical('Error loading VILA model: ')
            raise err

        self.model = self.model.cuda()
        if '3b' in model_path:
            self.conv_mode = 'vicuna_v1'
        if '8b' in model_path:
            self.conv_mode = 'llama_3'
        elif '13b' in model_path:
            self.conv_mode = 'vicuna_v1'
        elif '40b' in model_path:
            self.conv_mode = 'hermes-2'

        kwargs_default = dict(do_sample=False, temperature=0, max_new_tokens=2048, top_p=None, num_beams=1, use_cache=True) # noqa E501

        kwargs_default.update(kwargs)
        self.kwargs = kwargs_default
        warnings.warn(f'Using the following kwargs for generation config: {self.kwargs}')

        self.conv_templates = conv_templates
        self.process_images = process_images
        self.tokenizer_image_token = tokenizer_image_token
        self. DEFAULT_IMAGE_TOKEN = DEFAULT_IMAGE_TOKEN
        self.SeparatorStyle = SeparatorStyle
        self.IMAGE_TOKEN_INDEX = IMAGE_TOKEN_INDEX
        self.KeywordsStoppingCriteria = KeywordsStoppingCriteria

    def use_custom_prompt(self, dataset):
        assert dataset is not None
        # TODO see if custom prompt needed
        return False

    def generate_inner(self, message, dataset=None):

        content, images = '', []

        for msg in message:
            if msg['type'] == 'text':
                content += msg['value']
            elif msg['type'] == 'image':
                image = Image.open(msg['value']).convert('RGB')
                images.append(image)
                content += (self.DEFAULT_IMAGE_TOKEN + '\n')

        image_tensor = self.process_images(
            images, self.image_processor,
            self.model.config).to(self.model.device, dtype=torch.float16)

        # Support interleave text and image
        conv = self.conv_templates[self.conv_mode].copy()
        conv.append_message(conv.roles[0], content)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        input_ids = self.tokenizer_image_token(prompt, self.tokenizer, self.IMAGE_TOKEN_INDEX,
                                               return_tensors='pt').unsqueeze(0).cuda()

        stop_str = conv.sep if conv.sep_style != self.SeparatorStyle.TWO else conv.sep2
        keywords = [stop_str]
        stopping_criteria = self.KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)

        with torch.inference_mode():
            output_ids = self.model.generate(
                input_ids, images=image_tensor, stopping_criteria=[stopping_criteria], **self.kwargs)

            output = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
        return output
