import math
import torch
import random
import numpy as np
from PIL import Image
from transformers import AutoModel, AutoTokenizer

from .base import BaseModel
from ..smp import *
from ..dataset import DATASET_TYPE, DATASET_MODALITY

import re


class MiniCPM_V(BaseModel):

    INSTALL_REQ = False
    INTERLEAVE = False

    def __init__(self, model_path='openbmb/MiniCPM-V', **kwargs):
        assert model_path is not None
        self.model_path = model_path
        print(f'load from {self.model_path}')
        self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True)
        self.model = self.model.to(dtype=torch.bfloat16)
        self.model.eval().cuda()
        self.kwargs = kwargs
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
        torch.cuda.empty_cache()
        self.num_beams = 3

    def use_custom_prompt(self, dataset):
        assert dataset is not None
        if listinstr(['MMDU', 'MME-RealWorld', 'MME-RealWorld-CN'], dataset):
            # For Multi-Turn we don't have custom prompt
            return False
        return False

    def build_prompt(self, line, dataset=None):
        assert dataset is None or isinstance(dataset, str)
        assert self.use_custom_prompt(dataset)
        tgt_path = self.dump_image(line, dataset)

        question = line['question']
        options = {
            cand: line[cand]
            for cand in string.ascii_uppercase
            if cand in line and not pd.isna(line[cand])
        }
        options_prompt = 'Options:\n'
        for key, item in options.items():
            options_prompt += f'{key}. {item}\n'
        hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
        prompt = ''
        if hint is not None:
            prompt += f'Hint: {hint}\n'
        prompt += f'{question}\n'
        if len(options):
            prompt += options_prompt
            prompt = 'Study the image carefully and pick the option associated with the correct answer. \
                Focus solely on selecting the option and avoid including any other content.\n' + prompt
        message = [dict(type='text', value=prompt)]
        message.extend([dict(type='image', value=p) for p in tgt_path])

        return message

    def generate_inner(self, message, dataset=None):
        prompt, image_path = self.message_to_promptimg(message, dataset=dataset)
        image = Image.open(image_path).convert('RGB')
        msgs = [{'role': 'user', 'content': prompt}]
        if DATASET_TYPE(dataset) == 'MCQ':
            max_new_tokens = 20
        elif DATASET_TYPE(dataset) == 'Y/N':
            max_new_tokens = 100
        else:
            max_new_tokens = 1024

        default_kwargs = dict(
            max_new_tokens=max_new_tokens,
            sampling=False,
            num_beams=self.num_beams
        )
        default_kwargs.update(self.kwargs)
        res, _, _ = self.model.chat(
            image=image,
            msgs=msgs,
            context=None,
            tokenizer=self.tokenizer,
            **default_kwargs
        )
        return res


class MiniCPM_Llama3_V(BaseModel):

    INSTALL_REQ = False
    INTERLEAVE = True

    def __init__(self, model_path='openbmb/MiniCPM-Llama3-V-2_5', **kwargs):
        assert model_path is not None
        self.model_path = model_path
        print(f'load from {self.model_path}')
        self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True)
        self.model = self.model.to(dtype=torch.float16)
        self.model.eval().cuda()
        self.kwargs = kwargs
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
        torch.cuda.empty_cache()
        self.num_beams = 3
        self.options_system_prompt = ('Carefully read the following question and select the letter corresponding '
                                      'to the correct answer. Highlight the applicable choices without giving '
                                      'explanations.')
        self.wo_options_system_prompt = 'Carefully read the following question Answer the question directly.'
        self.detail_system_prompt = 'Answer this question in detail.'
        self.vqa_prompt = 'Answer the question using a single word or phrase.'

    def use_custom_prompt(self, dataset):
        if listinstr(['MCQ', 'VQA'], DATASET_TYPE(dataset)):
            return True
        elif dataset is not None and listinstr(['HallusionBench'], dataset):
            return True
        return False

    def build_prompt(self, line, dataset=None):
        if isinstance(line, int):
            line = self.data.iloc[line]

        tgt_path = self.dump_image(line, dataset)
        system_prompt = ''

        question = line['question']
        if DATASET_TYPE(dataset) == 'MCQ':
            options = {
                cand: line[cand]
                for cand in string.ascii_uppercase
                if cand in line and not pd.isna(line[cand])
            }
            options_prompt = 'Options:\n'
            for key, item in options.items():
                options_prompt += f'{key}. {item}\n'
            hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
            prompt = ''
            if hint is not None:
                prompt += f'Hint: {hint}\n'
            prompt += f'Question: {question}\n'
            if len(options):
                prompt += options_prompt
                system_prompt = self.options_system_prompt + '\nPlease just indicate your choice.'
            else:
                system_prompt = self.wo_options_system_prompt
            if 'MMMU' in dataset:  # Corner Case
                prompt = system_prompt + '\n' + prompt
                system_prompt = ''
        elif dataset is not None and listinstr(['HallusionBench'], dataset):
            question = line['question'] + ' Yes or No?'
            prompt = question
        elif dataset is not None and listinstr(['MME'], dataset):
            question = line['question'] + ' Yes or No?'
            prompt = question
        elif dataset is not None and listinstr(['OCRBench'], dataset):
            system_prompt = self.vqa_prompt
            question = line['question']
            prompt = question
        elif DATASET_TYPE(dataset) == 'VQA':
            if listinstr(['LLaVABench', 'MMLongBench_DOC'], dataset):
                system_prompt = ''
                prompt = question
            elif listinstr(['MMVet'], dataset):
                system_prompt = self.detail_system_prompt
                prompt = question
            else:
                system_prompt = self.vqa_prompt
                prompt = question

        msgs = []
        if system_prompt:
            msgs.append(dict(type='text', value=system_prompt))
        if isinstance(tgt_path, list):
            msgs.extend([dict(type='image', value=p) for p in tgt_path])
        else:
            msgs = [dict(type='image', value=tgt_path)]
        msgs.append(dict(type='text', value=prompt))
        return msgs

    def generate_inner(self, message, dataset=None):
        if DATASET_TYPE(dataset) == 'MCQ':
            max_new_tokens = 200
        elif DATASET_TYPE(dataset) == 'Y/N':
            max_new_tokens = 3
        else:
            max_new_tokens = 1024

        default_kwargs = dict(
            max_new_tokens=max_new_tokens,
            sampling=False,
            num_beams=self.num_beams,
        )
        default_kwargs.update(self.kwargs)

        content = []
        for x in message:
            if x['type'] == 'text':
                content.append(x['value'])
            elif x['type'] == 'image':
                image = Image.open(x['value']).convert('RGB')
                content.append(image)
        msgs = [{'role': 'user', 'content': content}]

        res = self.model.chat(
            msgs=msgs,
            context=None,
            image=None,
            tokenizer=self.tokenizer,
            **default_kwargs
        )

        if isinstance(res, tuple) and len(res) > 0:
            res = res[0]
        return res

    def chat_inner(self, message, dataset=None):
        max_new_tokens = 1024

        default_kwargs = dict(
            max_new_tokens=max_new_tokens,
            sampling=False,
            num_beams=self.num_beams,
        )
        default_kwargs.update(self.kwargs)

        msgs = []
        for msg in message:
            content = []
            if len(msg['content']) == 1 and msg['content'][0]['type'] == 'text':
                msg_new = {'role': msg['role'], 'content': msg['content'][0]['value']}
                msgs.append(msg_new)
                continue

            for x in msg['content']:
                if x['type'] == 'text':
                    content.append(x['value'])
                elif x['type'] == 'image':
                    image = Image.open(x['value']).convert('RGB')
                    content.append(image)
            msg_new = {'role': msg['role'], 'content': content}
            msgs.append(msg_new)

        res = self.model.chat(
            msgs=msgs,
            context=None,
            image=None,
            tokenizer=self.tokenizer,
            **default_kwargs)

        if isinstance(res, tuple) and len(res) > 0:
            res = res[0]
        return res


class MiniCPM_V_2_6(BaseModel):
    INSTALL_REQ = False
    INTERLEAVE = True

    def __init__(self, model_path='openbmb/MiniCPM-V-2_6', **kwargs):
        random.seed(0)
        np.random.seed(0)
        torch.manual_seed(0)
        torch.cuda.manual_seed_all(0)

        assert model_path is not None
        self.model_path = model_path
        print(f'load from path {self.model_path}')
        self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True)
        self.model = self.model.to(dtype=torch.bfloat16)
        self.model.eval().cuda()

        self.kwargs = kwargs
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
        torch.cuda.empty_cache()
        self.num_beams = 3

        self.options_suffix_prompt = '''\nAnswer with the option's letter from the given choices directly.'''
        self.wo_options_system_prompt = 'Carefully read the following question Answer the question directly.'
        self.detail_system_prompt = 'Answer this question in detail.'
        self.vqa_prompt = 'Answer the question using a single word or phrase.'

        self.multi_choice_cot_prompt = ('''Carefully read the following multichoice question, solve it step '''
                                        '''by step and finally pick the option associated with the correct '''
                                        '''answer in the format of "Answer: selected option\n\n''')
        self.short_ans_cot_prompt = ('''Read the following question carefully, solve it step by step, and '''
                                     '''then output the final answer in the format of "Answer: single number '''
                                     '''or single word or phrase".\n\n''')

    def use_custom_prompt(self, dataset=None):
        if dataset is None:
            return False
        if DATASET_TYPE(dataset) in ['MCQ', 'VQA', 'Y/N']:
            return True
        return False

    def use_cot(self, dataset=None):
        if dataset is None:
            return False
        if listinstr(['MMMU', 'HallusionBench', 'OCRBench', 'ChartQA'], dataset):
            return True
        elif listinstr(['MathVista', 'MMVet', 'MMBench', 'MMStar', 'AI2D', 'RealWorldQA',
                        'POPE', 'ScienceQA', 'TextVQA', 'DocVQA'], dataset):
            return False
        else:
            return False

    def use_upsize(self, dataset=None):
        if dataset is None:
            return False
        if listinstr(['MMVet', 'MMBench', 'MMStar', 'AI2D', 'OCRBench'], dataset):
            return True
        else:
            return False

    def build_prompt(self, line, dataset=None):
        if isinstance(line, int):
            line = self.data.iloc[line]

        tgt_path = self.dump_image(line, dataset)
        system_prompt, prompt = '', ''

        question = line['question']

        if not self.use_cot(dataset):
            if DATASET_TYPE(dataset) == 'MCQ':
                options = {
                    cand: line[cand]
                    for cand in string.ascii_uppercase
                    if cand in line and not pd.isna(line[cand])
                }
                options_prompt = 'Options:\n'
                for key, item in options.items():
                    options_prompt += f'{key}. {item}\n'
                hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
                if hint is not None:
                    prompt += f'Hint: {hint}\n'
                prompt += f'Question: {question}\n'
                if len(options):
                    prompt += options_prompt
                    prompt += self.options_suffix_prompt
                else:
                    system_prompt = self.wo_options_system_prompt

                if 'MMMU' in dataset:
                    if len(system_prompt) > 0:
                        prompt = system_prompt + '\n' + prompt
                        system_prompt = ''
            elif dataset is not None and listinstr(['HallusionBench'], dataset):
                question += ' Yes or No?'
                prompt = question
            elif dataset is not None and listinstr(['OCRBench'], dataset):
                system_prompt = self.vqa_prompt
                prompt = question
            elif DATASET_TYPE(dataset) == 'VQA':
                if listinstr(['LLaVABench'], dataset):
                    system_prompt = ''
                elif listinstr(['MMVet'], dataset):
                    system_prompt = self.detail_system_prompt
                else:
                    system_prompt = self.vqa_prompt
                prompt = question
            else:
                prompt = question
        else:
            has_options = True
            if DATASET_TYPE(dataset) == 'MCQ':
                options = {
                    cand: line[cand]
                    for cand in string.ascii_uppercase
                    if cand in line and not pd.isna(line[cand])
                }
                options_prompt = ''
                for key, item in options.items():
                    options_prompt += f'{key}. {item}\n'
                hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
                if hint is not None:
                    prompt += f'Hint: {hint}\n'
                prompt += f'{question}\n'

                if len(options):
                    prompt += options_prompt
                else:
                    has_options = False

                if 'MMMU' in dataset:
                    if len(system_prompt) > 0:
                        prompt = system_prompt + '\n' + prompt
                        system_prompt = ''
            else:
                prompt = question

            if DATASET_TYPE(dataset) in ['MCQ', 'Y/N', 'VQA']:
                if DATASET_TYPE(dataset) == 'MCQ':
                    if has_options:
                        prompt = self.multi_choice_cot_prompt + prompt
                    else:
                        prompt = self.short_ans_cot_prompt + prompt
                elif DATASET_TYPE(dataset) == 'Y/N':
                    prompt = self.short_ans_cot_prompt + prompt
                else:
                    prompt = self.short_ans_cot_prompt + prompt

        msgs = []
        if system_prompt:
            msgs.append(dict(type='text', value=system_prompt))
        if isinstance(tgt_path, list):
            msgs.extend([dict(type='image', value=p) for p in tgt_path])
        else:
            msgs = [dict(type='image', value=tgt_path)]
        msgs.append(dict(type='text', value=prompt))

        return msgs

    def generate_inner(self, message, dataset=None):
        if DATASET_MODALITY(dataset) == 'VIDEO':
            max_slice_nums = 1
            use_image_id = False
            max_inp_length = 2048 * 10
        else:
            max_slice_nums = None
            use_image_id = True
            max_inp_length = 8192

        max_new_tokens = 2048
        default_kwargs = dict(
            max_new_tokens=max_new_tokens,
            sampling=False,
            num_beams=self.num_beams,
        )
        default_kwargs.update(self.kwargs)

        content = []

        for x in message:
            if x['type'] == 'text':
                content.append(x['value'])
            elif x['type'] == 'image':
                image = Image.open(x['value']).convert('RGB')
                if not self.use_upsize(dataset):
                    content.append(image)
                else:
                    img_width, img_height = image.width, image.height
                    if (img_width * img_height) >= (1344 * 1344):
                        content.append(image)
                    else:
                        ratio = math.sqrt((1344 * 1344) / (img_width * img_height))
                        max_img_width = int(img_width * ratio)
                        new_img_width = random.randint(img_width, max_img_width)
                        new_img_height = int(new_img_width / img_width * img_height)
                        resized_image = image.resize((new_img_width, new_img_height))
                        content.append(resized_image)
        msgs = [{'role': 'user', 'content': content}]

        res = self.model.chat(
            image=None,
            msgs=msgs,
            context=None,
            tokenizer=self.tokenizer,
            max_inp_length=max_inp_length,
            use_image_id=use_image_id,
            max_slice_nums=max_slice_nums,
            **default_kwargs
        )

        if isinstance(res, tuple) and len(res) > 0:
            res = res[0]

        return res


class MiniCPM_o_2_6(BaseModel):
    INSTALL_REQ = False
    INTERLEAVE = True

    def __init__(self, model_path='openbmb/MiniCPM-o-2_6', **kwargs):
        random.seed(0)
        np.random.seed(0)
        torch.manual_seed(0)
        torch.cuda.manual_seed_all(0)

        assert model_path is not None
        self.model_path = model_path
        print(f'load from path {self.model_path}')
        self.model = AutoModel.from_pretrained(
            self.model_path,
            trust_remote_code=True,
            attn_implementation='sdpa',
            torch_dtype=torch.bfloat16,
            init_vision=True,
            init_audio=False,
            init_tts=False
        )

        self.model.eval().cuda()

        self.kwargs = kwargs
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
        torch.cuda.empty_cache()

        self.num_beams = int(os.getenv("NUM_BEAMS", "3"))

        repetition_penalty = float(os.getenv("PENALTY", "1.2"))
        self.repetition_penalty = repetition_penalty

        self.options_suffix_prompt = '''\nAnswer with the option's letter from the given choices directly.'''
        self.wo_options_system_prompt = 'Carefully read the following question Answer the question directly.'
        self.detail_system_prompt = 'Answer this question in detail.'
        self.vqa_prompt = 'Answer the question using a single word or phrase.'

        self.multi_choice_cot_prompt = ('''Carefully read the following multichoice question, solve it step '''
                                        '''by step and finally pick the option associated with the correct '''
                                        '''answer in the format of "Answer: selected option\n\n''')
        self.short_ans_cot_prompt = ('''Read the following question carefully, solve it step by step, and '''
                                     '''then output the final answer in the format of "Answer: single number '''
                                     '''or single word or phrase".\n\n''')

    def use_custom_prompt(self, dataset=None):
        if dataset is None:
            return False
        if listinstr(['MCQ', 'VQA', 'Y/N'], DATASET_TYPE(dataset)):
            return True
        return False

    def use_cot(self, dataset=None):
        if dataset is None:
            return False
        if listinstr(['MMMU', 'MathVista', 'OCRBench', 'ChartQA', 'MathVision', 'MathVerse_MINI_Vision_Only'], dataset):
            return True
        elif listinstr(['MMVet', 'MMBench', 'MMStar', 'HallusionBench', 'AI2D', 'RealWorldQA',
                        'POPE', 'ScienceQA', 'TextVQA', 'DocVQA'], dataset):
            return False
        else:
            return False

    def use_upsize(self, dataset=None):
        if dataset is None:
            return False
        if listinstr(['MathVista', 'MMBench_TEST_CN', 'MMStar', 'AI2D', 'OCRBench', 'DynaMath'], dataset):
            return True
        else:
            return False

    def build_prompt(self, line, dataset=None):
        if isinstance(line, int):
            line = self.data.iloc[line]

        tgt_path = self.dump_image(line, dataset)
        system_prompt, prompt = '', ''

        question = line['question']

        if not self.use_cot(dataset):
            if DATASET_TYPE(dataset) == 'MCQ':
                options = {
                    cand: line[cand]
                    for cand in string.ascii_uppercase
                    if cand in line and not pd.isna(line[cand])
                }
                options_prompt = 'Options:\n'
                for key, item in options.items():
                    options_prompt += f'{key}. {item}\n'
                hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
                if hint is not None:
                    prompt += f'Hint: {hint}\n'
                prompt += f'Question: {question}\n'
                if len(options):
                    prompt += options_prompt
                    prompt += self.options_suffix_prompt
                else:
                    system_prompt = self.wo_options_system_prompt

                if 'MMMU' in dataset:
                    if len(system_prompt) > 0:
                        prompt = system_prompt + '\n' + prompt
                        system_prompt = ''
            elif dataset is not None and listinstr(['HallusionBench'], dataset):
                question += ' Yes or No?'
                prompt = question
            elif dataset is not None and listinstr(['OCRBench'], dataset):
                system_prompt = self.vqa_prompt
                prompt = question
            elif DATASET_TYPE(dataset) == 'VQA':
                if listinstr(['LLaVABench'], dataset):
                    system_prompt = ''
                elif listinstr(['MMVet'], dataset):
                    system_prompt = self.detail_system_prompt
                else:
                    system_prompt = self.vqa_prompt
                prompt = question
            else:
                prompt = question
        else:
            has_options = True
            if DATASET_TYPE(dataset) == 'MCQ':
                options = {
                    cand: line[cand]
                    for cand in string.ascii_uppercase
                    if cand in line and not pd.isna(line[cand])
                }
                options_prompt = ''
                for key, item in options.items():
                    options_prompt += f'{key}. {item}\n'
                hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
                if hint is not None:
                    prompt += f'Hint: {hint}\n'
                prompt += f'{question}\n'

                if len(options):
                    prompt += options_prompt
                else:
                    has_options = False

                if 'MMMU' in dataset:
                    if len(system_prompt) > 0:
                        prompt = system_prompt + '\n' + prompt
                        system_prompt = ''
            else:
                prompt = question

            if DATASET_TYPE(dataset) in ['MCQ', 'Y/N', 'VQA']:
                if DATASET_TYPE(dataset) == 'MCQ':
                    if has_options:
                        prompt = self.multi_choice_cot_prompt + prompt
                    else:
                        prompt = self.short_ans_cot_prompt + prompt
                elif DATASET_TYPE(dataset) == 'Y/N':
                    prompt = self.short_ans_cot_prompt + prompt
                else:
                    prompt = self.short_ans_cot_prompt + prompt

        msgs = []
        if system_prompt:
            msgs.append(dict(type='text', value=system_prompt))
        if isinstance(tgt_path, list):
            msgs.extend([dict(type='image', value=p) for p in tgt_path])
        else:
            msgs = [dict(type='image', value=tgt_path)]
        msgs.append(dict(type='text', value=prompt))

        return msgs

    def extract_answer(self, res, dataset=None):
        if dataset is None:
            return res
        if self.use_cot(dataset):
            if DATASET_TYPE(dataset) == 'MCQ':
                pattern = r'Answer:\s*([A-Ia-i])(?![A-Za-z])'
                matches = re.findall(pattern, res, re.DOTALL)
                if matches:
                    extracted_res = matches[-1].strip()
                else:
                    extracted_res = res
                return extracted_res
            elif DATASET_TYPE(dataset) == 'VQA' and not listinstr(['OCRBench'], dataset):
                pattern = r'Answer:\s*(.*)\s*$'
                match = re.search(pattern, res, re.DOTALL)
                if match:
                    extracted_res = match.group(1)
                else:
                    extracted_res = res
                return extracted_res
        return res

    def generate_inner(self, message, dataset=None):
        if DATASET_MODALITY(dataset) == 'VIDEO':
            max_slice_nums = 1
            use_image_id = False
            max_inp_length = 2048 * 10
        else:
            max_slice_nums = None
            use_image_id = True
            max_inp_length = 8192

        max_new_tokens = 2048
        default_kwargs = dict(
            max_new_tokens=max_new_tokens,
            sampling=False,
            repetition_penalty=self.repetition_penalty,
            num_beams=self.num_beams,
        )
        default_kwargs.update(self.kwargs)

        content = []

        for x in message:
            if x['type'] == 'text':
                content.append(x['value'])
            elif x['type'] == 'image':
                image = Image.open(x['value']).convert('RGB')
                if not self.use_upsize(dataset):
                    content.append(image)
                else:
                    img_width, img_height = image.width, image.height
                    if (img_width * img_height) >= (1344 * 1344):
                        content.append(image)
                    else:
                        ratio = math.sqrt((1344 * 1344) / (img_width * img_height))
                        max_img_width = int(img_width * ratio)
                        new_img_width = random.randint(img_width, max_img_width)
                        new_img_height = int(new_img_width / img_width * img_height)
                        resized_image = image.resize((new_img_width, new_img_height))
                        content.append(resized_image)
        msgs = [{'role': 'user', 'content': content}]

        res = self.model.chat(
            image=None,
            msgs=msgs,
            context=None,
            tokenizer=self.tokenizer,
            max_inp_length=max_inp_length,
            use_image_id=use_image_id,
            max_slice_nums=max_slice_nums,
            **default_kwargs
        )

        if isinstance(res, tuple) and len(res) > 0:
            res = res[0]

        res = self.extract_answer(res, dataset)

        return res
