# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Sequence, Union, Tuple

import torch
import torch.nn.functional as F
from PIL import Image
from torch import Tensor
from torch.nn import Module as _DINOModel
from torchmetrics import Metric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TRANSFORMERS_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
from torchvision import transforms
from torchvision.transforms import Compose as _DINOProcessor
from typing_extensions import Literal

if not _MATPLOTLIB_AVAILABLE:
    __doctest_skip__ = ["DINOScore.plot"]

_DEFAULT_MODEL: str = "dino_vits8"


class DINOV2Score(Metric):
    r"""Calculates `DINO Score`_ which is a image-to-image similarity metric.

    .. note:: Metric is not scriptable

    Args:
        model_name_or_path: string indicating the version of the DINO model to use. Available models are:

            - `"dino_vits8"`

        kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

    Raises:
        ModuleNotFoundError:
            If transformers package is not installed or version is lower than 4.10.0
    """

    is_differentiable: bool = False
    higher_is_better: bool = True
    full_state_update: bool = True
    plot_lower_bound: float = 0.0

    score: Tensor
    n_samples: Tensor
    plot_upper_bound = 100.0

    def __init__(
            self,
            model_name_or_path: Literal[
                "dino_vits8",
            ] = _DEFAULT_MODEL,  # type: ignore[assignment]
            **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)
        self.model, self.processor = self._get_model_and_processor(model_name_or_path)
        self.add_state("score", torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("n_samples", torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum")

    @staticmethod
    def _get_model_and_processor(
            model_name_or_path: Literal[
                "dino_vits8",
            ] = "dino_vits8",
    ) -> Tuple[_DINOModel, _DINOProcessor]:
        if _TRANSFORMERS_AVAILABLE:
            model = torch.hub.load('facebookresearch/dinov2:main', model_name_or_path).to("cuda", dtype=torch.float32)
            processor = transforms.Compose([
                transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
            ])
            return model, processor

        raise ModuleNotFoundError(
            "`dino_score` metric requires `transformers` package be installed."
            " Either install with `pip install transformers>=4.0` or `pip install torchmetrics[multimodal]`."
        )

    @staticmethod
    def _dino_score_update(
            images1: Union[Image.Image, List[Image.Image]],
            images2: Union[Image.Image, List[Image.Image]],
            model: _DINOModel,
            processor: _DINOProcessor,
    ) -> Tuple[Tensor, int]:
        if len(images1) != len(images2):
            raise ValueError(
                f"Expected the number of images to be the same but got {len(images1)} and {len(images2)}"
            )

        device = next(model.parameters()).device

        img1_processed_input = [processor(i) for i in images1]
        img2_processed_input = [processor(i) for i in images2]

        img1_processed_input = torch.stack(img1_processed_input).to(device)
        img2_processed_input = torch.stack(img2_processed_input).to(device)

        img1_features = model(img1_processed_input)
        img2_features = model(img2_processed_input)
        
        img1_features = img1_features / img1_features.norm(p=2, dim=-1, keepdim=True)
        img2_features = img2_features / img2_features.norm(p=2, dim=-1, keepdim=True)

        # cosine similarity between feature vectors
        #score = 100 * F.cosine_similarity(img1_features, img2_features, dim=-1)
        score = 100 * (img1_features * img2_features).sum(axis=-1)
        return score, len(images1)

    def calcul(self, images1: Union[Image.Image, List[Image.Image]],
               images2: Union[Image.Image, List[Image.Image]]) -> float:
        """Update DINO score on a batch of images and text.

        Args:
            images1: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors
            images2: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors

        Raises:
            ValueError:
                If not all images have format [C, H, W]
            ValueError:
                If the number of images do not match
        """
        score, n_samples = self._dino_score_update(images1, images2, self.model, self.processor)
        score = score.sum(0).cpu().item()

        return score
    
    def update(self, images1: Union[Image.Image, List[Image.Image]],
               images2: Union[Image.Image, List[Image.Image]]) -> None:
        """Update DINO score on a batch of images and text.

        Args:
            images1: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors
            images2: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors

        Raises:
            ValueError:
                If not all images have format [C, H, W]
            ValueError:
                If the number of images do not match
        """
        score, n_samples = self._dino_score_update(images1, images2, self.model, self.processor)
        self.score += score.sum(0).cpu()
        self.n_samples += n_samples



    def compute(self) -> Tensor:
        """Compute accumulated dino score."""
        return torch.max(self.score / self.n_samples, torch.zeros_like(self.score))

    def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
        """Plot a single or multiple values from the metric.

        Args:
            val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
                If no value is provided, will automatically call `metric.compute` and plot that result.
            ax: An matplotlib axis object. If provided will add plot to that axis

        Returns:
            Figure and Axes object

        Raises:
            ModuleNotFoundError:
                If `matplotlib` is not installed
        """
        return self._plot(val, ax)
