from ..smp import *
from ..dataset import img_root_map, DATASET_TYPE
from abc import abstractmethod


class BaseModel:

    INTERLEAVE = False
    allowed_types = ['text', 'image', 'video']

    def __init__(self):
        self.dump_image_func = None

    def use_custom_prompt(self, dataset):
        """Whether to use custom prompt for the given dataset.

        Args:
            dataset (str): The name of the dataset.

        Returns:
            bool: Whether to use custom prompt. If True, will call `build_prompt` of the VLM to build the prompt.
                Default to False.
        """
        return False

    @abstractmethod
    def build_prompt(self, line, dataset):
        """Build custom prompts for a specific dataset. Called only if `use_custom_prompt` returns True.

        Args:
            line (line of pd.DataFrame): The raw input line.
            dataset (str): The name of the dataset.

        Returns:
            str: The built message.
        """
        raise NotImplementedError

    def set_dump_image(self, dump_image_func):
        self.dump_image_func = dump_image_func

    def dump_image(self, line, dataset):
        return self.dump_image_func(line)

    @abstractmethod
    def generate_inner(self, message, dataset=None):
        raise NotImplementedError

    def check_content(self, msgs):
        """Check the content type of the input. Four types are allowed: str, dict, liststr, listdict.
        """
        if isinstance(msgs, str):
            return 'str'
        if isinstance(msgs, dict):
            return 'dict'
        if isinstance(msgs, list):
            types = [self.check_content(m) for m in msgs]
            if all(t == 'str' for t in types):
                return 'liststr'
            if all(t == 'dict' for t in types):
                return 'listdict'
        return 'unknown'

    def preproc_content(self, inputs):
        """Convert the raw input messages to a list of dicts.

        Args:
            inputs: raw input messages.

        Returns:
            list(dict): The preprocessed input messages. Will return None if failed to preprocess the input.
        """

        if self.check_content(inputs) == 'str':
            return [dict(type='text', value=inputs)]
        elif self.check_content(inputs) == 'dict':
            assert 'type' in inputs and 'value' in inputs
            return [inputs]
        elif self.check_content(inputs) == 'liststr':
            res = []
            for s in inputs:
                mime, pth = parse_file(s)
                if mime is None or mime == 'unknown':
                    res.append(dict(type='text', value=s))
                else:
                    res.append(dict(type=mime.split('/')[0], value=pth))
            return res
        elif self.check_content(inputs) == 'listdict':
            for item in inputs:
                assert 'type' in item and 'value' in item
                mime, s = parse_file(item['value'])
                if mime is None:
                    assert item['type'] == 'text'
                else:
                    assert mime.split('/')[0] == item['type']
                    item['value'] = s
            return inputs
        else:
            return None

    def generate(self, message, dataset=None):
        """Generate the output message.

        Args:
            message (list[dict]): The input message.
            dataset (str, optional): The name of the dataset. Defaults to None.

        Returns:
            str: The generated message.
        """
        assert self.check_content(message) in ['str', 'dict', 'liststr', 'listdict'], f'Invalid input type: {message}'
        message = self.preproc_content(message)
        assert message is not None and self.check_content(message) == 'listdict'
        for item in message:
            assert item['type'] in self.allowed_types, f'Invalid input type: {item["type"]}'
        return self.generate_inner(message, dataset)

    def chat(self, messages, dataset=None):
        """The main function for multi-turn chatting. Will call `chat_inner` with the preprocessed input messages."""
        assert hasattr(self, 'chat_inner'), 'The API model should has the `chat_inner` method. '
        for msg in messages:
            assert isinstance(msg, dict) and 'role' in msg and 'content' in msg, msg
            assert self.check_content(msg['content']) in ['str', 'dict', 'liststr', 'listdict'], msg
            msg['content'] = self.preproc_content(msg['content'])

        while len(messages):
            try:
                return self.chat_inner(messages, dataset=dataset)
            except Exception as e:
                logging.info(f'{type(e)}: {e}')
                messages = messages[1:]
                while len(messages) and messages[0]['role'] != 'user':
                    messages = messages[1:]
                continue
        return 'Chat Mode: Failed with all possible conversation turns.'

    def message_to_promptimg(self, message, dataset=None):
        assert not self.INTERLEAVE
        model_name = self.__class__.__name__
        warnings.warn(
            f'Model {model_name} does not support interleaved input. '
            'Will use the first image and aggregated texts as prompt. ')
        num_images = len([x for x in message if x['type'] == 'image'])
        if num_images == 0:
            prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
            image = None
        else:
            prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
            images = [x['value'] for x in message if x['type'] == 'image']
            if 'BLINK' == dataset:
                image = concat_images_vlmeval(images, target_size=512)
            else:
                image = images[0]
        return prompt, image

    def message_to_promptvideo(self, message):
        if self.VIDEO_LLM:
            num_videos = len([x for x in message if x['type'] == 'video'])
            if num_videos == 0:
                prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
                video = None
            else:
                prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
                video = [x['value'] for x in message if x['type'] == 'video'][0]
            return prompt, video
        else:
            logging.critical('Model does not support video input.')
            raise NotImplementedError

    def message_to_promptvideo_withrole(self, message, dataset=None):
        if self.VIDEO_LLM:
            system, user, assistant, video_list = '', '', '', []
            for msg in message:
                if msg['type'] == 'text':
                    if 'role' in msg and msg['role'] == 'system':
                        system += msg['value']
                    elif 'role' in msg and msg['role'] == 'assistant':
                        assistant += msg['value']
                    else:
                        user += msg['value']
                elif msg['type'] == 'video':
                    video_list.append(msg['value'])
            question = {
                'system': system,
                'user': user,
                'assistant': assistant
            }
            if assistant == '':
                if listinstr(['MCQ'], DATASET_TYPE(dataset)):
                    question['assistant'] = 'Best Option: ('
                else:
                    del question['assistant']
            if len(video_list) > 1:
                print('VLMEvalKit only support single video as input, take first video as input')
            video = video_list[0]
            return question, video
        else:
            logging.critical('Model does not support video input.')
            raise NotImplementedError

    def message_to_lmdeploy(self, messages, system_prompt=None):
        from lmdeploy.vl.constants import IMAGE_TOKEN
        from PIL import Image
        prompt, image_path = '', []
        for msg in messages:
            if msg['type'] == 'text':
                prompt += msg['value']
            elif msg['type'] == 'image':
                prompt += IMAGE_TOKEN
                image_path.append(msg['value'])
        content = [{'type': 'text', 'text': prompt}]
        for image in image_path:
            img = Image.open(image).convert('RGB')
            b64 = encode_image_to_base64(img)
            img_struct = dict(url=f'data:image/jpeg;base64,{b64}')
            content.append(dict(type='image_url', image_url=img_struct))
        ret = []
        if system_prompt is not None:
            ret.append(dict(role='system', content=system_prompt))
        ret.append(dict(role='user', content=content))
        return [ret]
