import torch
from transformers import AutoModelForCausalLM

from ..base import BaseModel
from ...dataset import DATASET_TYPE, DATASET_MODALITY
from ...smp import *


class Ovis(BaseModel):
    INSTALL_REQ = False
    INTERLEAVE = True

    def __init__(self, model_path='AIDC-AI/Ovis1.5-Llama3-8B', **kwargs):
        assert model_path is not None
        # Recommend to install `transformers==4.43.2` and `torch==2.1.2`.
        self.model_path = model_path
        self.device = torch.cuda.current_device()
        self.dtype = torch.bfloat16
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_path,
            torch_dtype=self.dtype,
            multimodal_max_length=8192,
            trust_remote_code=True
        )
        self.model = self.model.eval().to(device=self.device)
        self.eos_token_id = self.model.generation_config.eos_token_id
        self.text_tokenizer = self.model.get_text_tokenizer()
        self.pad_token_id = self.text_tokenizer.pad_token_id
        self.visual_tokenizer = self.model.get_visual_tokenizer()
        self.conversation_formatter = self.model.get_conversation_formatter()
        self.image_placeholder = '<image>'
        self.gen_kwargs = dict(
            max_new_tokens=1024,
            do_sample=False,
            top_p=None,
            top_k=None,
            temperature=None,
            repetition_penalty=None,
            eos_token_id=self.eos_token_id,
            pad_token_id=self.pad_token_id,
            use_cache=True
        )
        self.gen_kwargs.update(kwargs)

    def use_custom_prompt(self, dataset):
        if DATASET_TYPE(dataset) == 'Y/N' or DATASET_TYPE(dataset) == 'MCQ':
            return True
        return False

    def build_prompt(self, line, dataset=None):
        assert self.use_custom_prompt(dataset)
        assert isinstance(dataset, str)
        tgt_path = self.dump_image(line, dataset)

        if DATASET_TYPE(dataset) == 'Y/N':
            prompt = self.build_yorn_prompt(line, dataset)
        elif DATASET_TYPE(dataset) == 'MCQ':
            prompt = self.build_multi_choice_prompt(line, dataset)
        else:
            raise RuntimeError(f'Invalid dataset type: {DATASET_TYPE(dataset)}')
        message = [dict(type='text', value=prompt)]
        message.extend([dict(type='image', value=s) for s in tgt_path])


        return message

    def build_yorn_prompt(self, line, dataset=None):
        prompt = line['question']
        prompt += '\n请用单个词或短语回答问题。' if cn_string(
            prompt) else '\nAnswer the question using a single word or phrase.'
        return prompt

    def build_multi_choice_prompt(self, line, dataset=None):
        question = line['question']
        hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
        if hint is not None:
            question = hint + '\n' + question

        options = {
            cand: line[cand]
            for cand in string.ascii_uppercase
            if cand in line and not pd.isna(line[cand])
        }
        for key, item in options.items():
            question += f'\n{key}. {item}'
        prompt = question

        if len(options):
            prompt += '\n请直接回答选项字母。' if cn_string(
                prompt) else "\nAnswer with the option's letter from the given choices directly."
        else:
            prompt += '\n请直接回答问题。' if cn_string(prompt) else '\nAnswer the question directly.'

        return prompt

    def generate_inner(self, message, dataset=None):
        prompt, input_ids, attention_mask, pixel_values = self.prepare_inputs(message)
        output_ids = self.model.generate(
            input_ids,
            pixel_values=pixel_values,
            attention_mask=attention_mask,
            **self.gen_kwargs
        )
        response = self.text_tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()

        return response

    def prepare_inputs(self, message):
        # build query
        images = [x['value'] for x in message if x['type'] == 'image']
        texts = [x['value'] for x in message if x['type'] == 'text']
        if len(images) == 0:
            query = '\n'.join(texts)
        elif len(images) == 1 and len(texts) == 1:
            query = self.image_placeholder + '\n' + texts[0]
        else:  # interleave sample
            chunks = [x['value'] if x['type'] == 'text' else self.image_placeholder for x in message]
            query = '\n'.join(chunks)

        # format conversation
        prompt, input_ids = self.conversation_formatter.format_query(query)
        attention_mask = torch.ne(input_ids, self.text_tokenizer.pad_token_id)
        input_ids = input_ids.unsqueeze(0).to(device=self.device)
        attention_mask = attention_mask.unsqueeze(0).to(device=self.device)

        # preprocess images
        if len(images) == 0:
            pixel_values = [None]
        else:
            preprocessed_images = [self.visual_tokenizer.preprocess_image(Image.open(image)) for image in images]
            pixel_values = [torch.cat(preprocessed_images, dim=0).to(device=self.device, dtype=self.dtype)]

        return prompt, input_ids, attention_mask, pixel_values


class Ovis1_6(BaseModel):
    INSTALL_REQ = False
    INTERLEAVE = True

    def __init__(self, model_path='AIDC-AI/Ovis1.6-Gemma2-9B', **kwargs):
        assert model_path is not None
        # Recommend to install `python=3.10`, `transformers==4.44.2`, `torch==2.2.0`, and `numpy==1.24.3`
        self.model_path = model_path
        self.device = torch.cuda.current_device()
        self.dtype = torch.bfloat16
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_path,
            torch_dtype=self.dtype,
            multimodal_max_length=8192,
            trust_remote_code=True
        )
        self.model = self.model.eval().to(device=self.device)
        self.eos_token_id = self.model.generation_config.eos_token_id
        self.text_tokenizer = self.model.get_text_tokenizer()
        self.pad_token_id = self.text_tokenizer.pad_token_id
        self.visual_tokenizer = self.model.get_visual_tokenizer()
        self.max_partition = 9
        self.image_placeholder = '<image>'
        self.gen_kwargs = dict(
            max_new_tokens=1024,
            do_sample=False,
            top_p=None,
            top_k=None,
            temperature=None,
            repetition_penalty=None,
            eos_token_id=self.eos_token_id,
            pad_token_id=self.pad_token_id,
            use_cache=True
        )
        self.gen_kwargs.update(kwargs)

    def use_custom_prompt(self, dataset):
        if DATASET_TYPE(dataset) == 'Y/N' or DATASET_TYPE(dataset) == 'MCQ':
            return True
        return False

    def build_yorn_prompt(self, line, dataset=None):
        prompt = line['question'] + '\nAnswer the question using a single word or phrase.'
        return prompt

    def build_multi_choice_prompt(self, line, dataset=None):
        question = line['question']
        hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
        if hint is not None:
            question = hint + '\n' + question

        options = {
            cand: line[cand]
            for cand in string.ascii_uppercase
            if cand in line and not pd.isna(line[cand])
        }
        for key, item in options.items():
            question += f'\n{key}. {item}'
        prompt = question

        if len(options):
            prompt += "\nAnswer with the option's letter from the given choices directly."

        return prompt

    def build_prompt(self, line, dataset=None):
        assert self.use_custom_prompt(dataset)
        assert isinstance(dataset, str)
        tgt_path = self.dump_image(line, dataset)

        if DATASET_TYPE(dataset) == 'Y/N':
            prompt = self.build_yorn_prompt(line, dataset)
        elif DATASET_TYPE(dataset) == 'MCQ':
            prompt = self.build_multi_choice_prompt(line, dataset)
        else:
            raise RuntimeError(f'Invalid dataset type: {DATASET_TYPE(dataset)}')
        message = [dict(type='text', value=prompt)]
        message.extend([dict(type='image', value=s) for s in tgt_path])


        return message

    def generate_inner(self, message, dataset=None):
        prompt, input_ids, attention_mask, pixel_values = self.prepare_inputs(message)
        output_ids = self.model.generate(
            input_ids,
            pixel_values=pixel_values,
            attention_mask=attention_mask,
            **self.gen_kwargs
        )
        response = self.text_tokenizer.decode(output_ids[0], skip_special_tokens=True)

        return response

    def prepare_inputs(self, message):
        # build query
        images = [x['value'] for x in message if x['type'] == 'image']
        texts = [x['value'] for x in message if x['type'] == 'text']
        if len(images) == 0:
            query = '\n'.join(texts)
        elif len(images) == 1 and len(texts) == 1:
            query = self.image_placeholder + '\n' + texts[0]
        else:  # interleaved sample
            chunks = [x['value'] if x['type'] == 'text' else self.image_placeholder for x in message]
            query = '\n'.join(chunks)

        # preprocess inputs
        prompt, input_ids, pixel_values = self.model.preprocess_inputs(
            query, [Image.open(image) for image in images], max_partition=self.max_partition
        )

        # move to self.device
        attention_mask = torch.ne(input_ids, self.text_tokenizer.pad_token_id)
        input_ids = input_ids.unsqueeze(0).to(device=self.device)
        attention_mask = attention_mask.unsqueeze(0).to(device=self.device)
        pixel_values = [
            pixel_values.to(device=self.device, dtype=self.dtype) if pixel_values is not None else None
        ]

        return prompt, input_ids, attention_mask, pixel_values


class Ovis1_6_Plus(Ovis1_6):
    # Recommend to install `python=3.10`, `transformers==4.46.2`, `torch==2.4.0`, and `numpy==1.25.0`


    def build_prompt(self, line, dataset=None):
        assert self.use_custom_prompt(dataset)
        assert isinstance(dataset, str)
        tgt_path = self.dump_image(line, dataset)

        if DATASET_TYPE(dataset) == 'Y/N':
            prompt = self.build_yorn_prompt(line, dataset)
        elif DATASET_TYPE(dataset) == 'MCQ':
            prompt = self.build_multi_choice_prompt(line, dataset)
        else:
            raise RuntimeError(f'Invalid dataset type: {DATASET_TYPE(dataset)}')

        message = [dict(type='image', value=s) for s in tgt_path] + [dict(type='text', value=prompt)]

        return message


class Ovis2(BaseModel):
    INSTALL_REQ = False
    INTERLEAVE = True
    SIZE_DICT = {
        (24, 896): '1B',  # (num_hidden_layers, hidden_size)
        (28, 1536): '2B',
        (36, 2048): '4B',
        (28, 3584): '8B',
        (48, 5120): '16B',
        (64, 5120): '34B'
    }

    def __init__(self, model_path='AIDC-AI/Ovis2-8B', **kwargs):
        assert model_path is not None
        # Recommend to install `python=3.10`, `transformers==4.46.2`, `torch==2.4.0`, and `numpy==1.25.0`
        self.model_path = model_path
        self.device = torch.cuda.current_device()
        self.dtype = torch.bfloat16
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_path,
            torch_dtype=self.dtype,
            multimodal_max_length=32768,
            trust_remote_code=True
        )
        self.size = self.SIZE_DICT[
            (self.model.config.llm_config.num_hidden_layers, self.model.config.llm_config.hidden_size)]
        self.model = self.model.eval().to(device=self.device)
        self.eos_token_id = self.model.generation_config.eos_token_id
        self.text_tokenizer = self.model.get_text_tokenizer()
        self.pad_token_id = self.text_tokenizer.pad_token_id
        self.visual_tokenizer = self.model.get_visual_tokenizer()
        self.image_placeholder = '<image>'
        self.gen_kwargs = dict(
            max_new_tokens=1024,
            do_sample=False,
            top_p=None,
            top_k=None,
            temperature=None,
            repetition_penalty=None,
            eos_token_id=self.eos_token_id,
            pad_token_id=self.pad_token_id,
            use_cache=True
        )
        self.use_cot = {
            '1B': {},
            '2B': {},
            '4B': {},
            '8B': {},
            '16B': {},
            '34B':  {}
        }
        self.frame_selector = None
        if kwargs.pop("frame_selection", False):
            from .utils.mdp3 import MDP3
            self.frame_selector = MDP3(
                n_selection=int(kwargs.pop("n_frames", 32)),
                visual_encoder_name_or_path=kwargs.pop("frame_selection_vlm", "google/siglip-so400m-patch14-384"),
                device=f"cuda:{self.device}"
            )
        self.gen_kwargs.update(kwargs)

    def use_custom_prompt(self, dataset):
        if DATASET_TYPE(dataset) == 'Y/N' or DATASET_TYPE(dataset) == 'MCQ':
            return True
        return False

    def build_yorn_prompt(self, line, dataset=None):
        prompt = line['question']
        prompt += '\nAnswer the question using a single word or phrase.'
        return prompt

    def build_multi_choice_prompt(self, line, dataset=None, use_cot=False):
        prompt = line['question']
        hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
        if hint is not None:
            prompt = hint + '\n' + prompt

        options = {
            cand: line[cand]
            for cand in string.ascii_uppercase
            if cand in line and not pd.isna(line[cand])
        }
        for key, item in options.items():
            prompt += f'\n{key}. {item}'

        if len(options):
            if use_cot:
                prompt += "\nProvide a step-by-step solution to the problem, and conclude with 'the answer is' followed by the final solution."
            else:
                prompt += "\nAnswer with the option's letter from the given choices directly."

        return prompt


    def build_prompt(self, line, dataset=None):
        assert self.use_custom_prompt(dataset)
        assert isinstance(dataset, str)
        tgt_path = self.dump_image(line, dataset)

        use_cot = any(dataset.startswith(prefix) for prefix in self.use_cot[self.size])

        if DATASET_TYPE(dataset) == 'Y/N':
            prompt = self.build_yorn_prompt(line, dataset)
        elif DATASET_TYPE(dataset) == 'MCQ':
            prompt = self.build_multi_choice_prompt(line, dataset, use_cot)
        else:
            raise RuntimeError(f'Invalid dataset type: {DATASET_TYPE(dataset)}')

        message = [dict(type='image', value=s) for s in tgt_path] + [dict(type='text', value=prompt)]

        return message

    def generate_inner(self, message, dataset=None):
        def _extract_answer(text):
            answer_index = text.lower().find('the answer is')
            if answer_index != -1:
                answer_index += len('the answer is')
                answer = text[answer_index:].lstrip(':').strip()
            else:
                answer = text
            return answer

        prompt, input_ids, attention_mask, pixel_values, max_partition = self.prepare_inputs(message, dataset)
        output_ids = self.model.generate(
            input_ids,
            pixel_values=pixel_values,
            attention_mask=attention_mask,
            **self.gen_kwargs
        )
        response = self.text_tokenizer.decode(output_ids[0], skip_special_tokens=True)

        if "conclude with 'the answer is' followed by the final solution." in prompt:
            response = _extract_answer(response)

        return response

    def prepare_inputs(self, message, dataset=None):
        # build query
        images = [x['value'] for x in message if x['type'] == 'image']
        texts = [x['value'] for x in message if x['type'] == 'text']
        if DATASET_MODALITY(dataset) == 'VIDEO': # video inputs
            chunks = [self.image_placeholder for x in message if x['type'] != 'text']
            chunks += [x['value'].strip() for x in message if x['type'] == 'text' and x['value'] != '']
            query = '\n'.join(chunks)
        elif len(images) == 0: # text-only inputs
            query = '\n'.join(texts)
        elif len(images) == 1 and len(texts) == 1: # single-image inputs
            query = self.image_placeholder + '\n' + texts[0]
        else:  # interleaved inputs
            chunks = [x['value'].strip() if x['type'] == 'text' else self.image_placeholder for x in message]
            query = '\n'.join(chunks)

        if len(images) > 1:
            max_partition = max(1, 12 // len(images))
        else:
            max_partition = 9

        prompt, input_ids, pixel_values = self.model.preprocess_inputs(
            query, [Image.open(image) for image in images], max_partition=max_partition, frame_selector=self.frame_selector
        )

        # move to self.device
        attention_mask = torch.ne(input_ids, self.text_tokenizer.pad_token_id)
        input_ids = input_ids.unsqueeze(0).to(device=self.device)
        attention_mask = attention_mask.unsqueeze(0).to(device=self.device)
        pixel_values = [
            pixel_values.to(device=self.device, dtype=self.dtype) if pixel_values is not None else None
        ]

        return prompt, input_ids, attention_mask, pixel_values, max_partition


class OvisU1(BaseModel):
    INSTALL_REQ = False
    INTERLEAVE = True

    def __init__(self, model_path='AIDC-AI/Ovis-U1-3B', **kwargs):
        assert model_path is not None
        # Recommend to install `transformers==4.51.3`, `torch==2.4.0`, and `numpy==1.24.3`
        self.model_path = model_path
        self.device = torch.cuda.current_device()
        self.dtype = torch.bfloat16

        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_path,
            torch_dtype=self.dtype,
            multimodal_max_length=32768,
            trust_remote_code=True
        )
        self.model = self.model.eval().to(device=self.device)
        self.text_tokenizer = self.model.get_text_tokenizer()
        self.pad_token_id = self.text_tokenizer.pad_token_id
        self.eos_token_id = self.text_tokenizer.eos_token_id
        self.visual_tokenizer = self.model.get_visual_tokenizer()
        self.image_placeholder = '<image>'
        self.gen_kwargs = dict(
            max_new_tokens=1024,
            do_sample=False,
            top_p=None,
            top_k=None,
            temperature=None,
            repetition_penalty=None,
            eos_token_id=self.eos_token_id,
            pad_token_id=self.pad_token_id,
            use_cache=True
        )
        self.min_pixels = 200704  # 448*448
        self.max_pixels = 2408448  # 1344*1792
        self.frame_selector = None
        if kwargs.pop("frame_selection", False):
            from .utils.mdp3 import MDP3
            self.frame_selector = MDP3(
                n_selection=int(kwargs.pop("n_frames", 32)),
                visual_encoder_name_or_path=kwargs.pop("frame_selection_vlm", "google/siglip-so400m-patch14-384"),
                device=f"cuda:{self.device}"
            )
        self.gen_kwargs.update(kwargs)

    def use_custom_prompt(self, dataset):
        if DATASET_TYPE(dataset) == 'Y/N' or DATASET_TYPE(dataset) == 'MCQ':
            return True
        return False


    def build_multi_choice_prompt(self, line, dataset=None, use_cot=False):
        prompt = line['question']
        hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
        if hint is not None:
            prompt = hint + '\n' + prompt

        options = {
            cand: line[cand]
            for cand in string.ascii_uppercase
            if cand in line and not pd.isna(line[cand])
        }
        for key, item in options.items():
            prompt += f'\n{key}. {item}'

        if len(options):
            if use_cot:
                prompt += "\nProvide a step-by-step solution to the problem, and conclude with 'the answer is' followed by the final solution."
            else:
                prompt += "\nAnswer with the option's letter from the given choices directly."

        return prompt



    def build_prompt(self, line, dataset=None):
        assert self.use_custom_prompt(dataset)
        assert isinstance(dataset, str)
        tgt_path = self.dump_image(line, dataset)

        use_cot = any(dataset.startswith(prefix) for prefix in self.use_cot)

        if DATASET_TYPE(dataset) == 'Y/N':
            prompt = self.build_yorn_prompt(line, dataset)
        elif DATASET_TYPE(dataset) == 'MCQ':
            prompt = self.build_multi_choice_prompt(line, dataset, use_cot)
        else:
            raise RuntimeError(f'Invalid dataset type: {DATASET_TYPE(dataset)}')

        message = [dict(type='image', value=s) for s in tgt_path] + [dict(type='text', value=prompt)]

        return message

    def generate_inner(self, message, dataset=None):
        def _extract_answer(text):
            answer_index = text.lower().find('the answer is')
            if answer_index != -1:
                answer_index += len('the answer is')
                answer = text[answer_index:].lstrip(':').strip()
            else:
                answer = text
            return answer

        prompt, input_ids, attention_mask, pixel_values, grid_thws = self.prepare_inputs(message, dataset)
        output_ids = self.model.generate(
            input_ids,
            pixel_values=pixel_values,
            grid_thws=grid_thws,
            attention_mask=attention_mask,
            **self.gen_kwargs
        )
        response = self.text_tokenizer.decode(output_ids[0], skip_special_tokens=True)

        print('\n========================************========================')
        print(f'prompt: {prompt}<<<\n')
        print(f'output: {response}\n')

        think_end = response.rfind('</think>')
        if think_end != -1:
            think_end += len('</think>')
            response = response[think_end:].strip()
            print(f'extract answer: {response}\n')

        if "conclude with 'the answer is' followed by the final solution." in prompt:
            response = _extract_answer(response)
            print(f'extract answer: {response}\n')

        print('------------------------------------------------------------\n', flush=True)

        return response

    def prepare_inputs(self, message, dataset=None):
        # build query
        images = [x['value'] for x in message if x['type'] == 'image']
        texts = [x['value'] for x in message if x['type'] == 'text']
        # print(f"=============={DATASET_MODALITY(dataset)}============")
        if DATASET_MODALITY(dataset) == 'VIDEO':  # video inputs
            chunks = [self.image_placeholder for x in message if x['type'] != 'text']
            chunks += [x['value'].strip() for x in message if x['type'] == 'text' and x['value'] != '']
            query = '\n'.join(chunks)
            # print(query, chunks)
        elif len(images) == 0:  # text-only inputs
            query = '\n'.join(texts)
        elif len(images) == 1 and len(texts) == 1:  # single-image inputs
            query = self.image_placeholder + '\n' + texts[0]
        else:  # interleaved inputs
            chunks = [x['value'].strip() if x['type'] == 'text' else self.image_placeholder for x in message]
            query = '\n'.join(chunks)

        # preprocess inputs
        min_pixels = self.min_pixels
        max_pixels = self.max_pixels
        enable_thinking = os.getenv("OvisThink") == 'True'
        prompt, input_ids, pixel_values, grid_thws = self.model.preprocess_inputs(
            query, [Image.open(image) for image in images],
            frame_selector=self.frame_selector,
            enable_thinking=enable_thinking,
            min_pixels=min_pixels,
            max_pixels=max_pixels,  # 2000*2000,
        )

        attention_mask = torch.ne(input_ids, self.text_tokenizer.pad_token_id)
        input_ids = input_ids.unsqueeze(0).to(device=self.device)
        attention_mask = attention_mask.unsqueeze(0).to(device=self.device)
        pixel_values = torch.cat([
            pixel_values.to(device=self.device, dtype=self.dtype) if pixel_values is not None else None
        ], dim=0)
        grid_thws = torch.cat([
            grid_thws.to(device=self.device) if grid_thws is not None else None
        ], dim=0)

        return prompt, input_ids, attention_mask, pixel_values, grid_thws