import torch
from my_config import config as cfg


class loss_func(torch.nn.Module):
    def __init__(self):
        super(loss_func, self).__init__()

    def forward(self, pred_nodes, next_nodes, nodes, senders, receivers, nodes_norm_con):
        gt_loss = torch.nn.MSELoss()(pred_nodes, next_nodes)

        batch_size = senders.shape[0]
        # edge_size = senders.shape[1]
        # print(senders.shape)
        if cfg.flag_normed:
            pred_nodes = torch.mul(pred_nodes, nodes_norm_con[0:13])
            nodes = torch.mul(nodes, nodes_norm_con)

        end_point_loss = torch.zeros(1)

        if cfg.flag_gpu:
            end_point_loss = end_point_loss.cuda()
        for j in range(batch_size):
            center = pred_nodes[j, :, 0:3]
            quat = pred_nodes[j, :, 6:10]  # in the form of q_w, q_x, q_y, q_z
            length = nodes[j, :, 15] / 2  # the 15th dimension is the length of the segment

            # https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation,
            # Using the rotation matrix multiplied by the e3 direction,
            # Which is the rod creating direction from root to leaves
            # Lower node index closer too root
            # q = q_r + (q_i + q_j + q_k)
            # [0,0,1] result [2(q_i*q_k + q_j*q_r), 2(q_j*q_k - q_i*q_r), 1 - 2(q_i^2 + q_j^2)]
            # [0,1,0] result [2(q_i*q_j - q_k*q_r), 1 - 2(q_i^2 + q_k^2), 2(q_j*q_k + q_i*q_r)]

            center_sen = center[senders[j, :]]
            quat_sen = quat[senders[j, :]]
            dir_sen = torch.zeros_like(center_sen)

            center_rec = center[receivers[j, :]]
            quat_rec = quat[receivers[j, :]]
            dir_rec = torch.zeros_like(center_rec)
            if cfg.flag_e3_dir == "z":
                dir_sen[:, 0] = 2 * (
                        torch.mul(quat_sen[:, 1], quat_sen[:, 3]) + torch.mul(quat_sen[:, 2], quat_sen[:, 0]))
                dir_sen[:, 1] = 2 * (
                        torch.mul(quat_sen[:, 2], quat_sen[:, 3]) - torch.mul(quat_sen[:, 1], quat_sen[:, 0]))
                dir_sen[:, 2] = 1 - 2 * (
                        torch.mul(quat_sen[:, 1], quat_sen[:, 1]) + torch.mul(quat_sen[:, 2], quat_sen[:, 2]))

                dir_rec[:, 0] = 2 * (
                        torch.mul(quat_rec[:, 1], quat_rec[:, 3]) + torch.mul(quat_rec[:, 2], quat_rec[:, 0]))
                dir_rec[:, 1] = 2 * (
                        torch.mul(quat_rec[:, 2], quat_rec[:, 3]) - torch.mul(quat_rec[:, 1], quat_rec[:, 0]))
                dir_rec[:, 2] = 1 - 2 * (
                        torch.mul(quat_rec[:, 1], quat_rec[:, 1]) + torch.mul(quat_rec[:, 2], quat_rec[:, 2]))

            elif cfg.flag_e3_dir == "y":
                dir_sen[:, 0] = 2 * (
                        torch.mul(quat_sen[:, 1], quat_sen[:, 2]) - torch.mul(quat_sen[:, 3], quat_sen[:, 0]))
                dir_sen[:, 1] = 1 - 2 * (
                        torch.mul(quat_sen[:, 1], quat_sen[:, 1]) + torch.mul(quat_sen[:, 3], quat_sen[:, 3]))
                dir_sen[:, 2] = 2 * (
                        torch.mul(quat_sen[:, 2], quat_sen[:, 3]) + torch.mul(quat_sen[:, 1], quat_sen[:, 0]))

                dir_rec[:, 0] = 2 * (
                        torch.mul(quat_rec[:, 1], quat_rec[:, 2]) - torch.mul(quat_rec[:, 3], quat_rec[:, 0]))
                dir_rec[:, 1] = 1 - 2 * (
                        torch.mul(quat_rec[:, 1], quat_rec[:, 1]) + torch.mul(quat_rec[:, 3], quat_rec[:, 3]))
                dir_rec[:, 2] = 2 * (
                        torch.mul(quat_rec[:, 2], quat_rec[:, 3]) + torch.mul(quat_rec[:, 1], quat_rec[:, 0]))


            sign = (receivers[j, :] - senders[j, :]).float()
            sign /= torch.abs(sign)

            length_sen = torch.mul(length[senders[j, :]], sign).unsqueeze(1)
            length_rec = torch.mul(length[receivers[j, :]], sign).unsqueeze(1)

            end_sen = center_sen + torch.mul(dir_sen, length_sen)
            end_rec = center_rec - torch.mul(dir_rec, length_rec)

            end_point_loss += torch.nn.MSELoss()(end_rec, end_sen)

        return end_point_loss + gt_loss


class loss_func_corr(torch.nn.Module):
    def __init__(self):
        super(loss_func_corr, self).__init__()

    def forward(self, pred_corr, gt_corr, pred_lambda, gt_lambda):
        corr_loss = torch.nn.MSELoss()(pred_corr, gt_corr)
        lambda_loss = torch.nn.MSELoss()(pred_lambda, gt_lambda)

        return corr_loss + lambda_loss