# Code adapted from: https://github.com/openai/CLIP
# License: MIT

import os
from typing import Callable, Tuple, Union, Optional, cast, List
from functools import wraps

import torch
from torch import Tensor, IntTensor
from torch.nn import functional as F

from torchvision import transforms as T
from torchvision.transforms import functional as TF
from torchvision.transforms import InterpolationMode

from PIL.Image import Image

import clip
from clip.model import CLIP as _CLIP
from clip.clip import _convert_image_to_rgb, load, available_models


_IMAGE_MEAN = [0.48145466, 0.4578275, 0.40821073]
_IMAGE_STD = [0.26862954, 0.26130258, 0.27577711]


def _transform(n_px) -> T.Compose:
    return T.Compose([
        T.Resize(n_px, interpolation=InterpolationMode.BICUBIC),
        T.CenterCrop(n_px),
        _convert_image_to_rgb,
        T.ToTensor(),
        T.Normalize(mean=_IMAGE_MEAN, std=_IMAGE_STD),
    ])


class CLIP(_CLIP):

    def tokenize(self, texts: Union[str, List[str]]) -> IntTensor:
        tokens = clip.tokenize(texts, context_length=self.context_length, truncate=True)
        tokens = cast(IntTensor, tokens)
        return tokens

    def encode_image(self, image: Tensor) -> Tensor:
        image = image.to(self.device)
        return super().encode_image(image)

    def encode_text(self, text: Union[IntTensor, List[IntTensor], List[str], List[List[str]]]) -> Tensor:
        if isinstance(text, list):
            if isinstance(text[0], list):  # List[List[str]]
                text = [self.tokenize(cast(str, t)) for t in text]
            elif isinstance(text[0], str):  # List[str]
                text = self.tokenize(cast(str, text))

        text = cast(Union[IntTensor, List[IntTensor]], text)

        if isinstance(text, list):  # List[IntTensor]
            _encode = super().encode_text
            return torch.stack([_encode(t.to(self.device)).mean(dim=0) for t in text], dim=0)
        else:  # IntTensor
            return super().encode_text(text.to(self.device))

    def encode_prompt(self, prompts: Tensor, eot_idxs: IntTensor):
        x = prompts.type(self.dtype)  # [batch_size, n_ctx, d_model]

        x = x + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)

        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), eot_idxs] @ self.text_projection

        return x

    def similarity(
        self,
        image_features: Tensor,
        text_features: Tensor,
        *,
        temperature: Optional[float] = None,
        softmax: bool = True,
    ) -> Tensor:
        image_features = F.normalize(image_features, dim=-1)
        text_features = F.normalize(text_features, dim=-1)

        image_features = image_features.type(self.dtype)
        text_features = text_features.type(self.dtype)

        logit_scale = self.logit_scale.exp() if temperature is None else (1 / temperature)
        image_logits = image_features @ text_features.t()

        if softmax:
            return (logit_scale * image_logits).softmax(dim=-1)
        else:
            return logit_scale * image_logits

    def forward(
        self,
        image: Tensor,
        text: Union[IntTensor, List[IntTensor], List[str], List[List[str]]],
        *,
        temperature: Optional[float] = None,
        softmax: bool = True,
    ) -> Tensor:
        image_features = self.encode_image(image)
        text_features = self.encode_text(text)
        image_probs = self.similarity(image_features, text_features, temperature=temperature, softmax=softmax)
        return image_probs

    @property
    def dtype(self) -> torch.dtype:
        return self.logit_scale.dtype

    @property
    def device(self) -> torch.device:
        return self.logit_scale.device


@wraps(load)
def load_clip(
    name: str,
    device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
    jit: bool = False,
    download_root: Optional[str] = None,
) -> Tuple[CLIP, Callable[[Image], Tensor]]:
    if name not in available_models() and not os.path.isfile(name):
        raise RuntimeError(f"Model {name} not found")

    model, _ = load(name, device=device, jit=jit, download_root=download_root)  # type: ignore
    model.__class__ = CLIP  # type: ignore
    model = cast(CLIP, model)

    if jit:
        preprocess = _transform(model.input_resolution.item())  # type: ignore
    else:
        preprocess = _transform(model.visual.input_resolution)

    return model, preprocess
