import torch.nn.functional as F

from .Abstract import *


class JustPhi(Comparer):
    def __init__(self, *args, **kwargs):
        super(JustPhi, self).__init__(*args, **kwargs)

    def encode(self, conf: torch.Tensor, mask_matrices: MaskMatrices, target_conf: torch.Tensor = None,
               *args, **kwargs) -> torch.Tensor:
        n_v, n_e = mask_matrices.vertex_edge_w1.shape
        n_e2 = 2 * n_e
        vew1 = torch.cat([mask_matrices.vertex_edge_w1, mask_matrices.vertex_edge_w2], dim=1)
        vew2 = torch.cat([mask_matrices.vertex_edge_w2, mask_matrices.vertex_edge_w1], dim=1)
        edge_lengths = torch.norm((vew1.t() - vew2.t()) @ conf,
                                  dim=1)
        u_pos = vew1.t() @ conf
        v_pos = vew2.t() @ conf
        distance_matrix_uv = torch.norm(torch.unsqueeze(u_pos, 1) - torch.unsqueeze(v_pos, 0), dim=2)

        chain_mask_1 = torch.tril((vew2.t() @ vew1) * (- vew1.t() @ vew2 + 1), diagonal=-1)
        indices = torch.nonzero(chain_mask_1.flatten()).squeeze(1)

        a = edge_lengths[[int(index / n_e2) for index in indices]]
        b = edge_lengths[[int(index % n_e2) for index in indices]]
        c = torch.ravel(distance_matrix_uv)[indices]
        cos_phi = torch.divide(a ** 2 + b ** 2 - c ** 2, 2 * a * b)
        phi = torch.arccos(cos_phi.clip(-1 + 1e-6, 1 - 1e-6))
        return phi

    def compare(self, source_conf: torch.Tensor, target_conf: torch.Tensor, mask_matrices: MaskMatrices, *args,
                **kwargs) -> torch.Tensor:
        s_phi = self.encode(conf=source_conf, mask_matrices=mask_matrices)
        t_phi = self.encode(conf=target_conf, mask_matrices=mask_matrices)
        loss = F.mse_loss(s_phi, t_phi) ** 0.5
        if torch.isnan(loss):
            loss = torch.zeros_like(loss)
        return loss
