import sys
import torch
from PIL import Image
from .base import BaseModel
from ..smp import *
from ..dataset import DATASET_TYPE


class mPLUG_Owl2(BaseModel):

    INSTALL_REQ = True
    INTERLEAVE = False

    def __init__(self, model_path='MAGAer13/mplug-owl2-llama2-7b', **kwargs):
        try:
            from mplug_owl2.model.builder import load_pretrained_model
            from mplug_owl2.mm_utils import get_model_name_from_path
        except Exception as e:
            logging.critical('Please install mPLUG_Owl2 before using mPLUG_Owl2. ')
            raise e

        model_name = get_model_name_from_path(model_path)
        tokenizer, model, image_processor, context_len = load_pretrained_model(
            model_path, None, model_name, load_8bit=False, load_4bit=False, device='cpu')

        self.model = model.cuda()
        self.device = self.model.device
        self.image_processor = image_processor
        tokenizer.padding_side = 'left'
        tokenizer.pad_token_id = tokenizer.eos_token_id
        self.tokenizer = tokenizer
        self.context_len = context_len

        kwargs_default = dict(
            max_new_tokens=512, do_sample=False, num_beams=1,
            min_new_tokens=1, length_penalty=1, num_return_sequences=1)
        kwargs_default.update(kwargs)
        self.kwargs = kwargs_default
        warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')

    def use_custom_prompt(self, dataset):
        assert dataset is not None
        if DATASET_TYPE(dataset) == 'MCQ':
            return True
        return False

    def build_prompt(self, line, dataset=None):
        assert dataset is None or isinstance(dataset, str)
        assert self.use_custom_prompt(dataset)
        tgt_path = self.dump_image(line, dataset)
        question = line['question']
        if DATASET_TYPE(dataset) == 'MCQ':
            options = {
                cand: line[cand]
                for cand in string.ascii_uppercase
                if cand in line and not pd.isna(line[cand])
            }
            options_prompt = ''
            for key, item in options.items():
                options_prompt += f'{key}. {item}\n'

            hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
            prompt = f'Hint: {hint}\n' if hint is not None else ''
            prompt += f'{question}\n'
            prompt += (
                f'{options_prompt}\nAnswer with the option’s letter from the given choices directly. '
                if len(options) else 'Answer the question directly. '
            )
        else:
            raise NotImplementedError

        message = [dict(type='text', value=prompt)]
        message.extend([dict(type='image', value=s) for s in tgt_path])
        return message

    def generate_inner(self, message, dataset=None):
        from mplug_owl2.constants import IMAGE_TOKEN_INDEX
        from mplug_owl2.mm_utils import process_images, tokenizer_image_token
        kwargs = cp.deepcopy(self.kwargs)
        if dataset is not None and DATASET_TYPE(dataset) == 'VQA':
            kwargs['length_penalty'] = 0
        elif dataset is not None and DATASET_TYPE(dataset) == 'MCQ':
            kwargs['max_new_tokens'] = 10
        num_images = len([x for x in message if x['type'] == 'image'])
        assert num_images >= 0
        prompt_full = 'USER: '
        images = []
        if num_images == 1:
            prompt, image = self.message_to_promptimg(message, dataset=dataset)
            prompt_full += f'<|image|>{prompt} \nASSISTANT: '
            images.append(image)
        else:
            for msg in message:
                if msg['type'] == 'image':
                    images.append(msg['value'])
                    prompt_full += '<|image|>'
                elif msg['type'] == 'text':
                    prompt_full += msg['value']
            prompt_full += '\nASSISTANT: '

        def preproc_image(fname):
            image = Image.open(fname).convert('RGB')
            max_edge = max(image.size)
            image = image.resize((max_edge, max_edge))
            return image
        images = [preproc_image(fname) for fname in images]
        image_tensor = process_images(images, self.image_processor)
        image_tensor = image_tensor.to(self.device, dtype=torch.float16)
        input_ids = tokenizer_image_token(
            prompt_full, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)

        with torch.inference_mode():
            output_ids = self.model.generate(
                input_ids=input_ids,
                images=image_tensor,
                output_hidden_states=True,
                use_cache=True,
                **kwargs)
        answer = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
        return answer.split('</s>')[0]
