import torch
from typing import Tuple, List, Dict, Any

from data.structures import MaskMatrices


class Comparer:
    def __init__(self, use_cuda=False):
        super(Comparer, self).__init__()
        self.use_cuda = use_cuda

    def encode(self, conf: torch.Tensor, mask_matrices: MaskMatrices, target_conf: torch.Tensor = None,
               *args, **kwargs) -> Any:
        raise NotImplementedError

    def compare(self, source_conf: torch.Tensor, target_conf: torch.Tensor,
                mask_matrices: MaskMatrices, *args, **kwargs) -> torch.Tensor:
        raise NotImplementedError


class MultiComparer:
    def __init__(self, use_cuda=False):
        self.use_cuda = use_cuda

    def compare(self, list_source_conf: List[torch.Tensor], list_target_conf: List[torch.Tensor],
                mask_matrices: MaskMatrices, *args, **kwargs) -> torch.Tensor:
        raise NotImplementedError
