from PIL import Image
import torch

from .base import BaseModel
from ..smp import *


class XGenMM(BaseModel):

    INSTALL_REQ = False
    INTERLEAVE = True

    def __init__(self, model_path='Salesforce/xgen-mm-phi3-mini-instruct-interleave-r-v1.5', **kwargs):
        try:
            from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoImageProcessor
        except Exception as err:
            logging.critical('Please install the latest version transformers.')
            raise err

        model = AutoModelForVision2Seq.from_pretrained(
            model_path, device_map='cuda', trust_remote_code=True, torch_dtype='auto'
        ).eval()

        tokenizer = AutoTokenizer.from_pretrained(
            model_path, trust_remote_code=True, use_fast=False, legacy=False
        )
        tokenizer = model.update_special_tokens(tokenizer)
        tokenizer.eos_token = '<|end|>'
        tokenizer.padding_side = 'left'
        image_processor = AutoImageProcessor.from_pretrained(model_path, trust_remote_code=True)
        self.model = model
        self.image_processor = image_processor
        self.tokenizer = tokenizer
        self.kwargs = kwargs

    def apply_prompt_template(self, query):
        s = (
            '<|system|>\nA chat between a curious user and an artificial intelligence assistant. '
            "The assistant gives helpful, detailed, and polite answers to the user's questions.<|end|>\n"
            f'<|user|>\n{query}<|end|>\n<|assistant|>\n'
        )
        return s

    def generate_inner(self, message, dataset=None):

        content, images, image_sizes = '', [], []

        for msg in message:
            if msg['type'] == 'text':
                content += msg['value']
            elif msg['type'] == 'image':
                image = Image.open(msg['value']).convert('RGB')
                images.append(self.image_processor([image], image_aspect_ratio='anyres')['pixel_values'].to('cuda'))
                image_sizes.append(image.size)
                content += '<image> '

        inputs = {'pixel_values': [images]}
        prompt = self.apply_prompt_template(content)
        language_inputs = self.tokenizer([prompt], return_tensors='pt').to('cuda')
        inputs.update(language_inputs)

        generation_args = {
            'max_new_tokens': 1024,
            'temperature': 0.0,
            'do_sample': False,
            'top_p': None,
            'num_beams': 1
        }
        generation_args.update(self.kwargs)

        generate_ids = self.model.generate(
            **inputs, image_size=[image_sizes],
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
            **generation_args
        )

        # remove input tokens
        response = self.tokenizer.decode(generate_ids[0], skip_special_tokens=True).split('<|end|>')[0]

        return response
