import torch
from PIL import Image
import re
from transformers import AutoModel, AutoProcessor

from .base import BaseModel
from ..smp import *
from ..dataset import DATASET_TYPE


class OmChat(BaseModel):

    INSTALL_REQ = True
    INTERLEAVE = True

    def __init__(self, model_path='omlab/omchat-v2.0-13B-single-beta_hf', **kwargs):

        # Recommend to install `transformers==4.44.0`
        assert model_path is not None
        self.model_path = model_path
        print(f'load from {self.model_path}')
        model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True, torch_dtype=torch.float16)
        self.model = model.cuda().eval()
        self.kwargs = kwargs
        self.processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True)
        torch.cuda.empty_cache()

        # system prompt
        self.default_system_prompt = 'You are a helpful assistant. Focus on accuracy and reliability in your response.'
        self.new1_system_prompt = 'You are a helpful assistant.'
        self.new2_system_prompt = (
            'Read the following question carefully, '
            'solve it step by step, '
            'and then output the final answer in the format of '
            "'Answer: single number or single word or phrase'.\n\n"
        )

        # suffix_prompt for MCQ
        self.mcq_suffix_prompt_en = 'Please select the correct answer from the options above. \n'
        self.mcq_suffix_prompt_cn = '请直接回答选项字母。\n'
        # suffix_prompt for Y/N
        self.yorn_suffix_prompt = ' Please answer yes or no. Answer the question using a single word or phrase.'

    def use_custom_prompt(self, dataset):
        assert dataset is not None
        if DATASET_TYPE(dataset) == 'MCQ' or DATASET_TYPE(dataset) == 'Y/N':
            return True
        return False

    def build_prompt(self, line, dataset=None):
        assert dataset is None or isinstance(dataset, str)
        assert self.use_custom_prompt(dataset)
        tgt_path = self.dump_image(line, dataset)

        if isinstance(line, int):
            line = self.data.iloc[line]

        question = line['question']

        if DATASET_TYPE(dataset) == 'MCQ':
            hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
            options = {
                cand: line[cand]
                for cand in string.ascii_uppercase
                if cand in line and not pd.isna(line[cand])
            }
            options_prompt = 'Options:\n'
            for key, item in options.items():
                options_prompt += f'{key}. {item}\n'

            prompt = ''
            if hint is not None:
                prompt += f'Hint: {hint}\n'
            prompt += f'Question: {question}\n'
            if len(options):
                prompt += options_prompt

        elif DATASET_TYPE(dataset) == 'Y/N':
            prompt = question + self.yorn_suffix_prompt

        print(DATASET_TYPE(dataset))
        message = []
        if isinstance(tgt_path, list):
            message.extend([dict(type='image', value=p) for p in tgt_path])
        else:
            message = [dict(type='image', value=tgt_path)]
        message.append(dict(type='text', value=prompt))

        return message

    def message_to_promptimg(self, message, dataset=None):
        if dataset is None:
            prompt = '\n'.join([
                re.sub(r'<image\s*\d+>', '<image>', x['value'])
                for x in message
                if x['type'] == 'text'
            ])
            image = [x['value'] for x in message if x['type'] == 'image']
        else:
            prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
            image = [x['value'] for x in message if x['type'] == 'image']
        return prompt, image

    def generate_inner(self, message, dataset=None):

        def replace_last_dot(input_string):
            if input_string.endswith('.'):
                return input_string[:-1]
            else:
                return input_string

        prompt, image_path = self.message_to_promptimg(message, dataset=dataset)
        image = [Image.open(img_path).convert('RGB') for img_path in image_path]

        default_kwargs = dict(
            max_new_tokens=1024,
            do_sample=False,
            temperature=0.0,
            top_p=1)
        system_prompt = self.default_system_prompt
        inputs = self.processor(text=prompt, system_prompt=system_prompt, images=image, return_tensors='pt').to('cuda')
        default_kwargs.update(self.kwargs)

        with torch.inference_mode():
            output_ids = self.model.generate(
                **inputs,
                eos_token_id=self.model.generation_config.eos_token_id,
                **default_kwargs
            )
        res = self.processor.tokenizer.decode(output_ids[0, inputs.input_ids.shape[1]:]).strip()
        if '<|im_end|>' in res:
            res = res.split('<|im_end|>')[0].strip()

        if res.startswith('Answer: '):
            res = res[len('Answer: '):]

        match = re.search(r'\nThe answer is:(.+)', res)
        if match:
            res = match.group(1).strip()

        doc_match = re.search(r'<doc>(.*?)<\/doc>', res)
        if doc_match:
            res = doc_match.group(1).strip()
        res = replace_last_dot(res)

        return res
