import math

import torch
from torchvision import transforms as tv_trans
from transformers import AutoProcessor, AutoModelForPreTraining
from transformers.image_processing_utils import select_best_resolution

from .clip import CLIP_transform
from .utils import to_pil


class LLaVA16_transform(object):

    def __init__(self, image_processor, device):
        super(LLaVA16_transform, self).__init__()
        interpolation = getattr(image_processor, "resample", 3)
        self.interpolation = interpolation

        size = image_processor.size["shortest_edge"]
        assert size == image_processor.crop_size["height"]
        assert size == image_processor.crop_size["width"]
        self.size = size

        self.resize1 = tv_trans.Resize(size=(size, size), 
                                       interpolation=interpolation, 
                                       antialias=True)
        
        self.resize2 = tv_trans.Compose(
            [tv_trans.Resize(size=size, 
                             interpolation=interpolation, 
                             antialias=True),
             tv_trans.CenterCrop(size=size),
            ])

        self.image_grid_pinpoints = image_processor.image_grid_pinpoints
        self.rescale_factor = image_processor.rescale_factor

        mean = torch.tensor(image_processor.image_mean).to(device)
        std = torch.tensor(image_processor.image_std).to(device)

        self.mean = mean.view(3, 1, 1)
        self.std = std.view(3, 1, 1)


    def find_resolution(self, image):
        b, c, h, w = image.shape
        target_h, target_w = select_best_resolution(
            (h, w), self.image_grid_pinpoints)

        scale_w, scale_h = target_w / w, target_h / h

        if scale_w < scale_h:
            new_w = target_w
            new_h = min(math.ceil(h * scale_w), target_h)
        else:
            new_h = target_h
            new_w = min(math.ceil(w * scale_h), target_w)

        return new_h, new_w, target_h, target_w

    def __call__(self, image):
        all_images = [self.resize1(image)]

        new_h, new_w, target_h, target_w = self.find_resolution(image)

        image2 = tv_trans.Resize((new_h, new_w), 
                                 interpolation=self.interpolation, 
                                 antialias=True)(image)

        pad_h = (target_h - new_h) // 2
        pad_w = (target_w - new_w) // 2
        image2 = tv_trans.Pad((pad_w, pad_h))(image2)

        for i in range(0, target_h, self.size):
            for j in range(0, target_w, self.size):
                patch = image2[..., i: i + self.size, j: j + self.size]
                patch = self.resize2(patch)
                all_images.append(patch)

        all_images = torch.stack(all_images).transpose(0, 1)
        all_images = all_images.float().clamp(0, 255)
        all_images = all_images * self.rescale_factor

        all_images = (all_images - self.mean) / self.std
        return all_images



class LLaVA_model(object):

    supported_models = [
        "llava-hf/llava-1.5-7b-hf",
        "llava-hf/llava-1.5-13b-hf",
        "llava-hf/llava-v1.6-vicuna-7b-hf",
        "llava-hf/llava-v1.6-vicuna-13b-hf",
        "llava-hf/llava-v1.6-mistral-7b-hf",
        "llava-hf/llama3-llava-next-8b-hf",
    ]

    def __init__(self, 
                 model_id,
                 device='cpu'):
        super(LLaVA_model, self).__init__()
        self.model_id = model_id
        self.device = device

        self.model = AutoModelForPreTraining.from_pretrained(
            model_id, 
            torch_dtype=torch.float16, 
            low_cpu_mem_usage=True)

        self.model.requires_grad_(False)
        self.model.to(device)

        self.processor = AutoProcessor.from_pretrained(model_id)

        if "llava-1.5" in self.model_id:
            self.llava_type = "llava"
            self.image_transform = CLIP_transform(
                image_processor=self.processor.image_processor,
                device=device)
        else:
            self.llava_type = "llava16"
            self.image_transform = LLaVA16_transform(
                image_processor=self.processor.image_processor,
                device=device)

        self.model_type = "vllm"

    def get_prompt(self, question, answer=None):
        conversation = [
            {

              "role": "user",
              "content": [
                  {"type": "text", "text": question},
                  {"type": "image"},
                ],
            },
        ]
        prompt = self.processor.apply_chat_template(
            conversation, add_generation_prompt=True) 

        output = dict()
        output["question"] = prompt

        question_ids = self.processor.tokenizer(
            prompt, 
            return_tensors='pt', 
            add_special_tokens=True).input_ids.to(self.device)
        output["question_ids"] = question_ids

        if answer is not None:
            if self.model_id == "llava-hf/llama3-llava-next-8b-hf":
                answer = "\n" + answer
            output["answer"] = answer
            answer_ids = self.processor.tokenizer(
                answer, 
                return_tensors='pt', 
                add_special_tokens=False).input_ids.to(self.device)
            input_ids = torch.cat([question_ids,
                                   answer_ids], dim=1)
            output["input_ids"] = input_ids
            output["answer_ids"] = answer_ids
            output["len_answer_ids"] = answer_ids.shape[1]
        return output

    def get_pixel_values(self, image):
        pixel_values = self.image_transform(image)
        return pixel_values

    def compute_loss(self, image, question, answer):
        image = image.to(self.device)

        prompt = self.get_prompt(question, answer)
        input_ids = prompt["input_ids"]
        len_answer_ids = prompt["len_answer_ids"]

        attention_mask = torch.ones_like(input_ids)

        pixel_values = self.get_pixel_values(image)

        inputs = dict(input_ids=input_ids,
                      attention_mask=attention_mask,
                      pixel_values=pixel_values)

        if self.llava_type == "llava16":
            image_sizes = torch.tensor(image.shape[2:])
            image_sizes = image_sizes.unsqueeze(0).to(self.device)
            inputs["image_sizes"] = image_sizes

        outputs = self.model(**inputs)

        logits = outputs.logits
        logits = logits[:, -len_answer_ids-1:-1].mT
        labels = prompt["answer_ids"]
        loss = torch.nn.CrossEntropyLoss()(logits, labels)
        return loss

    @torch.no_grad()
    def generate(self, image, question):
        pil_image = to_pil(image)
        prompt = self.get_prompt(question)
        inputs = self.processor(images=pil_image, 
                                text=prompt["question"], 
                                return_tensors="pt").to(self.device)
        output = self.model.generate(**inputs,
                                do_sample=False, 
                                max_new_tokens=256)[0].cpu()
        text = self.processor.decode(output, skip_special_tokens=True)
        return output, text
