import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, List, Dict, Any

from data.structures import MaskMatrices


def onehot_cols(n_row, row_indices, use_cuda=False) -> torch.FloatTensor:
    matrix = torch.zeros(size=[n_row, len(row_indices)])
    for i, ri in enumerate(row_indices):
        matrix[ri, i] = 1
    if use_cuda:
        matrix = matrix.cuda()
    return matrix


def get_phi(mask_matrices: MaskMatrices, pos: torch.FloatTensor, use_cuda=False
            ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.LongTensor, torch.FloatTensor]:
    n_v, n_e = mask_matrices.vertex_edge_w1.shape
    n_e2 = 2 * n_e
    vew1 = torch.cat([mask_matrices.vertex_edge_w1, mask_matrices.vertex_edge_w2], dim=1)
    vew2 = torch.cat([mask_matrices.vertex_edge_w2, mask_matrices.vertex_edge_w1], dim=1)
    u_index = torch.nonzero(vew1.t())[:, 1]
    v_index = torch.nonzero(vew2.t())[:, 1]

    edge_lengths = torch.norm((vew1.t() - vew2.t()) @ pos,
                              dim=1)
    u_pos = vew1.t() @ pos
    v_pos = vew2.t() @ pos
    distance_matrix_uv = torch.norm(torch.unsqueeze(u_pos, 1) - torch.unsqueeze(v_pos, 0), dim=2)

    chain_mask_1 = torch.tril((vew2.t() @ vew1) * (- vew1.t() @ vew2 + 1), diagonal=-1)
    indices = torch.nonzero(chain_mask_1.flatten()).squeeze(1)
    vew1 = onehot_cols(n_v, [u_index[int(index / n_e2)] for index in indices])
    vew2 = onehot_cols(n_v, [v_index[int(index % n_e2)] for index in indices])
    flat_abc_indices = torch.LongTensor([
        [u_index[int(index / n_e2)] * n_v + v_index[int(index / n_e2)] for index in indices],
        [u_index[int(index % n_e2)] * n_v + v_index[int(index % n_e2)] for index in indices],
        [u_index[int(index / n_e2)] * n_v + v_index[int(index % n_e2)] for index in indices]
    ]).t()

    a = edge_lengths[[int(index / n_e2) for index in indices]]
    b = edge_lengths[[int(index % n_e2) for index in indices]]
    c = torch.ravel(distance_matrix_uv)[indices]
    cos_phi = torch.divide(a ** 2 + b ** 2 - c ** 2, 2 * a * b)
    phi = torch.arccos(cos_phi.clip(-1 + 1e-6, 1 - 1e-6))

    # sin_phi = torch.sin(phi) + 1e-6
    # g_a = torch.divide(-a ** 2 + b ** 2 - c ** 2, 2 * a ** 2 * b * sin_phi)
    # g_b = torch.divide(a ** 2 - b ** 2 - c ** 2, 2 * a * b ** 2 * sin_phi)
    # g_c = torch.divide(c, a * b * sin_phi)
    # g_abc = torch.vstack([g_a, g_b, g_c]).t()
    # g_abc = g_abc.clip(-1e1, 1e1)

    sin_m1_phi = torch.sin(phi).clip(0.1, 1.0) ** -1
    g_a = torch.divide(-a ** 2 + b ** 2 - c ** 2, 2 * a ** 2 * b) * sin_m1_phi
    g_b = torch.divide(a ** 2 - b ** 2 - c ** 2, 2 * a * b ** 2) * sin_m1_phi
    g_c = torch.divide(c, a * b) * sin_m1_phi
    g_abc = torch.vstack([g_a, g_b, g_c]).t()

    if use_cuda:
        vew1 = vew1.cuda()
        vew2 = vew2.cuda()
        flat_abc_indices = flat_abc_indices.cuda()
    return vew1, vew2, flat_abc_indices, g_abc


def get_psi(mask_matrices: MaskMatrices, pos: torch.FloatTensor, use_cuda=False
            ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.LongTensor, torch.FloatTensor]:
    n_v, n_e = mask_matrices.vertex_edge_w1.shape
    n_e2 = 2 * n_e
    vew1 = torch.cat([mask_matrices.vertex_edge_w1, mask_matrices.vertex_edge_w2], dim=1)
    vew2 = torch.cat([mask_matrices.vertex_edge_w2, mask_matrices.vertex_edge_w1], dim=1)
    u_index = torch.nonzero(vew1.t())[:, 1]
    v_index = torch.nonzero(vew2.t())[:, 1]

    edge_lengths = torch.norm((vew1.t() - vew2.t()) @ pos,
                              dim=1)
    u_pos = vew1.t() @ pos
    v_pos = vew2.t() @ pos
    distance_matrix_uu = torch.norm(torch.unsqueeze(u_pos, 1) - torch.unsqueeze(u_pos, 0), dim=2)
    distance_matrix_uv = torch.norm(torch.unsqueeze(u_pos, 1) - torch.unsqueeze(v_pos, 0), dim=2)
    distance_matrix_vu = torch.norm(torch.unsqueeze(v_pos, 1) - torch.unsqueeze(u_pos, 0), dim=2)
    distance_matrix_vv = torch.norm(torch.unsqueeze(v_pos, 1) - torch.unsqueeze(v_pos, 0), dim=2)

    chain_mask_1 = vew2.t() @ vew1
    chain_mask_2 = torch.tril((chain_mask_1 @ chain_mask_1) * (-(vew1 + vew2).t() @ (vew1 + vew2) + 1), diagonal=-1)
    indices_2 = torch.nonzero(chain_mask_2.flatten()).squeeze(1)
    vew1 = onehot_cols(n_v, [u_index[int(index / n_e2)] for index in indices_2])
    vew2 = onehot_cols(n_v, [v_index[int(index % n_e2)] for index in indices_2])
    flat_abcdef_indices = torch.LongTensor([
        [u_index[int(index / n_e2)] * n_v + v_index[int(index / n_e2)] for index in indices_2],
        [v_index[int(index / n_e2)] * n_v + u_index[int(index % n_e2)] for index in indices_2],
        [u_index[int(index % n_e2)] * n_v + v_index[int(index % n_e2)] for index in indices_2],
        [u_index[int(index / n_e2)] * n_v + u_index[int(index % n_e2)] for index in indices_2],
        [v_index[int(index / n_e2)] * n_v + v_index[int(index % n_e2)] for index in indices_2],
        [u_index[int(index / n_e2)] * n_v + v_index[int(index % n_e2)] for index in indices_2]
    ]).t()

    a = edge_lengths[[int(index / n_e2) for index in indices_2]]
    b = torch.ravel(distance_matrix_vu)[indices_2]
    c = edge_lengths[[int(index % n_e2) for index in indices_2]]
    d = torch.ravel(distance_matrix_uu)[indices_2]
    e = torch.ravel(distance_matrix_vv)[indices_2]
    f = torch.ravel(distance_matrix_uv)[indices_2]
    a2, b2, c2, d2, e2, f2 = a ** 2, b ** 2, c ** 2, d ** 2, e ** 2, f ** 2
    r1 = b2 + c2 - e2
    r2 = b2 - c2 + e2
    r3 = -b2 + c2 + e2
    s1 = c2 + d2 - f2
    s2 = c2 - d2 + f2
    s3 = -c2 + d2 + f2
    t1 = a2 + b2 - d2
    t2 = a2 + e2 - f2
    sin2_psi = torch.divide(
        4 * a2 * b2 * e2
        - b2 * t2 ** 2
        - a2 * r2 ** 2
        - e2 * t1 ** 2
        + r2 * t1 * t2,
        4 * a2 * b2 * c2 - a2 * r1 ** 2 + 1e-6
    )
    sin_psi = torch.sqrt(sin2_psi.clip(0, 1))
    psi = torch.arcsin(sin_psi)
    cos_psi = torch.cos(psi)

    # base = sin_psi * cos_psi * a2 * (4 * b2 * c2 - r1 ** 2) + 1e-6
    # g_a = torch.divide(a * ((4 * b2 * c2 - r1 ** 2) * sin2_psi - s1 * e2 - s2 * b2 - s3 * c2 + 2 * a2 * c2), base)
    # g_b = torch.divide(b * (2 * r3 * a2 * sin2_psi - s1 * f2 - s2 * a2 - s3 * e2 + 2 * b2 * f2), base)
    # g_c = torch.divide(c * (2 * r2 * a2 * sin2_psi - 2 * r2 * a2 + t1 * t2), base)
    # g_d = torch.divide(d * (-2 * t1 * e2 + r2 * t2), base)
    # g_e = torch.divide(e * (2 * r1 * a2 * sin2_psi - s1 * a2 - s2 * d2 - s3 * b2 + 2 * d2 * e2), base)
    # g_f = torch.divide(f * (-2 * t2 * b2 + r2 * t1), base)
    # g_abcdef = torch.vstack([g_a, g_b, g_c, g_d, g_e, g_f]).t()
    # # g_abcdef[torch.isinf(g_abcdef)] = 0
    # g_abcdef[torch.isnan(g_abcdef)] = 0
    # g_abcdef = g_abcdef.clip(-1e1, 1e1)

    sin_phi_cos_psi_m1 = (sin_psi * cos_psi).clip(0.05, 0.5) ** -1
    base = a2 * (4 * b2 * c2 - r1 ** 2).clip(10., 1e9)
    g_a = torch.divide(a * ((4 * b2 * c2 - r1 ** 2) * sin2_psi - s1 * e2 - s2 * b2 - s3 * c2 + 2 * a2 * c2), base
                       ) * sin_phi_cos_psi_m1
    g_b = torch.divide(b * (2 * r3 * a2 * sin2_psi - s1 * f2 - s2 * a2 - s3 * e2 + 2 * b2 * f2), base
                       ) * sin_phi_cos_psi_m1
    g_c = torch.divide(c * (2 * r2 * a2 * sin2_psi - 2 * r2 * a2 + t1 * t2), base) * sin_phi_cos_psi_m1
    g_d = torch.divide(d * (-2 * t1 * e2 + r2 * t2), base) * sin_phi_cos_psi_m1
    g_e = torch.divide(e * (2 * r1 * a2 * sin2_psi - s1 * a2 - s2 * d2 - s3 * b2 + 2 * d2 * e2), base
                       ) * sin_phi_cos_psi_m1
    g_f = torch.divide(f * (-2 * t2 * b2 + r2 * t1), base) * sin_phi_cos_psi_m1
    g_abcdef = torch.vstack([g_a, g_b, g_c, g_d, g_e, g_f]).t()
    g_abcdef[torch.isnan(g_abcdef)] = 0

    if use_cuda:
        vew1 = vew1.cuda()
        vew2 = vew2.cuda()
        flat_abcdef_indices = flat_abcdef_indices.cuda()
    return vew1, vew2, flat_abcdef_indices, g_abcdef


if __name__ == '__main__':
    mm = MaskMatrices(None, None,
                      torch.FloatTensor([
                          [1, 0, 0, 0, 0, 0],
                          [0, 1, 0, 0, 0, 0],
                          [0, 0, 1, 0, 0, 0],
                          [0, 0, 0, 1, 0, 0],
                          [0, 0, 0, 0, 1, 0],
                          [0, 0, 0, 0, 0, 1],
                      ]),
                      torch.FloatTensor([
                          [0, 0, 0, 0, 0, 1],
                          [1, 0, 0, 0, 0, 0],
                          [0, 1, 0, 0, 0, 0],
                          [0, 0, 1, 0, 0, 0],
                          [0, 0, 0, 1, 0, 0],
                          [0, 0, 0, 0, 1, 0],
                      ]),
                      None, None)
    pos = torch.FloatTensor([
        [1.7, 2, 0],
        [1.7, 1, -0],
        [1.7, -1, 0],
        [0, -2, 0],
        [-1.7, -1, 0.2],
        [-1.7, 1, 0],
    ])
    u, v, abc, g = get_psi(mm, pos)
    print(u)
    print(v)
    print(abc)
    print(g)
