# Code adapted from: https://github.com/openai/CLIP
# License: MIT

from typing import cast
from collections.abc import Callable

import torch
from torch import Tensor, IntTensor
from torch.nn import functional as F

from PIL.Image import Image

import open_clip
from open_clip.model import CLIP as _CLIP


__all__ = [
    "CLIP",
    "load_clip",
]


class CLIP(_CLIP):

    def tokenize(self, texts: str | list[str]) -> IntTensor:
        tokens = open_clip.tokenize(texts, context_length=self.context_length)
        tokens = cast(IntTensor, tokens)
        return tokens

    def encode_image(self, image: Tensor) -> Tensor:
        image = image.to(self.device)
        return super().encode_image(image)

    def encode_text(self, text: IntTensor | list[IntTensor] | list[str] | list[list[str]]) -> Tensor:
        if isinstance(text, list):
            if isinstance(text[0], list):  # list[list[str]]
                text = [self.tokenize(cast(str, t)) for t in text]
            elif isinstance(text[0], str):  # list[str]
                text = self.tokenize(cast(str, text))

        text = cast(IntTensor | list[IntTensor], text)

        if isinstance(text, list):  # list[IntTensor]
            _encode = super().encode_text
            return torch.stack([_encode(t.to(self.device)).mean(dim=0) for t in text], dim=0)
        else:  # IntTensor
            return super().encode_text(text.to(self.device))

    def encode_prompt(self, prompts: Tensor, eot_idxs: IntTensor):
        x: Tensor = prompts.type(self.dtype)  # [batch_size, n_ctx, d_model]

        x: Tensor = x + self.positional_embedding.type(self.dtype)
        x: Tensor = x.permute(1, 0, 2)  # NLD -> LND
        x: Tensor = self.transformer(x)
        x: Tensor = x.permute(1, 0, 2)  # LND -> NLD
        x: Tensor = self.ln_final(x).type(self.dtype)  # type: ignore

        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x: Tensor = x[torch.arange(x.shape[0]), eot_idxs] @ self.text_projection
        return x

    def similarity(
        self,
        image_features: Tensor,
        text_features: Tensor,
        *,
        temperature: float | None = None,
        softmax: bool = True,
    ) -> Tensor:
        image_features = F.normalize(image_features, dim=-1)
        text_features = F.normalize(text_features, dim=-1)

        image_features = image_features.type(self.dtype)
        text_features = text_features.type(self.dtype)

        logit_scale = self.logit_scale.exp() if temperature is None else (1 / temperature)
        image_logits = image_features @ text_features.t()

        if softmax:
            return (logit_scale * image_logits).softmax(dim=-1)
        else:
            return logit_scale * image_logits

    def forward(
        self,
        image: Tensor,
        text: IntTensor | list[IntTensor] | list[str] | list[list[str]],
        *,
        temperature: float | None = None,
        softmax: bool = True,
    ) -> Tensor:
        image_features = self.encode_image(image)
        text_features = self.encode_text(text)
        image_probs = self.similarity(image_features, text_features, temperature=temperature, softmax=softmax)
        return image_probs

    @property
    def dtype(self) -> torch.dtype:
        return self.logit_scale.dtype

    @property
    def device(self) -> torch.device:
        return self.logit_scale.device


def load_clip(
    name: str,
    device: str | torch.device = "cuda" if torch.cuda.is_available() else "cpu",
    jit: bool = False,
    download_root: str | None = None,
) -> tuple[CLIP, Callable[[Image], Tensor]]:
    if name not in open_clip.list_pretrained(as_str=True):
        try:
            from .clip_oai import load_clip as _load_clip
            return _load_clip(name, device=device, jit=jit, download_root=download_root)
        except RuntimeError:
            raise RuntimeError(f"Model {name} not found")

    backbone, pretrained = name.split(":")
    model, _, preprocess = open_clip.create_model_and_transforms(
        backbone,
        pretrained=pretrained,
        device=device,
        jit=jit,
        cache_dir=download_root,
    )

    model.__class__ = CLIP  # type: ignore
    model = cast(CLIP, model)

    return model, preprocess  # type: ignore
