from collections.abc import Callable, Sequence

import torch

from open_clip.transformer import VisionTransformer

from PIL.Image import Image

from ..clip import CLIP, load_clip
from ..laft import inner, orthogonal
from .winclip import WinCLIP, cosine_similarity, visual_association_score, harmonic_aggregation


class WinCLIPwLAFT(WinCLIP):
    def __init__(
        self,
        clip: CLIP,
        scales: Sequence[int] = (2, 3),
        temperature: float = 0.07,
    ):
        super().__init__(clip, scales, temperature)

        self.laft_transform = None
        self.laft_basis = None

    @torch.no_grad()
    def setup(
        self,
        normal_prompts: list[str] | None = None,
        anomaly_prompts: list[str] | None = None,
        reference_images: torch.Tensor | None = None,
    ):
        if normal_prompts is None and anomaly_prompts is None:
            pass
        elif normal_prompts is not None and anomaly_prompts is not None:
            self.k_shot = 0
            self._text_embeds = torch.cat((
                self.clip.encode_text(normal_prompts).mean(dim=0, keepdim=True),
                self.clip.encode_text(anomaly_prompts).mean(dim=0, keepdim=True),
            ))
        else:
            raise ValueError("Both 'normal_prompts' and 'anomaly_prompts' should be provided")

        if reference_images is not None:
            self.k_shot = reference_images.shape[0]
            _, self._visual_embeds, self._patch_embeds = self.encode_image(reference_images)

    @torch.no_grad()
    def setup_laft(self, basis: torch.Tensor, guide: bool = True):
        self.laft_basis = basis
        self.laft_transform = inner if guide else orthogonal
        self._laft_text_embeds = self.laft_transform(self._text_embeds, self.laft_basis)
        if self._visual_embeds[0].numel() > 0:
            self._laft_visual_embeds = [self.laft_transform(v, self.laft_basis) for v in self._visual_embeds]

    @torch.no_grad()
    def forward(self, batch: torch.Tensor, laft_image: bool = False) -> torch.Tensor:
        image_embeds, window_embeds, patch_embeds = self.encode_image(batch)

        if self.laft_basis is not None:
            if laft_image:
                image_embeds = self.laft_transform(image_embeds, self.laft_basis)
            window_embeds = [self.laft_transform(v, self.laft_basis) for v in window_embeds]

        image_scores = (cosine_similarity(image_embeds, self._laft_text_embeds) / self.temperature).softmax(dim=-1)[..., 1]

        if self.k_shot:
            few_shot_scores = self._compute_few_shot_scores(patch_embeds, window_embeds)
            image_scores = (image_scores + few_shot_scores.amax(dim=(-2, -1))) / 2

        return image_scores

    @torch.no_grad()
    def forward_cache(
        self,
        image_embeds: torch.Tensor,
        window_embeds: list[torch.Tensor],
        patch_embeds: torch.Tensor,
        laft_image: bool = False,
    ) -> torch.Tensor:
        if self.laft_basis is not None:
            if laft_image:
                image_embeds = self.laft_transform(image_embeds, self.laft_basis)
            window_embeds = [self.laft_transform(v, self.laft_basis) for v in window_embeds]

        image_scores = (cosine_similarity(image_embeds, self._laft_text_embeds) / self.temperature).softmax(dim=-1)[..., 1]

        if self.k_shot:
            few_shot_scores = self._compute_few_shot_scores(patch_embeds, window_embeds)
            image_scores = (image_scores + few_shot_scores.amax(dim=(-2, -1))) / 2

        return image_scores

    def _compute_few_shot_scores(
        self,
        patch_embeds: torch.Tensor,
        window_embeds: list[torch.Tensor],
    ) -> torch.Tensor:
        multi_scale_scores = [
            visual_association_score(patch_embeds, self._patch_embeds).reshape((-1, *self.grid_size)),
        ] + [
            harmonic_aggregation(
                visual_association_score(window_embed, reference_embed),
                self.grid_size, mask,
            )
            for window_embed, reference_embed, mask in zip(
                window_embeds, self._laft_visual_embeds, self.masks, strict=True,
            )
        ]
        return torch.stack(multi_scale_scores).mean(dim=0)

def load_winclip_laft(
    backbone: str,
    scales: Sequence[int] = (2, 3),
    temperature: float = 0.07,
    device: str | torch.device = "cuda" if torch.cuda.is_available() else "cpu",
    jit: bool = False,
    download_root: str | None = None,
) -> tuple[WinCLIPwLAFT, Callable[[Image], torch.Tensor]]:
    clip, preprocess = load_clip(name=backbone, device=device, jit=jit, download_root=download_root)

    if not isinstance(clip.visual, VisionTransformer):
        raise ValueError("WinCLIPwLAFT only supports open_clip's VisionTransformer visual encoder")

    model = WinCLIPwLAFT(
        clip=clip,
        scales=scales,
        temperature=temperature,
    )
    model.to(device)

    return model, preprocess  # type: ignore
