import torch
from transformers import AutoModelForCausalLM

from .base import BaseModel
from ..dataset import DATASET_TYPE
from ..smp import *


class Ovis(BaseModel):
    INSTALL_REQ = False
    INTERLEAVE = True

    def __init__(self, model_path='AIDC-AI/Ovis1.5-Llama3-8B', **kwargs):
        assert model_path is not None
        # Recommend to install `transformers==4.43.2` and `torch==2.1.2`.
        self.model_path = model_path
        self.device = torch.cuda.current_device()
        self.dtype = torch.bfloat16
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_path,
            torch_dtype=self.dtype,
            multimodal_max_length=8192,
            trust_remote_code=True
        )
        self.model = self.model.eval().to(device=self.device)
        self.eos_token_id = self.model.generation_config.eos_token_id
        self.text_tokenizer = self.model.get_text_tokenizer()
        self.pad_token_id = self.text_tokenizer.pad_token_id
        self.visual_tokenizer = self.model.get_visual_tokenizer()
        self.conversation_formatter = self.model.get_conversation_formatter()
        self.image_placeholder = '<image>'
        self.gen_kwargs = dict(
            max_new_tokens=1024,
            do_sample=False,
            top_p=None,
            top_k=None,
            temperature=None,
            repetition_penalty=None,
            eos_token_id=self.eos_token_id,
            pad_token_id=self.pad_token_id,
            use_cache=True
        )

    def use_custom_prompt(self, dataset):
        if DATASET_TYPE(dataset) == 'Y/N' or DATASET_TYPE(dataset) == 'MCQ':
            return True
        return False

    def build_prompt(self, line, dataset=None):
        assert self.use_custom_prompt(dataset)
        assert isinstance(dataset, str)
        tgt_path = self.dump_image(line, dataset)

        if DATASET_TYPE(dataset) == 'Y/N':
            prompt = self.built_yorn_prompt(line, dataset)
        elif DATASET_TYPE(dataset) == 'MCQ':
            prompt = self.build_multi_choice_prompt(line, dataset)
        else:
            raise RuntimeError(f'Invalid dataset type: {DATASET_TYPE(dataset)}')
        message = [dict(type='text', value=prompt)]
        message.extend([dict(type='image', value=s) for s in tgt_path])

        # interleave dataset
        if dataset.startswith('MMMU_'):
            from .. import MMMUDataset
            message = MMMUDataset.split_MMMU(message)

        return message

    def built_yorn_prompt(self, line, dataset=None):
        prompt = line['question']
        if listinstr(['HallusionBench'], dataset):
            prompt += ' Please answer yes or no.'
        prompt += '\n请用单个词或短语回答问题。' if cn_string(
            prompt) else '\nAnswer the question using a single word or phrase.'
        return prompt

    def build_multi_choice_prompt(self, line, dataset=None):
        question = line['question']
        hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
        if hint is not None:
            question = hint + '\n' + question

        options = {
            cand: line[cand]
            for cand in string.ascii_uppercase
            if cand in line and not pd.isna(line[cand])
        }
        for key, item in options.items():
            question += f'\n{key}. {item}'
        prompt = question

        if len(options):
            prompt += '\n请直接回答选项字母。' if cn_string(
                prompt) else "\nAnswer with the option's letter from the given choices directly."
        else:
            prompt += '\n请直接回答问题。' if cn_string(prompt) else '\nAnswer the question directly.'

        return prompt

    def generate_inner(self, message, dataset=None):
        prompt, input_ids, attention_mask, pixel_values = self.prepare_inputs(message)
        output_ids = self.model.generate(
            input_ids,
            pixel_values=pixel_values,
            attention_mask=attention_mask,
            **self.gen_kwargs
        )
        response = self.text_tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()

        return response

    def prepare_inputs(self, message):
        # build query
        images = [x['value'] for x in message if x['type'] == 'image']
        texts = [x['value'] for x in message if x['type'] == 'text']
        if len(images) == 0:
            query = '\n'.join(texts)
        elif len(images) == 1 and len(texts) == 1:
            query = self.image_placeholder + '\n' + texts[0]
        else:  # interleave sample
            chunks = [x['value'] if x['type'] == 'text' else self.image_placeholder for x in message]
            query = '\n'.join(chunks)

        # format conversation
        prompt, input_ids = self.conversation_formatter.format_query(query)
        attention_mask = torch.ne(input_ids, self.text_tokenizer.pad_token_id)
        input_ids = input_ids.unsqueeze(0).to(device=self.device)
        attention_mask = attention_mask.unsqueeze(0).to(device=self.device)

        # preprocess images
        if len(images) == 0:
            pixel_values = [None]
        else:
            preprocessed_images = [self.visual_tokenizer.preprocess_image(Image.open(image)) for image in images]
            pixel_values = [torch.cat(preprocessed_images, dim=0).to(device=self.device, dtype=self.dtype)]

        return prompt, input_ids, attention_mask, pixel_values
