import torch
from PIL import Image
from abc import abstractproperty
from .base import BaseModel
from ..smp import *
from ..dataset import DATASET_TYPE
import warnings


class Mantis(BaseModel):
    """
    Mantis Model
    This implementation is adpated from the Llava model from llava.py and the Idefics model from idefics.py
    """
    INSTALL_REQ = True
    INTERLEAVE = True

    DEFAULT_IMAGE_TOKEN = '<image>'
    IMAGE_TOKEN_INDEX = -200

    def __init__(self, model_path='TIGER-Lab/Mantis-8B-siglip-llama3', **kwargs):
        assert model_path is not None
        try:
            from mantis.models.mllava import LlavaForConditionalGeneration, MLlavaProcessor
            from mantis.models.mfuyu import MFuyuForCausalLM, MFuyuProcessor
            from mantis.models.conversation import conv_mllava_v1 as default_conv, conv_templates
        except Exception as e:
            logging.critical(
                "Mantis is not installed. Please install Mantis to use this model.Please use 'pip install "
                "git+https://github.com/TIGER-AI-Lab/Mantis.git' to install"
            )
            raise e

        try:
            from transformers import AutoModelForVision2Seq, AutoProcessor
        except Exception as e:
            logging.critical(f'{type(e)}: {e}')
            logging.critical("Upgrade transformers to use Mantis's idefics model.\nError: %s" % e)

        # inference implementation for attention, can be "sdpa", "eager", "flash_attention_2".
        # Seems FA2 is not effective during inference:
        # https://discuss.huggingface.co/t/flash-attention-has-no-effect-on-inference/73453/5
        # if is_flash_attn_2_available:
        #     best_fit_attn_implementation = "flash_attention_2"
        # flash_attn has a bug that says: ERROR Error query and key must have the same dtype in generating

        try:
            import flash_attn
            best_fit_attn_implementation = 'flash_attention_2'
        except ImportError:
            best_fit_attn_implementation = 'eager'
        self.model_path = model_path
        attn_implementation = best_fit_attn_implementation
        self._is_idefics = 'idefics' in model_path.lower()
        # Here load the "non-idefics" Mantis model.
        if not self._is_idefics:
            if 'fuyu' in model_path.lower():
                self.processor = MFuyuProcessor.from_pretrained(self.model_path)
                model = MFuyuForCausalLM.from_pretrained(
                    self.model_path,
                    device_map='cuda',
                    attn_implementation=attn_implementation,
                    torch_dtype=torch.float16
                )
            else:
                self.processor = MLlavaProcessor.from_pretrained(self.model_path)
                model = LlavaForConditionalGeneration.from_pretrained(
                    self.model_path,
                    device_map='cuda',
                    attn_implementation=attn_implementation,
                    torch_dtype=torch.float16
                )
        else:
            self.processor = AutoProcessor.from_pretrained(self.model_path)
            model = AutoModelForVision2Seq.from_pretrained(
                self.model_path,
                device_map='cuda',
                torch_dtype=torch.float16
            )

        model = model.eval()
        self.model = model.cuda()
        kwargs_default = dict(do_sample=False, temperature=0, max_new_tokens=1024, top_p=None, num_beams=1)
        kwargs_default.update(kwargs)
        self.kwargs = kwargs_default
        warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')

        self.tokenizer = self.processor.tokenizer
        self.default_conv = default_conv
        self.conv_templates = conv_templates

    def use_custom_prompt(self, dataset):
        assert dataset is not None
        if DATASET_TYPE(dataset) == 'MCQ':
            return True
        return False

    def build_prompt(self, line, dataset=None):
        assert self.use_custom_prompt(dataset)
        assert dataset is None or isinstance(dataset, str)
        tgt_path = self.dump_image(line, dataset)

        question = line['question']
        hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
        if hint is not None:
            question = hint + '\n' + question
        options = {
            cand: line[cand]
            for cand in string.ascii_uppercase
            if cand in line and not pd.isna(line[cand])
        }
        for key, item in options.items():
            question += f'\n{key}. {item}'
        prompt = question

        if len(options):
            prompt += (
                '\n请直接回答选项字母。' if cn_string(prompt) else
                "\nAnswer with the option's letter from the given choices directly."
            )
        else:
            prompt += '\n请直接回答问题。' if cn_string(prompt) else '\nAnswer the question directly.'
        message = [dict(type='image', value=s) for s in tgt_path]
        message.append(dict(type='text', value=prompt))
        return message

    def output_process(self, answer):
        if '<s>' in answer:
            answer = answer.replace('<s>', '').strip()
        if '[/INST]' in answer:
            answer = answer.split('[/INST]')[1].strip()
        elif 'ASSISTANT:' in answer:
            answer = answer.split('ASSISTANT:')[1].strip()
        elif 'assistant\n' in answer:
            answer = answer.split('assistant\n')[1].strip()
        elif '<|end_header_id|>\n\n' in answer:
            answer = answer.split('<|end_header_id|>\n\n')[2].strip()

        if '</s>' in answer:
            answer = answer.split('</s>')[0].strip()
        elif '<|im_end|>' in answer:
            answer = answer.split('<|im_end|>')[0].strip()
        elif '<|eot_id|>' in answer:
            answer = answer.split('<|eot_id|>')[0].strip()
        elif '<end_of_utterance>' in answer:
            answer = answer.split('<end_of_utterance>')[0].strip()
        elif '|ENDOFTEXT|' in answer:
            answer = answer.split('|ENDOFTEXT|')[0].strip()
        return answer

    def generate_inner(self, message, dataset=None):
        content, images = '', []
        ide_content, question = [], ''
        for msg in message:
            if msg['type'] == 'text':
                content += msg['value']
                question += msg['value']
            else:
                images.append(Image.open(msg['value']).convert('RGB'))
                content += (self.DEFAULT_IMAGE_TOKEN + '\n')
                ide_content.append({'type': 'image'})
        if self._is_idefics:
            # Follow the idefics implementation:
            ide_content.append({'type': 'text', 'text': question})
            prompt = [{'role': 'user', 'content': ide_content}]
            prompt = self.processor.apply_chat_template(prompt, add_generation_prompt=True)
        else:
            # Follow the Mantis code base to make sure they are consistent:
            # https://github.com/TIGER-AI-Lab/Mantis/blob/main/mantis/models/mllava/utils.py#L33
            # Users don't need to define chat template as it is done here
            if 'llama-3' in self.model.language_model.name_or_path.lower():
                conv = self.conv_templates['llama_3']
                terminators = [
                    self.processor.tokenizer.eos_token_id,
                    self.processor.tokenizer.convert_tokens_to_ids('<|eot_id|>')
                ]
            else:
                conv = self.default_conv
                terminators = [self.processor.tokenizer.eos_token_id]

            # Using EOT because end of *text* is more accurate for what we're doing than end of *sentence*
            if 'eos_token_id' not in self.kwargs:
                self.kwargs['eos_token_id'] = terminators

            conv = conv.copy()
            conv.append_message(conv.roles[0], content)
            conv.append_message(conv.roles[1], '')
            assert conv.messages[-1][0] == conv.roles[1] and conv.messages[-1][1] == '', 'Format check'
            prompt = conv.get_prompt()

        inputs = self.processor(prompt, images, return_tensors='pt', truncation=True)
        # FIXME: Fuyu model would return a list instead of a pytorch tensor. This weird behavior needs fixing.
        if 'image_patches' in inputs.keys():
            inputs['image_patches'] = inputs['image_patches'][0]
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        output = self.model.generate(**inputs, **self.kwargs)
        output = output[0]
        generated_ids = output[inputs['input_ids'].shape[-1]:]
        answer = self.processor.decode(generated_ids, skip_special_token=True)
        answer = self.output_process(answer)
        return answer
