import torch
import torch.nn.functional as F

from .Abstract import *


class JustPsi(Comparer):
    def __init__(self, *args, **kwargs):
        super(JustPsi, self).__init__(*args, **kwargs)

    def encode(self, conf: torch.Tensor, mask_matrices: MaskMatrices, target_conf: torch.Tensor = None,
               *args, **kwargs) -> torch.Tensor:
        n_v, n_e = mask_matrices.vertex_edge_w1.shape
        n_e2 = 2 * n_e
        vew1 = torch.cat([mask_matrices.vertex_edge_w1, mask_matrices.vertex_edge_w2], dim=1)
        vew2 = torch.cat([mask_matrices.vertex_edge_w2, mask_matrices.vertex_edge_w1], dim=1)

        edge_lengths = torch.norm((vew1.t() - vew2.t()) @ conf, dim=1)
        u_pos = vew1.t() @ conf
        v_pos = vew2.t() @ conf
        distance_matrix_uu = torch.norm(torch.unsqueeze(u_pos, 1) - torch.unsqueeze(u_pos, 0), dim=2)
        distance_matrix_uv = torch.norm(torch.unsqueeze(u_pos, 1) - torch.unsqueeze(v_pos, 0), dim=2)
        distance_matrix_vu = torch.norm(torch.unsqueeze(v_pos, 1) - torch.unsqueeze(u_pos, 0), dim=2)
        distance_matrix_vv = torch.norm(torch.unsqueeze(v_pos, 1) - torch.unsqueeze(v_pos, 0), dim=2)

        chain_mask_1 = vew2.t() @ vew1
        chain_mask_2 = torch.tril((chain_mask_1 @ chain_mask_1) * (-(vew1 + vew2).t() @ (vew1 + vew2) + 1), diagonal=-1)
        indices_2 = torch.nonzero(chain_mask_2.flatten()).squeeze(1)

        a = edge_lengths[[int(index / n_e2) for index in indices_2]]
        b = torch.ravel(distance_matrix_vu)[indices_2]
        c = edge_lengths[[int(index % n_e2) for index in indices_2]]
        d = torch.ravel(distance_matrix_uu)[indices_2]
        e = torch.ravel(distance_matrix_vv)[indices_2]
        f = torch.ravel(distance_matrix_uv)[indices_2]
        a2, b2, c2, d2, e2, f2 = a ** 2, b ** 2, c ** 2, d ** 2, e ** 2, f ** 2
        r1 = b2 + c2 - e2
        r2 = b2 - c2 + e2
        t1 = a2 + b2 - d2
        t2 = a2 + e2 - f2
        sin2_psi = torch.divide(
            4 * a2 * b2 * e2
            - b2 * t2 ** 2
            - a2 * r2 ** 2
            - e2 * t1 ** 2
            + r2 * t1 * t2,
            4 * a2 * b2 * c2 - a2 * r1 ** 2 + 1e-6
        )
        sin_psi = torch.sqrt(sin2_psi.clip(0, 1))
        psi = torch.arcsin(sin_psi)
        return psi

    def compare(self, source_conf: torch.Tensor, target_conf: torch.Tensor, mask_matrices: MaskMatrices, *args,
                **kwargs) -> torch.Tensor:
        s_psi = self.encode(conf=source_conf, mask_matrices=mask_matrices)
        t_psi = self.encode(conf=target_conf, mask_matrices=mask_matrices)
        loss = F.mse_loss(s_psi, t_psi) ** 0.5
        if torch.isnan(loss):
            loss = torch.zeros_like(loss)
        return loss
