from collections.abc import Callable, Sequence

import torch
from torch import nn

from open_clip.transformer import VisionTransformer

from PIL.Image import Image

from ..laft import cosine_similarity as _cosine_similarity
from ..clip import CLIP, load_clip


class BufferListDescriptor:
    def __init__(self, name: str, length: int):
        self.name = name
        self.length = length

    def __get__(self, instance, object_type = None) -> list[torch.Tensor]:
        del object_type
        return [getattr(instance, f"_{self.name}_{i}") for i in range(self.length)]

    def __set__(self, instance, values: list[torch.Tensor]):
        for i, value in enumerate(values):
            setattr(instance, f"_{self.name}_{i}", value)


class WinCLIP(nn.Module):
    def __init__(
        self,
        clip: CLIP,
        scales: Sequence[int] = (2, 3),
        temperature: float = 0.07,
    ):
        super().__init__()

        self.clip = clip
        self.scales = tuple(scales)
        self.temperature = temperature
        self.k_shot = 0

        self.clip.visual.output_tokens = True
        self.grid_size: tuple[int, int] = self.clip.visual.grid_size

        self.register_buffer_list("masks", [make_masks(self.grid_size, scale) for scale in self.scales], persistent=False)
        self.register_buffer("_text_embeds", torch.empty(0))
        self.register_buffer_list("_visual_embeds", [torch.empty(0) for _ in self.scales])
        self.register_buffer("_patch_embeds", torch.empty(0))

    def register_buffer_list(self, name: str, values: list[torch.Tensor], persistent: bool = True):
        for i, value in enumerate(values):
            self.register_buffer(f"_{name}_{i}", value, persistent=persistent)
        setattr(self.__class__, name, BufferListDescriptor(name, len(values)))

    @torch.no_grad()
    def setup(
        self,
        normal_prompts: list[str] | None = None,
        anomaly_prompts: list[str] | None = None,
        reference_images: torch.Tensor | None = None,
    ):
        if normal_prompts is None and anomaly_prompts is None:
            pass
        elif normal_prompts is not None and anomaly_prompts is not None:
            self.k_shot = 0
            self._text_embeds = torch.cat((
                self.clip.encode_text(normal_prompts).mean(dim=0, keepdim=True),
                self.clip.encode_text(anomaly_prompts).mean(dim=0, keepdim=True),
            ))
        else:
            raise ValueError("Both 'normal_prompts' and 'anomaly_prompts' should be provided")

        if reference_images is not None:
            self.k_shot = reference_images.shape[0]
            _, self._visual_embeds, self._patch_embeds = self.encode_image(reference_images)

    def encode_image(self, batch: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor], torch.Tensor]:
        outputs = {}

        def get_feature_map(name: str) -> Callable:
            def hook(_model, inputs: tuple[torch.Tensor,], _outputs: torch.Tensor):
                del _model, _outputs
                outputs[name] = inputs[0].detach()
            return hook

        self.clip.visual.patch_dropout.register_forward_hook(get_feature_map("feature_map"))

        image_embeds, patch_embeds = self.clip.encode_image(batch)

        feature_map = outputs["feature_map"]
        window_embeds = [self._get_window_embeds(feature_map, masks) for masks in self.masks]

        return image_embeds, window_embeds, patch_embeds

    def _get_window_embeds(self, feature_map: torch.Tensor, masks: torch.Tensor) -> torch.Tensor:
        batch_size = feature_map.shape[0]
        n_masks = masks.shape[1]

        class_index = torch.zeros(1, n_masks, dtype=torch.int).to(feature_map.device)
        masks = torch.cat((class_index, masks + 1)).T  # +1 to account for class index
        masked = torch.cat([torch.index_select(feature_map, 1, mask) for mask in masks])

        masked = self.clip.visual.patch_dropout(masked)
        masked = self.clip.visual.ln_pre(masked)

        masked = masked.permute(1, 0, 2)  # NLD -> LND
        masked = self.clip.visual.transformer(masked)
        masked = masked.permute(1, 0, 2)  # LND -> NLD

        masked = self.clip.visual.ln_post(masked)
        pooled, _ = self.clip.visual._global_pool(masked)

        if self.clip.visual.proj is not None:
            pooled = pooled @ self.clip.visual.proj

        return pooled.reshape((n_masks, batch_size, -1)).permute(1, 0, 2)

    @torch.no_grad()
    def forward(self, batch: torch.Tensor) -> torch.Tensor:
        image_embeds, window_embeds, patch_embeds = self.encode_image(batch)

        image_scores = (cosine_similarity(image_embeds, self._text_embeds) / self.temperature).softmax(dim=-1)[..., 1]

        if self.k_shot:
            few_shot_scores = self._compute_few_shot_scores(patch_embeds, window_embeds)
            image_scores = (image_scores + few_shot_scores.amax(dim=(-2, -1))) / 2

        return image_scores

    def _compute_few_shot_scores(
        self,
        patch_embeds: torch.Tensor,
        window_embeds: list[torch.Tensor],
    ) -> torch.Tensor:
        multi_scale_scores = [
            visual_association_score(patch_embeds, self._patch_embeds).reshape((-1, *self.grid_size)),
        ] + [
            harmonic_aggregation(
                visual_association_score(window_embed, reference_embed),
                self.grid_size, mask,
            )
            for window_embed, reference_embed, mask in zip(
                window_embeds, self._visual_embeds, self.masks, strict=True,
            )
        ]
        return torch.stack(multi_scale_scores).mean(dim=0)

    @property
    def dtype(self) -> torch.dtype:
        return self.clip.dtype

    @property
    def device(self) -> torch.device:
        return self.clip.device


def cosine_similarity(input1: torch.Tensor, input2: torch.Tensor) -> torch.Tensor:
    ndim = input1.ndim
    input1 = input1.unsqueeze(0) if input1.ndim == 2 else input1
    input2 = input2.repeat(input1.shape[0], 1, 1) if input2.ndim == 2 else input2
    similarity = _cosine_similarity(input1, input2)
    if ndim == 2:
        return similarity.squeeze(0)
    return similarity


def harmonic_aggregation(window_scores: torch.Tensor, output_size: tuple, masks: torch.Tensor) -> torch.Tensor:
    """Perform harmonic aggregation on window scores.

    Computes a single score for each patch location by aggregating the scores of all windows that cover the patch.
    Scores are aggregated using the harmonic mean.

    Args:
        window_scores (torch.Tensor): Tensor of shape ``(batch_size, n_masks)`` representing the scores for each sliding
            window location.
        output_size (tuple): Tuple of integers representing the output size ``(H, W)``.
        masks (torch.Tensor): Tensor of shape ``(n_patches_per_mask, n_masks)`` representing the masks. Each mask is
            set of indices indicating which patches are covered by the mask.

    Returns:
        torch.Tensor: Tensor of shape ``(batch_size, H, W)```` representing the aggregated scores.

    Examples:
        >>> # example for a 3x3 patch grid with 4 sliding windows of size 2x2
        >>> window_scores = torch.tensor([[1.0, 0.75, 0.5, 0.25]])
        >>> output_size = (3, 3)
        >>> masks = torch.Tensor([[0, 1, 3, 4],
                                  [1, 2, 4, 5],
                                  [3, 4, 6, 7],
                                  [4, 5, 7, 8]])
        >>> harmonic_aggregation(window_scores, output_size, masks)
        tensor([[[1.0000, 0.8571, 0.7500],
                 [0.6667, 0.4800, 0.3750],
                 [0.5000, 0.3333, 0.2500]]])
    """
    batch_size = window_scores.shape[0]
    height, width = output_size

    scores = []
    for idx in range(height * width):
        patch_mask = torch.any(masks == idx, dim=0)  # boolean tensor indicating which masks contain the patch
        scores.append(sum(patch_mask) / (1 / window_scores.T[patch_mask]).sum(dim=0))

    return torch.stack(scores).T.reshape(batch_size, height, width).nan_to_num(posinf=0.0)


def visual_association_score(embeddings: torch.Tensor, reference_embeds: torch.Tensor) -> torch.Tensor:
    """Compute visual association scores between a set of embeddings and a set of reference embeddings.

    Returns a visual association score for each patch location in the inputs. The visual association score is the
    minimum cosine distance between each embedding and the reference embeddings. Equation (4) in the paper.

    Args:
        embeddings (torch.Tensor): Tensor of shape ``(batch_size, n_patches, dimensionality)`` representing the
            embeddings.
        reference_embeds (torch.Tensor): Tensor of shape ``(n_reference_embeds, n_patches, dimensionality)``
            representing the reference embeddings.

    Returns:
        torch.Tensor: Tensor of shape ``(batch_size, n_patches)`` representing the visual association scores.

    Examples:
        >>> embeddings = torch.tensor([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]])
        >>> reference_embeds = torch.tensor([[[0.0, 1.0, 0.0], [1.0, 1.0, 0.0]]])
        >>> visual_association_score(embeddings, reference_embeds)
        tensor([[0.1464, 0.0000]])

        >>> embeddings = torch.randn(10, 100, 128)
        >>> reference_embeds = torch.randn(2, 100, 128)
        >>> visual_association_score(embeddings, reference_embeds).shape
        torch.Size([10, 100])
    """
    reference_embeds = reference_embeds.reshape(-1, embeddings.shape[-1])
    scores = cosine_similarity(embeddings, reference_embeds)
    return (1 - scores).min(dim=-1)[0] / 2


def make_masks(grid_size: tuple[int, int], kernel_size: int) -> torch.Tensor:
    """Make a set of masks to select patches from a feature map in a sliding window fashion.

    Each column in the returned tensor represents a mask. Each mask is a set of indices indicating which patches are
    covered by the mask. The number of masks is equal to the number of sliding windows that fit in the feature map.

    Args:
        grid_size (tuple[int, int]): The shape of the feature map.
        kernel_size (int): The size of the kernel in number of patches.

    Returns:
        torch.Tensor: Set of masks of shape ``(n_patches_per_mask, n_masks)``.

    Examples:
        >>> make_masks((3, 3), 2)
        tensor([[0, 1, 3, 4],
                [1, 2, 4, 5],
                [3, 4, 6, 7],
                [4, 5, 7, 8]], dtype=torch.int32)

        >>> make_masks((4, 4), 2)
        tensor([[ 0,  1,  2,  4,  5,  6,  8,  9, 10],
                [ 1,  2,  3,  5,  6,  7,  9, 10, 11],
                [ 4,  5,  6,  8,  9, 10, 12, 13, 14],
                [ 5,  6,  7,  9, 10, 11, 13, 14, 15]], dtype=torch.int32)

        >>> make_masks((4, 4), 2, stride=2)
        tensor([[ 0,  2,  8, 10],
                [ 1,  3,  9, 11],
                [ 4,  6, 12, 14],
                [ 5,  7, 13, 15]], dtype=torch.int32)
    """
    if any(dim < kernel_size for dim in grid_size):
        raise ValueError(
            "Each dimension of the grid size must be greater than or equal to "
            f"the kernel size. Got grid size {grid_size} and kernel size {kernel_size}."
        )
    height, width = grid_size
    grid = torch.arange(height * width).reshape(1, height, width)
    return nn.functional.unfold(grid.float(), kernel_size=kernel_size, stride=1).int()


def load_winclip(
    backbone: str,
    scales: Sequence[int] = (2, 3),
    temperature: float = 0.07,
    device: str | torch.device = "cuda" if torch.cuda.is_available() else "cpu",
    jit: bool = False,
    download_root: str | None = None,
) -> tuple[WinCLIP, Callable[[Image], torch.Tensor]]:
    clip, preprocess = load_clip(name=backbone, device=device, jit=jit, download_root=download_root)

    if not isinstance(clip.visual, VisionTransformer):
        raise ValueError("WinCLIP only supports open_clip's VisionTransformer visual encoder")

    model = WinCLIP(
        clip=clip,
        scales=scales,
        temperature=temperature,
    )
    model.to(device)

    return model, preprocess  # type: ignore
