import open_clip
import torch
from torchvision import transforms as tv_trans


pretrain_lookup = {
    "ViT-H-14-378-quickgelu": "dfn5b",
    "ViT-H-14-quickgelu": "dfn5b",
    "ViT-SO400M-14-SigLIP-384": "webli",
    "ViT-SO400M-14-SigLIP": "webli",
    "ViT-L-16-SigLIP-384": "webli",
    "ViT-bigG-14": "laion2b_s39b_b160k",
    "ViT-H-14-CLIPA-336": "datacomp1b",
    "ViT-H-14-quickgelu": "metaclip_fullcc",
}


class OpenCLIP_transform(object):

    def __init__(self, image_processor):
        super(OpenCLIP_transform, self).__init__()
        self.transforms = image_processor.transforms

    def __call__(self, image):
        for transform in self.transforms:
            if isinstance(transform, torch.nn.Module):
                image = transform(image)
            elif isinstance(transform, tv_trans.ToTensor):
                image = image.float().clamp(0, 255) / 255
        return image


class OpenCLIP_model(object):

    supported_models = [
        "open_clip/ViT-H-14-378-quickgelu",
        "open_clip/ViT-H-14-quickgelu",
        "open_clip/ViT-SO400M-14-SigLIP-384",
        "open_clip/ViT-SO400M-14-SigLIP",
        "open_clip/ViT-L-16-SigLIP-384",
        "open_clip/ViT-bigG-14",
        "open_clip/ViT-H-14-CLIPA-336",
        "open_clip/ViT-H-14-quickgelu",
    ]

    def __init__(self, 
                 model_id,
                 device='cpu',
                 ):
        super(OpenCLIP_model, self).__init__()
        model_id = model_id.replace("open_clip/", "")
        self.model_id = model_id
        self.device = device

        model, _, preprocess = open_clip.create_model_and_transforms(
            model_id, pretrained=pretrain_lookup[model_id], force_patch_dropout=0.2)
        model.requires_grad_(False)
        self.model = model
        self.model.train()
        self.model.to(device)

        self.tokenizer = open_clip.get_tokenizer(model_id)

        self.image_transform = OpenCLIP_transform(preprocess)

    def get_prompt(self, text_pair):
        text_inputs = self.tokenizer(text_pair).to(self.device)
        return text_inputs

    def get_pixel_values(self, image):
        pixel_values = self.image_transform(image)
        return pixel_values

    def compute_loss(self, 
                     image,
                     target_text,
                     untarget_text):
        if not isinstance(image, list):
            images = [image]
        else:
            images = image

        pixel_values = []
        for image in images:
            image = image.to(self.device)
            pixel_value = self.get_pixel_values(image)
            pixel_values.append(pixel_value)

        pixel_values = torch.cat(pixel_values, dim=0)
        
        if isinstance(target_text, str):
            target_text = [target_text]
        if isinstance(untarget_text, str):
            untarget_text = [untarget_text]

        num_targets = len(target_text)

        text_pair = target_text + untarget_text
        text_inputs = self.get_prompt(text_pair)
        pixel_values = self.get_pixel_values(image)

        image_feat = self.model.encode_image(pixel_values)
        text_feat = self.model.encode_text(text_inputs)

        image_feat = image_feat.double()
        text_feat = text_feat.double()

        image_feat = image_feat / image_feat.norm(dim=-1, keepdim=True)
        text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True)

        logits = image_feat @ text_feat.T
        probs = logits.mul(10).softmax(dim=-1)

        loss = - probs[:, :num_targets].log().mean()
        return loss

