import torch


class ClassificationInTop:
    def __init__(
        self,
        top_to_check: int,
        classification: int,
        check_for_match: bool,
        original_image: torch.Tensor,
        mse_max: float,
    ):
        self.top_to_check = top_to_check
        self.classification = classification
        self.check_for_match = check_for_match
        self.original_image = original_image
        self.mse_max = mse_max

    def __call__(self, images, classifications):
        _, top_indices = torch.topk(classifications, self.top_to_check, dim=1)
        matching_classification = top_indices == self.classification
        mse_distance = torch.nn.functional.mse_loss(
            images, self.original_image.expand_as(images), reduction="none"
        ).mean(dim=(1, 2, 3))
        index_in_top_k = torch.logical_and(
            matching_classification.any(dim=1), mse_distance < self.mse_max
        )
        index_not_in_top_k = torch.logical_and(
            ~matching_classification.any(dim=1), mse_distance < self.mse_max
        )
        return (index_in_top_k if self.check_for_match else index_not_in_top_k).any()
