import torch
from typing import Tuple, List, Dict, Any

from data.structures import MaskMatrices


def d_phi_psi_encode(mask_matrices: MaskMatrices, pos: torch.Tensor, extra_dict: Dict[str, torch.Tensor]
                     ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
    return_dict = {
        'phi_w1': extra_dict['phi_w1'],
        'phi_w2': extra_dict['phi_w2'],
        'psi_w1': extra_dict['psi_w1'],
        'psi_w2': extra_dict['psi_w2'],
    }
    n_v, n_e = mask_matrices.vertex_edge_w1.shape
    n_e2 = 2 * n_e
    vew1 = torch.cat([mask_matrices.vertex_edge_w1, mask_matrices.vertex_edge_w2], dim=1)
    vew2 = torch.cat([mask_matrices.vertex_edge_w2, mask_matrices.vertex_edge_w1], dim=1)

    edge_lengths = torch.norm((vew1.t() - vew2.t()) @ pos,
                              dim=1)
    u_pos = vew1.t() @ pos
    v_pos = vew2.t() @ pos
    distance_matrix_uu = torch.norm(torch.unsqueeze(u_pos, 1) - torch.unsqueeze(u_pos, 0), dim=2)
    distance_matrix_uv = torch.norm(torch.unsqueeze(u_pos, 1) - torch.unsqueeze(v_pos, 0), dim=2)
    distance_matrix_vu = torch.norm(torch.unsqueeze(v_pos, 1) - torch.unsqueeze(u_pos, 0), dim=2)
    distance_matrix_vv = torch.norm(torch.unsqueeze(v_pos, 1) - torch.unsqueeze(v_pos, 0), dim=2)

    chain_mask = vew2.t() @ vew1
    chain_mask_1 = torch.tril((vew2.t() @ vew1) * (- vew1.t() @ vew2 + 1), diagonal=-1)
    indices = torch.nonzero(chain_mask_1.flatten()).squeeze(1)

    a = edge_lengths[[int(index / n_e2) for index in indices]]
    b = edge_lengths[[int(index % n_e2) for index in indices]]
    c = torch.ravel(distance_matrix_uv)[indices]
    cos_phi = torch.divide(a ** 2 + b ** 2 - c ** 2, 2 * a * b)
    phi = torch.arccos(cos_phi.clip(-1 + 1e-6, 1 - 1e-6))

    chain_mask_2 = torch.tril((chain_mask @ chain_mask) * (-(vew1 + vew2).t() @ (vew1 + vew2) + 1), diagonal=-1)
    indices_2 = torch.nonzero(chain_mask_2.flatten()).squeeze(1)

    a = edge_lengths[[int(index / n_e2) for index in indices_2]]
    b = torch.ravel(distance_matrix_vu)[indices_2]
    c = edge_lengths[[int(index % n_e2) for index in indices_2]]
    d = torch.ravel(distance_matrix_uu)[indices_2]
    e = torch.ravel(distance_matrix_vv)[indices_2]
    f = torch.ravel(distance_matrix_uv)[indices_2]
    a2, b2, c2, d2, e2, f2 = a ** 2, b ** 2, c ** 2, d ** 2, e ** 2, f ** 2
    r1 = b2 + c2 - e2
    r2 = b2 - c2 + e2
    t1 = a2 + b2 - d2
    t2 = a2 + e2 - f2
    sin2_psi = torch.divide(
        4 * a2 * b2 * e2
        - b2 * t2 ** 2
        - a2 * r2 ** 2
        - e2 * t1 ** 2
        + r2 * t1 * t2,
        4 * a2 * b2 * c2 - a2 * r1 ** 2 + 1e-6
    )
    sin_psi = torch.sqrt(sin2_psi.clip(1e-6, 1 - 1e-6))
    psi = torch.arcsin(sin_psi)
    psi[torch.isnan(psi)] = 0
    return d, phi, psi, return_dict


def d_phi_psi_trunc_encode(mask_matrices: MaskMatrices, pos: torch.Tensor, extra_dict: Dict[str, torch.Tensor]
                           ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
    return_dict = {
        'phi_w1': extra_dict['phi_w1'],
        'phi_w2': extra_dict['phi_w2'],
        'psi_w1': extra_dict['psi_w1'],
        'psi_w2': extra_dict['psi_w2'],
    }
    phi_flat = extra_dict['phi_flat']
    phi_g = extra_dict['phi_g']
    psi_flat = extra_dict['psi_flat']
    psi_g = extra_dict['psi_g']

    distance_matrix = torch.norm(torch.unsqueeze(pos, 1) - torch.unsqueeze(pos, 0), dim=2)
    flat_distance_matrix = torch.ravel(distance_matrix)

    adj = mask_matrices.vertex_edge_w1 @ mask_matrices.vertex_edge_w2.t()
    d_flat = torch.nonzero(adj.flatten()).squeeze(1)
    d: torch.Tensor = flat_distance_matrix[d_flat]
    trunc_phi = torch.sum(phi_g * flat_distance_matrix[phi_flat], dim=1)
    trunc_psi = torch.sum(psi_g * flat_distance_matrix[psi_flat], dim=1)

    return d, trunc_phi, trunc_psi, return_dict
