import torch
from torchvision import transforms as tv_trans
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification


class CLIP_transform(object):

    def __init__(self, image_processor, device):
        super(CLIP_transform, self).__init__()

        # resize
        interpolation = getattr(image_processor, "resample", 3)
        if "shortest_edge" in image_processor.size:
            size = image_processor.size["shortest_edge"]
        elif "height" in image_processor.size:
            size = (image_processor.size["height"],
                    image_processor.size["width"])
        else:
            raise ValueError("Unsuported `image_processor`")

        self.resize = tv_trans.Resize(size=size,
                                      interpolation=interpolation,
                                      antialias=True)

        # center crop
        do_center_crop = getattr(image_processor, "do_center_crop", False)
        if do_center_crop and hasattr(image_processor, "crop_size"):
            crop_size = image_processor.crop_size
            ch = image_processor.crop_size["height"]
            cw = image_processor.crop_size["width"]
            self.resize = tv_trans.Compose(
                [self.resize, tv_trans.CenterCrop((ch, cw))])

        self.rescale_factor = image_processor.rescale_factor

        mean = torch.tensor(image_processor.image_mean).to(device)
        std = torch.tensor(image_processor.image_std).to(device)

        self.mean = mean.view(3, 1, 1)
        self.std = std.view(3, 1, 1)

    def __call__(self, image):
        image = self.resize(image)
        image = image.float().clamp(0, 255)
        image = image * self.rescale_factor
        image = (image - self.mean) / self.std
        return image

class CLIP_model(object):

    supported_models = [
        "openai/clip-vit-base-patch32",
        "openai/clip-vit-large-patch14",
        "openai/clip-vit-base-patch16",
        "openai/clip-vit-large-patch14-336",
        "google/siglip-large-patch16-384",
        "google/siglip-large-patch16-256",
        "google/siglip-base-patch16-256",
        "google/siglip-so400m-patch14-384"
    ]

    def __init__(self,
                 model_id,
                 device='cpu',
                 ):
        super(CLIP_model, self).__init__()

        self.model_id = model_id
        self.device = device

        self.model = AutoModelForZeroShotImageClassification.from_pretrained(model_id)
        self.model.requires_grad_(False)
        self.model.to(device)

        self.processor = AutoProcessor.from_pretrained(model_id)

        self.image_transform = CLIP_transform(
            image_processor=self.processor.image_processor,
            device=device)

        self.is_siglip = "SiglipConfig" in self.model.__doc__
        self.padding = "max_length" if self.is_siglip else True

        self.model_type = "clip"

    def get_prompt(self, text_pair):
        text_inputs = self.processor(
            text=text_pair,
            padding=self.padding,
            return_tensors="pt").to(self.device)
        return text_inputs

    def get_pixel_values(self, image):
        pixel_values = self.image_transform(image)
        return pixel_values

    def compute_loss(self,
                     image,
                     target_text,
                     untarget_text):
        if not isinstance(image, list):
            images = [image]
        else:
            images = image

        pixel_values = []
        for image in images:
            image = image.to(self.device)
            pixel_value = self.get_pixel_values(image)
            pixel_values.append(pixel_value)

        pixel_values = torch.cat(pixel_values, dim=0)

        if isinstance(target_text, str):
            target_text = [target_text]
        if isinstance(untarget_text, str):
            untarget_text = [untarget_text]

        num_targets = len(target_text)

        text_pair = target_text + untarget_text
        text_inputs = self.get_prompt(text_pair)

        outputs = self.model(pixel_values=pixel_values, **text_inputs)

        logits_per_image = outputs.logits_per_image.double()

        if self.is_siglip:
            probs = torch.sigmoid(logits_per_image)
            probs = probs / probs.sum(dim=-1, keepdim=True)
        else:
            probs = logits_per_image.softmax(dim=-1)
        loss = - probs[:, :num_targets].log().mean()
        return loss.to(outputs.logits_per_image.dtype)
