import logging
import types

import torch
from transformers import (
    AutoConfig,
    AutoProcessor,
    AutoTokenizer,
    AutoModelForCausalLM,
)
from transformers.modeling_utils import PreTrainedModel

from llava import LlavaForConditionalGenerationWithLabels


logger = logging.getLogger(__name__)


IGNORE_INDEX = -100


def load(name: str):
    special_opts = {
        "is_qwen2_vl": "Qwen2-VL" in name,
        "is_phi3v": "Phi-3.5-vision" in name,
        "is_llava_phi": "llava-phi" in name,
        "is_llava_ov": "llava-onevision" in name,
    }
    model_args = dict(
        attn_implementation="flash_attention_2", torch_dtype=torch.float16
    )
    processor = Processor(name)
    if special_opts["is_llava_phi"]:
        model = LlavaForConditionalGenerationWithLabels.from_pretrained(
            name, **model_args
        )
    elif special_opts["is_phi3v"]:
        model = AutoModelForCausalLM.from_pretrained(
            name, **model_args, trust_remote_code=True
        )
    elif special_opts["is_qwen2_vl"]:
        from transformers import Qwen2VLForConditionalGeneration

        model = Qwen2VLForConditionalGeneration.from_pretrained(name, **model_args)
    elif special_opts["is_llava_ov"]:
        from transformers import LlavaOnevisionForConditionalGeneration

        model = LlavaOnevisionForConditionalGeneration.from_pretrained(
            name, **model_args
        )
    else:
        raise NotImplementedError(f"Not implemented model: {name}")

    if "llava" in name:
        for param in model.vision_tower.parameters():
            param.requires_grad = False
        for param in model.multi_modal_projector.parameters():
            param.requires_grad = True
    elif special_opts["is_qwen2_vl"]:
        for param in model.visual.parameters():
            param.requires_grad = False

    # for deepspeed
    if not hasattr(model.config, "hidden_size"):
        model.config.hidden_size = model.config.text_config.hidden_size

    # for saving phi3v: https://github.com/huggingface/transformers/issues/27293

    def save_pretrained(self, *args, **kwargs):
        kwargs["safe_serialization"] = False
        return PreTrainedModel.save_pretrained(self, *args, **kwargs)

    model.save_pretrained = types.MethodType(save_pretrained, model)

    return model, processor


class Processor:
    def __init__(self, name: str, image_size: int = 1024):
        self.name = name
        self.image_size = image_size
        self.special_opts = {
            "is_qwen2_vl": "Qwen2-VL" in name,
            "is_phi3v": "Phi-3.5-vision" in name,
            "is_llava_phi": "llava-phi" in name,
            "is_llava_ov": "llava-onevision" in name,
        }
        kwargs = {}
        if self.special_opts["is_phi3v"]:
            # Set number of crops to 16 for Phi3.5 Vision
            kwargs["num_crops"] = 16
        processor = AutoProcessor.from_pretrained(name, trust_remote_code=True)
        if self.special_opts["is_llava_phi"]:
            base_tokenizer = AutoTokenizer.from_pretrained(
                "microsoft/Phi-3.5-mini-instruct"
            )
            processor.tokenizer.chat_template = base_tokenizer.chat_template
        self.config = AutoConfig.from_pretrained(name, trust_remote_code=True)
        self.processor = processor

    def resize(self, image):
        # resize pil image
        # we should not do centre crop here, since that would skew bbox coordinates
        # image = center_crop_arr(image, self.image_size)
        image = image.resize((self.image_size, self.image_size))
        return image

    def build_labels(self, inputs):
        labels = inputs["input_ids"].clone()
        if self.special_opts["is_qwen2_vl"]:
            mask = labels == self.config.vision_start_token_id
            mask = mask | (labels == self.config.vision_end_token_id)
            mask = mask | (labels == self.config.vision_token_id)
            mask = mask | (labels == self.config.image_token_id)
            labels[mask] = -100
        elif self.special_opts["is_llava_ov"]:
            mask = labels == self.config.image_token_index
            mask = mask | (labels == self.config.video_token_index)
            labels[mask] = -100
        # Note that llava-phi labels will be handled by the model itself
        # else:
        #     raise NotImplementedError(f"Not implemented model: {model}")
        return labels

    def __call__(self, conversation, images=[], **kwargs):
        if images:
            if isinstance(images, list):
                images = [self.resize(image) for image in images]
            else:
                images = [self.resize(images)]
        assert len(conversation) == 2, "only allow single turn conversation"

        def format_turn(turn):
            # llava format to hf format
            if "value" in turn:
                turn["content"] = turn.pop("value")
            if "from" in turn:
                turn["role"] = turn.pop("from")
            if turn["role"] == "gpt":
                turn["role"] = "assistant"
            if isinstance(turn["content"], str):
                turn["content"] = [{"text": turn["content"], "type": "text"}]
                # always add the image in the user turn
                if turn["role"] == "user":
                    turn["content"].append({"type": "image"})
            return turn

        conversation = [format_turn(turn) for turn in conversation]

        def run_old_processor(conversation, image_key: str = "<|image|>"):
            image_keys = "\n".join(
                [image_key.replace("K", str(i + 1)) for i in range(len(images))]
            )
            conversation[0]["content"][0]["text"] = (
                f"{image_keys}\n" + conversation[0]["content"][0]["text"]
            )
            conversation = [
                {**turn, "content": turn["content"][0]["text"]} for turn in conversation
            ]
            prompt = self.processor.tokenizer.apply_chat_template(
                conversation, tokenize=False, add_generation_prompt=False
            )
            return prompt

        if self.special_opts["is_phi3v"]:
            prompt = run_old_processor(conversation, "<|image_K|>")
        elif self.special_opts["is_llava_phi"]:
            prompt = run_old_processor(conversation, "<image>")
        else:
            prompt = self.processor.apply_chat_template(conversation, tokenize=False)
        inputs = self.processor(images=images, text=prompt, return_tensors="pt")
        if self.special_opts["is_qwen2_vl"]:
            inputs["pixel_values"] = inputs["pixel_values"][None]

        inputs = {k: v for k, v in inputs.items() if v is not None}
        inputs["labels"] = self.build_labels(inputs)
        return {k: v[0] for k, v in inputs.items()}  # remove batch dim
