from .Abstract import *
from .utils import get_adj, normalize_adj_rc, distance_among


class lDDTComparer(Comparer):
    def __init__(self, truncate: float, *args, **kwargs):
        super(lDDTComparer, self).__init__(*args, **kwargs)
        assert truncate > 1e-6
        self.truncate = truncate

    def encode(self, conf: torch.Tensor, mask_matrices: MaskMatrices, target_conf: torch.Tensor = None,
               *args, **kwargs) -> torch.Tensor:
        adj = get_adj(mask_matrices, -1, use_cuda=self.use_cuda)
        ds = distance_among(conf)
        enc = ds * adj * (ds <= self.truncate)
        return enc

    def compare(self, source_conf: torch.Tensor, target_conf: torch.Tensor, mask_matrices: MaskMatrices, *args,
                **kwargs) -> torch.Tensor:
        adj = get_adj(mask_matrices, -1, use_cuda=self.use_cuda)
        n_atom = mask_matrices.mol_vertex_w.shape[1]
        ds = distance_among(source_conf)
        dt = distance_among(target_conf)
        mean_adj = normalize_adj_rc(adj * (dt <= self.truncate))
        distance_2 = (ds - dt) ** 2
        loss = torch.sum(distance_2 * mean_adj) / n_atom
        loss = torch.sqrt(loss)
        return loss


# class MultilDDTComparer(MultiComparer):
#     def __init__(self, truncate: float, *args, **kwargs):
#         super(MultilDDTComparer, self).__init__(*args, **kwargs)
#         assert truncate > 1e-6
#         self.truncate = truncate
#
#     def compare(self, list_source_conf: List[torch.Tensor], target_conf: torch.Tensor,
#                 mask_matrices: MaskMatrices, *args, **kwargs) -> torch.Tensor:
#         adj = get_adj(mask_matrices, -1, use_cuda=self.use_cuda)
#         n_atom = mask_matrices.mol_vertex_w.shape[1]
#         list_ds = [distance_among(source_conf) for source_conf in list_source_conf]
#         dt = distance_among(target_conf)
#         mean_adj = normalize_adj_rc(adj * (dt <= self.truncate))
#         list_distance_2 = [(ds - dt) ** 2 for ds in list_ds]
#         losses = [torch.sum(distance_2 * mean_adj) / n_atom for distance_2 in list_distance_2]
#         losses = [torch.sqrt(loss) for loss in losses]
#         return min(losses)


class lDDTScore(Comparer):
    def __init__(self, truncate: float, *args, **kwargs):
        super(lDDTScore, self).__init__(*args, **kwargs)
        assert truncate > 1e-6
        self.truncate = truncate

    def encode(self, conf: torch.Tensor, mask_matrices: MaskMatrices, target_conf: torch.Tensor = None,
               *args, **kwargs) -> torch.Tensor:
        adj = get_adj(mask_matrices, -1, use_cuda=self.use_cuda)
        ds = distance_among(conf)
        enc = ds * adj * (ds <= self.truncate)
        return enc

    def compare(self, source_conf: torch.Tensor, target_conf: torch.Tensor, mask_matrices: MaskMatrices, *args,
                **kwargs) -> torch.Tensor:
        adj = get_adj(mask_matrices, -1, use_cuda=self.use_cuda)
        ds = distance_among(source_conf)
        dt = distance_among(target_conf)
        trunc_adj = adj * (dt <= self.truncate)
        distance_2 = (ds - dt) ** 2
        trunc_distance = torch.sqrt(distance_2 * trunc_adj)
        total = torch.sum(trunc_adj)
        cnt_0 = torch.sum(trunc_distance < 1e-6)
        cnt_1 = torch.sum(trunc_distance < 0.5 + 1e-6) - cnt_0
        cnt_2 = torch.sum(trunc_distance < 1.0 + 1e-6) - cnt_0
        cnt_3 = torch.sum(trunc_distance < 2.0 + 1e-6) - cnt_0
        cnt_4 = torch.sum(trunc_distance < 4.0 + 1e-6) - cnt_0
        score = (cnt_1 + cnt_2 + cnt_3 + cnt_4) / (4 * total)
        return score


class MultilDDTScore(MultiComparer):
    def __init__(self, truncate: float, *args, **kwargs):
        super(MultilDDTScore, self).__init__(*args, **kwargs)
        assert truncate > 1e-6
        self.truncate = truncate

    def encode(self, conf: torch.Tensor, mask_matrices: MaskMatrices, target_conf: torch.Tensor = None,
               *args, **kwargs) -> torch.Tensor:
        adj = get_adj(mask_matrices, -1, use_cuda=self.use_cuda)
        ds = distance_among(conf)
        enc = ds * adj * (ds <= self.truncate)
        return enc

    def calc_score(self, source_conf: torch.Tensor, target_conf: torch.Tensor) -> float:
        ds = distance_among(source_conf)
        dt = distance_among(target_conf)
        trunc_adj = torch.ones_like(dt) * (dt <= self.truncate)
        distance_2 = (ds - dt) ** 2
        trunc_distance = torch.sqrt(distance_2 * trunc_adj)
        total = torch.sum(trunc_adj)
        cnt_0 = torch.sum(trunc_distance < 1e-6)
        cnt_1 = torch.sum(trunc_distance < 0.5 + 1e-6) - cnt_0
        cnt_2 = torch.sum(trunc_distance < 1.0 + 1e-6) - cnt_0
        cnt_3 = torch.sum(trunc_distance < 2.0 + 1e-6) - cnt_0
        cnt_4 = torch.sum(trunc_distance < 4.0 + 1e-6) - cnt_0
        score = (cnt_1 + cnt_2 + cnt_3 + cnt_4) / (4 * total)
        return score

    def compare(self, list_source_conf: List[torch.Tensor], list_target_conf: List[torch.Tensor],
                mask_matrices: MaskMatrices, *args, **kwargs) -> torch.Tensor:
        score = torch.tensor([
            max([self.calc_score(source_conf, target_conf) for source_conf in list_source_conf])
            for target_conf in list_target_conf]).mean()
        if self.use_cuda:
            score = score.cuda()
        return score


if __name__ == '__main__':
    import torch
    a = torch.tensor([
        [1, 2, 3],
        [4, 5, 6],
        [7, 8, 9],
    ], dtype=torch.float32)
    a = a * (a <= 5)
    print(a.data)
