import os

import torch
from PIL import Image
from abc import abstractproperty
from .base import BaseModel
from ..dataset import DATASET_TYPE
from ..smp import *


class Parrot(BaseModel):
    INSTALL_REQ = False
    INTERLEAVE = False

    def __init__(self, model_path='AIDC-AI/Parrot-7B', **kwargs):
        try:
            from parrot.model.parrot_arch import ParrotMetaForCausalLM
            from parrot.utils.constants import DEFAULT_IMAGE_TOKEN, BEGIN_LINE, END_LINE
            from parrot.model.conversation_formatter import ConversationFormatter
            from parrot.utils.mm_utils import process_images
        except Exception as e:
            logging.critical('Please install Parrot before using Parrot')
            logging.critical('Please install Parrot from https://github.com/AIDC-AI/Parrot')
            logging.critical('Using `pip install -e . --no-deps` in the Parrot directory')
            logging.critical('Recommend to install transformers==4.39.0')
            raise e

        self.process_images = process_images
        self.ConversationFormatter = ConversationFormatter
        self.DEFAULT_IMAGE_TOKEN = DEFAULT_IMAGE_TOKEN
        self.BEGIN_LINE = BEGIN_LINE
        self.END_LINE = END_LINE

        try:
            model_name = 'parrot_qwen2'
            model, tokenizer, conversation_formatter = ParrotMetaForCausalLM.build(
                model_name, model_path, mm_vision_tower='openai/clip-vit-large-patch14-336'
            )
            self.model = model.cuda()
            self.vision_tower = self.model.get_vision_tower()
            self.tokenizer = tokenizer
            self.conversation_formatter = conversation_formatter
            self.image_processor = self.model.get_vision_tower().image_processor
        except Exception as e:
            logging.critical('Error when loading Parrot model:')
            raise e

        self.kwargs = dict(
            do_sample=False,
            num_beams=1,
            max_new_tokens=512,
            repetition_penalty=None,
            use_cache=True,
            eos_token_id=self.tokenizer.eos_token_id,
            pad_token_id=self.tokenizer.pad_token_id
        )
        if int(os.environ.get('LOCAL_RANK', '0')) == 0:
            print(f'Following kwargs {self.kwargs} will be used as generation config.')

        self.count = 0

    def use_custom_prompt(self, dataset):
        if DATASET_TYPE(dataset) == 'Y/N' or DATASET_TYPE(dataset) == 'MCQ':
            return True
        return False

    def build_prompt(self, line, dataset=None):
        assert self.use_custom_prompt(dataset)
        assert isinstance(dataset, str)
        tgt_path = self.dump_image(line, dataset)

        if DATASET_TYPE(dataset) == 'Y/N':
            prompt = self.built_yorn_prompt(line, dataset)
        elif DATASET_TYPE(dataset) == 'MCQ':
            prompt = self.build_multi_choice_prompt(line, dataset)
        else:
            raise ValueError(f'Invalid dataset type: {DATASET_TYPE(dataset)}')

        message = [dict(type='text', value=prompt)]
        message.extend([dict(type='image', value=p) for p in tgt_path])
        return message

    def built_yorn_prompt(self, line, dataset=None):
        prompt = line['question']
        previous_suffixs = [' Please answer yes or no.', ' Yes or No', ' Answer in one sentence.']
        for previous_suffix in previous_suffixs:
            if prompt.endswith(previous_suffix):
                prompt = prompt[:-len(previous_suffix)]
                break
        prompt += '\n请直接回答Yes或No。请用单个词或短语回答问题。' if cn_string(
            prompt) else '\nPlease strictly answer Yes or No. Answer the question using a single word or phrase.'
        return prompt

    def build_multi_choice_prompt(self, line, dataset=None):
        question = line['question']
        hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
        if hint is not None:
            question = hint + '\n' + question

        options = {
            cand: line[cand]
            for cand in string.ascii_uppercase
            if cand in line and not pd.isna(line[cand])
        }
        for key, item in options.items():
            question += f'\n{key}. {item}'
        prompt = question

        if len(options):
            default_prompt = "\nAnswer with the option's letter from the given choices directly."
            if dataset[-3:] == '_cn' or cn_string(prompt):
                default_prompt = '\n请直接用给定选项中的选项字母回答。'
            elif dataset[-3:] == '_pt':
                default_prompt = '\nResponda diretamente com a letra da opção das escolhas dadas.'
            elif dataset[-3:] == '_ar':
                default_prompt = '\nأجب مباشرةً بحرف الخيار من الاختيارات المعطاة.'
            elif dataset[-3:] == '_ru':
                default_prompt = '\nОтветьте буквой варианта из предложенных вариантов напрямую.'
            elif dataset[-3:] == '_tr':
                default_prompt = '\nVerilen seçeneklerden doğrudan seçeneğin harfi ile cevap verin.'
            prompt += default_prompt
            # prompt += (
            #     '\n请直接回答选项字母。' if cn_string(prompt) else
            #     "\nAnswer with the option's letter from the given choices directly."
            # )
        else:
            prompt += '\n请用单个词或短语回答问题。' if cn_string(
                prompt) else '\nAnswer the question using a single word or phrase.'

        return prompt

    def process_answer_prefix(self, answer, prefixes):
        for prefix in prefixes:
            if prefix in answer.lower():
                return answer[answer.lower().find(prefix) + len(prefix):]
        return answer

    def generate_inner(self, message, dataset=None):
        query, image_paths = self.prepare_inputs(message)
        images_list = [Image.open(image_path).convert('RGB') for image_path in image_paths]
        args = abstractproperty()
        args.image_aspect_ratio = 'pad'
        image_tensors = self.process_images(images_list, self.image_processor, args).cuda()
        prompt, input_ids = self.conversation_formatter.format_query(query)
        input_ids = input_ids.unsqueeze(0).cuda()

        with torch.inference_mode():
            kwargs = dict(
                images=image_tensors,
            )
            kwargs.update(self.kwargs)
            output_ids = self.model.generate(input_ids, **kwargs)

        input_token_len = input_ids.shape[1]
        n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
        if n_diff_input_output > 0:
            print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
        response = self.tokenizer.batch_decode(output_ids[:, input_token_len:],
                                               skip_special_tokens=True)[0].strip(string.whitespace)
        answer = response

        if query.endswith("Answer with the option's letter from the given choices directly.") or query.endswith(
                '请直接回答选项字母。'):
            qtype = 'multiple-choice'
            while True:
                answer = answer.strip(string.punctuation + string.whitespace)
                if len(answer) > 1:
                    if answer[0] in string.ascii_uppercase and answer[1] in string.whitespace + string.punctuation:
                        answer = answer[0]
                        break
                    elif answer[-1] in string.ascii_uppercase and answer[-2] in string.whitespace + string.punctuation:
                        answer = answer[-1]
                        break
                    elif listinstr(['answer is', 'answer:'], answer.lower()):
                        answer = self.process_answer_prefix(answer, ['answer is', 'answer:'])
                        answer = self.process_answer_prefix(answer, ['option'])
                    else:
                        break
                else:
                    break
        else:
            qtype = 'open'

        if self.count % 50 == 0 and int(os.environ.get('LOCAL_RANK', '0')) == 0:
            print(f'\n{self.BEGIN_LINE}')
            print(f'image_paths: {image_paths}\n')
            print(f'prompt: {prompt}\n')
            print(f'qtype: {qtype}\n')
            print(f'output: {response}\n')
            print(f'answer: {answer}\n')
            print(f'{self.END_LINE}\n', flush=True)

        self.count += 1

        return answer

    def prepare_inputs(self, message):
        prompt = ''
        image_paths = []
        image_count = 0
        text_count = 0
        pure_text = ''
        for msg in message:
            if msg['type'] == 'text':
                text_count += 1
                prompt += msg['value']
                pure_text += msg['value']
            elif msg['type'] == 'image':
                image_count += 1
                prompt += self.DEFAULT_IMAGE_TOKEN
                image_paths.append(msg['value'])

        if image_count == 1 and text_count == 1:
            prompt = self.DEFAULT_IMAGE_TOKEN + '\n' + pure_text

        return prompt, image_paths
