import math

import torch
from torchvision import transforms as tv_trans
from transformers import AutoProcessor, AutoModelForVision2Seq

from .utils import to_pil


class Idefics2_transform(object):

    def __init__(self, image_processor, device):
        super(Idefics2_transform, self).__init__()
        self.image_processor = image_processor
        self.do_split = getattr(image_processor, "do_image_splitting", True)

        assert getattr(image_processor, "do_resize", True)

        self.rescale_factor = image_processor.rescale_factor

        mean = torch.tensor(image_processor.image_mean).to(device)
        std = torch.tensor(image_processor.image_std).to(device)

        self.mean = mean.view(3, 1, 1)
        self.std = std.view(3, 1, 1)

    def resize(self, image):
        _, _, h, w = image.shape

        interpolation = getattr(self.image_processor, "resample", 2)
        min_len = self.image_processor.size["shortest_edge"]
        max_len = self.image_processor.size["longest_edge"]

        aspect_ratio = w / h

        if w >= h and w > max_len:
            w = max_len
            h = int(w / aspect_ratio)
        elif h > w and h > max_len:
            h = max_len
            w = int(h * aspect_ratio)

        size = (max(h, min_len), max(w, min_len))
        
        resizer = tv_trans.Resize(size=size, 
                                  interpolation=interpolation, 
                                  antialias=True)
        image = resizer(image)
        return image

    def __call__(self, image):
        if self.do_split:
            b, c, h, w = image.shape
            all_images = [
                image[..., :h // 2, :w // 2],
                image[..., :h // 2, w // 2:],
                image[..., h // 2:, :w // 2],
                image[..., h // 2:, w // 2:],
                image,
                ]
        else:
            all_images = [image]


        max_h = max_w = 0
        for idx, image in enumerate(all_images):
            image = self.resize(image)
            max_h = max(max_h, image.shape[2])
            max_w = max(max_w, image.shape[3])

            image = image.float().clamp(0, 255)
            image = image * self.rescale_factor 
            image = (image - self.mean) / self.std
            all_images[idx] = image

        pixel_attention_mask = torch.zeros(b, len(all_images), max_h, max_w,
                                           device=self.mean.device,
                                           dtype=torch.int64)
        for idx, image in enumerate(all_images):
            _, _, h, w = image.shape
            ph, pw = max_h - h, max_w - w
            all_images[idx] = tv_trans.Pad((0, 0, pw, ph))(image)
            pixel_attention_mask[:, idx, :h, :w] = 1

        all_images = torch.stack(all_images).transpose(0, 1)
        return all_images, pixel_attention_mask


class Idefics2_model(object):

    supported_models = [
        "HuggingFaceM4/idefics2-8b"
    ]

    def __init__(self, 
                 model_id,
                 device='cpu'):
        super(Idefics2_model, self).__init__()
        self.model_id = model_id
        self.device = device

        self.model = AutoModelForVision2Seq.from_pretrained(
            model_id, 
            torch_dtype=torch.float16, 
            low_cpu_mem_usage=True)

        self.model.requires_grad_(False)
        self.model.to(self.device)

        self.processor = AutoProcessor.from_pretrained(model_id,size= {"longest_edge": 448, "shortest_edge": 378})

        self.image_transform = Idefics2_transform(
                image_processor=self.processor.image_processor,
                device=device)

        self.model_type = "vllm"

    def get_prompt(self, question, answer=None):
        conversation = [
            {

              "role": "user",
              "content": [
                  {"type": "image"},
                  {"type": "text", "text": question},
                ],
            },
        ]
        prompt = self.processor.apply_chat_template(
            conversation, add_generation_prompt=True) 

        output = dict()
        output["question"] = prompt
        question_ids = self.processor(
            prompt, 
            return_tensors='pt').input_ids.to(self.device)
        output["question_ids"] = question_ids

        if answer is not None:
            output["answer"] = answer
            answer_ids = self.processor.tokenizer(
                answer, 
                return_tensors='pt', 
                add_special_tokens=False).input_ids.to(self.device)
            input_ids = torch.cat([question_ids,
                                   answer_ids], dim=1)
            output["input_ids"] = input_ids
            output["answer_ids"] = answer_ids
            output["len_answer_ids"] = answer_ids.shape[1]
        return output

    def get_pixel_values(self, image):
        pixel_values, pixel_attention_mask = self.image_transform(image)
        return pixel_values, pixel_attention_mask

    def compute_loss(self, image, question, answer):
        image = image.to(self.device)

        prompt = self.get_prompt(question, answer)
        input_ids = prompt["input_ids"]

        len_answer_ids = prompt["len_answer_ids"]

        attention_mask = torch.ones_like(input_ids)

        pixel_values, pixel_attention_mask = self.get_pixel_values(image)

        inputs = dict(input_ids=input_ids,
                      attention_mask=attention_mask,
                      pixel_values=pixel_values,
                      pixel_attention_mask=pixel_attention_mask)

        outputs = self.model(**inputs)

        logits = outputs.logits
        logits = logits[:, -len_answer_ids-1:-1].mT
        labels = prompt["answer_ids"]
        loss = torch.nn.CrossEntropyLoss()(logits, labels)
        return loss

    @torch.no_grad()
    def generate(self, image, question):
        pil_image = to_pil(image)
        prompt = self.get_prompt(question)
        inputs = self.processor(images=pil_image, 
                                text=prompt["question"], 
                                return_tensors="pt").to(self.device)
        output = self.model.generate(**inputs,
                                do_sample=False, 
                                max_new_tokens=100)[0].cpu()
        text = self.processor.decode(output, skip_special_tokens=True)
        return output, text

