from .Abstract import *
from .utils import get_adj, get_mean_adj, distance_among


class DistanceComparer(Comparer):
    def __init__(self, hop: int, *args, **kwargs):
        super(DistanceComparer, self).__init__(*args, **kwargs)
        assert hop >= 1 or hop == -1  # -1 stands for infinity
        self.hop = hop

    def encode(self, conf: torch.Tensor, mask_matrices: MaskMatrices, target_conf: torch.Tensor = None,
               *args, **kwargs) -> torch.Tensor:
        adj = get_adj(mask_matrices, self.hop, use_cuda=self.use_cuda)
        ds = distance_among(conf)
        enc = ds * adj
        return enc

    def compare(self, source_conf: torch.Tensor, target_conf: torch.Tensor, mask_matrices: MaskMatrices, *args,
                **kwargs) -> torch.Tensor:
        mean_adj = get_mean_adj(mask_matrices, self.hop, use_cuda=self.use_cuda)
        n_atom = mask_matrices.mol_vertex_w.shape[1]
        ds = distance_among(source_conf)
        dt = distance_among(target_conf)
        distance_2 = (ds - dt) ** 2
        loss = torch.sum(distance_2 * mean_adj) / n_atom
        loss = torch.sqrt(loss)
        return loss


class MultiDistanceComparer(MultiComparer):
    def __init__(self, hop: int, *args, **kwargs):
        super(MultiDistanceComparer, self).__init__(*args, **kwargs)
        assert hop >= 1 or hop == -1
        self.hop = hop

    def encode(self, conf: torch.Tensor, mask_matrices: MaskMatrices) -> torch.tensor:
        mean_adj = get_mean_adj(mask_matrices, self.hop, use_cuda=self.use_cuda)
        dis_matrix = distance_among(conf)
        norm_dis_matrix = dis_matrix * (mean_adj ** 0.5)
        return norm_dis_matrix

    def compare(self, list_source_conf: List[torch.Tensor], list_target_conf: List[torch.Tensor],
                mask_matrices: MaskMatrices, *args, **kwargs) -> torch.Tensor:
        n_atom = mask_matrices.mol_vertex_w.shape[1]
        list_ds = [self.encode(source_conf, mask_matrices) for source_conf in list_source_conf]
        list_dt = [self.encode(target_conf, mask_matrices) for target_conf in list_target_conf]
        loss = torch.tensor([
            min([torch.sum((ds - dt) ** 2) / n_atom for ds in list_ds])
            for dt in list_dt]).mean()
        loss = torch.sqrt(loss)
        return loss
