from .Abstract import *
from .utils import kabsch


class NaiveComparer(Comparer):
    def __init__(self, *args, **kwargs):
        super(NaiveComparer, self).__init__(*args, **kwargs)

    def encode(self, conf: torch.Tensor, mask_matrices: MaskMatrices, target_conf: torch.Tensor = None,
               *args, **kwargs) -> torch.Tensor:
        return conf

    def compare(self, source_conf: torch.Tensor, target_conf: torch.Tensor, mask_matrices: MaskMatrices, *args,
                **kwargs) -> torch.Tensor:
        massive = kwargs['massive']
        md2 = torch.pow(source_conf - target_conf, 2).sum(dim=1, keepdim=True)
        loss = torch.sqrt(md2.sum() / massive.shape[0])
        return loss


class KabschComparer(Comparer):
    def __init__(self, *args, **kwargs):
        super(KabschComparer, self).__init__(*args, **kwargs)

    def encode(self, conf: torch.Tensor, mask_matrices: MaskMatrices, target_conf: torch.Tensor = None,
               *args, **kwargs) -> torch.Tensor:
        return kabsch(
            pos=conf,
            fit_pos=target_conf,
            mol_vertex_w=mask_matrices.mol_vertex_w,
            use_cuda=self.use_cuda
        )

    def compare(self, source_conf: torch.Tensor, target_conf: torch.Tensor, mask_matrices: MaskMatrices, *args,
                **kwargs) -> torch.Tensor:
        massive = kwargs['massive']
        source_conf = kabsch(
            pos=source_conf,
            fit_pos=target_conf,
            mol_vertex_w=mask_matrices.mol_vertex_w,
            use_cuda=self.use_cuda
        )
        # md2 = massive * torch.pow(source_conf - target_conf, 2).sum(dim=1, keepdim=True)
        # loss = torch.sqrt(md2.sum() / massive.sum())
        md2 = torch.pow(source_conf - target_conf, 2).sum(dim=1, keepdim=True)
        loss = torch.sqrt(md2.sum() / massive.shape[0])
        return loss
