import open_clip
import torch, torchvision

class preference_loss():# @markdown load a CLIP model and define the loss function
    def __init__(self, prompt, device="cuda"):
        self.device = device
        self.clip_model, _, preprocess = open_clip.create_model_and_transforms(
            "ViT-B-32", pretrained="openai"
        )
        self.clip_model.to(device)

        # Transforms to resize and augment an image + normalize to match CLIP's training data
        self.tfms = torchvision.transforms.Compose(
            [
                torchvision.transforms.RandomResizedCrop(224),  # Random CROP each time
                torchvision.transforms.RandomAffine(
                    5
                ),  # One possible random augmentation: skews the image
                torchvision.transforms.RandomHorizontalFlip(),  # You can add additional augmentations if you like
                torchvision.transforms.Normalize(
                    mean=(0.48145466, 0.4578275, 0.40821073),
                    std=(0.26862954, 0.26130258, 0.27577711),
                ),
            ]
        )

        self.prompt = prompt
        # We embed a prompt with CLIP as our target
        text = open_clip.tokenize([self.prompt]).to(device)
        with torch.no_grad(), torch.cuda.amp.autocast():
            self.text_features = self.clip_model.encode_text(text)

    # And define a loss function that takes an image, embeds it and compares with
    # the text features of the prompt
    def clip_loss(self, image):
        image_features = self.clip_model.encode_image(
            self.tfms(image)
        )  # Note: applies the above transforms
        input_normed = torch.nn.functional.normalize(image_features.unsqueeze(1), dim=2)
        embed_normed = torch.nn.functional.normalize(self.text_features.unsqueeze(0), dim=2)
        dists = (
            input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
        )  # Squared Great Circle Distance
        print("dists:", dists)
        return dists.mean()  # The loss is the mean of all the images in the batch