# https://github.com/NVlabs/tiny-cuda-nn/blob/327b3699d8aab6c694ec43b7ee0aeced57301bfb/include/tiny-cuda-nn/encodings/spherical_harmonics.h
import torch
import torch.nn as nn


def get_activation(activation: str):
    if activation == 'ReLU':
        return nn.ReLU()
    elif activation == 'Sigmoid':
        return nn.Sigmoid()
    else: # == 'None'
        return nn.Identity()


class Network(nn.Module):
    def __init__(self, n_input_dims, n_output_dims, network_config: dict):
        super().__init__()

        if network_config["n_hidden_layers"] > 0:
            out_dim = network_config["n_neurons"]
        else:
            out_dim = n_output_dims

        layers = [nn.Linear(n_input_dims, out_dim)]

        for i in range(network_config["n_hidden_layers"]):
            in_dim = out_dim
            if i == network_config["n_hidden_layers"] - 1:
                out_dim = n_output_dims

            layers.extend([get_activation(network_config["activation"]),
                           nn.Linear(in_dim, out_dim)])

        layers.append(get_activation(network_config["output_activation"]))
        self.layers = nn.Sequential(*layers)

    def forward(self, *args, **kwargs):
        return self.layers(*args, **kwargs)


class Encoding(nn.Module):
    """ Spherical Harmonics Encoding """
    def __init__(self, n_input_dims, encoding_config: dict):
        super().__init__()
        self.n_input_dims = n_input_dims
        self.n_output_dims = encoding_config["degree"]**2
        self.degree = encoding_config["degree"]

    def forward(self, coords):
        coords = coords * 2 - 1.
        x, y, z = torch.split(coords, 1, dim=-1)
        xy = x * y
        xz = x * z
        yz = y * z
        x2 = x * x
        y2 = y * y
        z2 = z * z
        x4 = x2 * x2
        y4 = y2 * y2
        z4 = z2 * z2

        outs = 0.28209479177387814 * torch.ones_like(x) # 1/(2*sqrt(pi))
        if self.degree <= 1:
            return outs

        outs = torch.cat([
            outs,
            -0.48860251190291987*y, # -sqrt(3)*y/(2*sqrt(pi))
            0.48860251190291987*z, # sqrt(3)*z/(2*sqrt(pi))
            -0.48860251190291987*x, # -sqrt(3)*x/(2*sqrt(pi))
        ], -1)
        if self.degree == 2:
            return outs

        outs = torch.cat([
            outs,
             1.0925484305920792*xy, # sqrt(15)*xy/(2*sqrt(pi))
            -1.0925484305920792*yz, # -sqrt(15)*yz/(2*sqrt(pi))
             0.94617469575755997*z2 - 0.31539156525251999,
             # sqrt(5)*(3*z2 - 1)/(4*sqrt(pi))
            -1.0925484305920792*xz, # -sqrt(15)*xz/(2*sqrt(pi))
             0.54627421529603959*x2 - 0.54627421529603959*y2,
             # sqrt(15)*(x2 - y2)/(4*sqrt(pi))
        ], -1)
        if self.degree == 3:
            return outs

        outs = torch.cat([
            outs,
            0.59004358992664352*y*(-3.0*x2 + y2),
            # sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi))
            2.8906114426405538*xy*z, # sqrt(105)*xy*z/(2*sqrt(pi))
            0.45704579946446572*y*(1.0 - 5.0*z2),
            # sqrt(42)*y*(1 - 5*z2)/(8*sqrt(pi))
            0.3731763325901154*z*(5.0*z2 - 3.0),
            # sqrt(7)*z*(5*z2 - 3)/(4*sqrt(pi))
            0.45704579946446572*x*(1.0 - 5.0*z2),
            # sqrt(42)*x*(1 - 5*z2)/(8*sqrt(pi))
            1.4453057213202769*z*(x2 - y2), # sqrt(105)*z*(x2 - y2)/(4*sqrt(pi))
            0.59004358992664352*x*(-x2 + 3.0*y2),
            # sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi))
        ], -1)
        return outs


@torch.no_grad()
def get_ellipses(covs):
    # for numerical stability
    # scale = covs[..., :1] # .mean(dim=-1, keepdims=True).clamp(max=1)
    # if scale > 1:
    #     scale = 1

    a2 = (covs[..., 0] / covs[..., :2].sqrt().prod(dim=-1)).square()
    b2 = (covs[..., 1] / covs[..., :2].sqrt().prod(dim=-1)).square()
    assert (a2 == 0).sum() == 0
    assert (b2 == 0).sum() == 0
    beta = covs[..., 2]
    gamma = covs[..., 3]

    sin_beta = torch.sin(beta)
    sin_gamma = torch.sin(gamma)
    cos_beta = torch.cos(beta)
    cos_gamma = torch.cos(gamma)

    sin_beta_square = torch.square(sin_beta)
    cos_beta_square = 1 - sin_beta_square
    sin_gamma_square = torch.square(sin_gamma)
    cos_gamma_square = 1 - sin_gamma_square

    sin_cos_beta = sin_beta * cos_beta
    sin_cos_gamma = sin_gamma * cos_gamma

    # Ratation matrix (assuming yaw = 0)
    # [cos B,  sin B sin G, sin B cos G]
    # [0,      cos G,       - sin G    ]
    # [-sin B, cos B sin G, cos B cos G]

    # Inverse...
    # [cos B, sin B sin G,   -sin B cos G]
    # [0,     cos G,         sin G       ]
    # [sin B, - cos B sin G, cos B cos G ]

    # (X2 + Y2) / A2 + Z2 / B2 = 1

    # (cos B x + sin B z)2 / A2
    # + (sin B sin G x + cos G y - cos B sin G z)2 / A2
    # + (- sin B cos G x + sin G y + cos B cos G z)2 / B2
    # - 1 = 0

    # (cos2 B / A2 + sin2 B sin2 G / A2 + sin2 B cos2 G / B2) x2
    # + (cos2 G / A2 + sin2 G / B2) y2
    # + (sin2 B / A2 + cos2 B sin2 G / A2 + cos2 B cos2 G / B2) z2
    # + 2 (sin B sin G cos G / A2 - sin B sin G cos G / B2) xy
    # + 2 (cos B sin B / A2 - sin B cos B sin2 G / A2 - sin B cos B cos2 G / B2) xz
    # + 2 (- cos B sin G cos G / A2 + cos B sin G cos G / B2) yz
    # - 1 = 0

    c_x2 = (cos_beta_square + sin_beta_square * sin_gamma_square) / a2 \
         + sin_beta_square * cos_gamma_square / b2
    c_y2 = cos_gamma_square / a2 + sin_gamma_square / b2
    c_z2 = (sin_beta_square + cos_beta_square * sin_gamma_square) / a2 \
         + cos_beta_square * cos_gamma_square / b2
    c_xy = 2 * (sin_beta * sin_cos_gamma * (1 / a2 - 1 / b2))
    c_xz = 2 * (sin_cos_beta * (1 - sin_gamma_square) / a2
                - sin_cos_beta * cos_gamma_square / b2)
    c_yz = 2 * (cos_beta * sin_cos_gamma * (1 / b2 - 1 / a2))

    a_xy, b_xy, theta_xy = get_ellipse_info(c_x2, c_xy, c_y2, -1)
    a_xz, b_xz, theta_xz = get_ellipse_info(c_x2, c_xz, c_z2, -1)
    a_yz, b_yz, theta_yz = get_ellipse_info(c_y2, c_yz, c_z2, -1)

    n_scale = covs[..., :2].sqrt().prod(dim=-1) # covs[..., 1] # scale # ** 0.5

    return ((a_xy * n_scale, b_xy * n_scale, theta_xy),
            (a_xz * n_scale, b_xz * n_scale, theta_xz),
            (a_yz * n_scale, b_yz * n_scale, theta_yz))


def get_ellipse_info(A, B, C, F, eps=1e-8):
    # 2. ELLIPSE
    # AX2 + BXY + CY2 + F = 0
    # a, b = - sqrt(2(B2 - 4AC) * F * ((A + C) \pm sqrt((A - C)2 + B2)))
    #      / (B2 - 4AC)
    # theta = if B != 0, arccot (C - A - sqrt((A - C)2 + B2)) / B
    #         elif A < C, 0
    #         elif A > 0, pi / 2
    assert A.isnan().sum() == 0
    assert B.isnan().sum() == 0
    assert C.isnan().sum() == 0

    B2 = B.square()
    A_C_square = (A - C).square()
    sqrt_A_C_square_B2 = (A_C_square + B2).clamp(min=0).sqrt()
    A_C = A + C
    B2_4AC = B2 - 4 * A * C

    temp = (2 * F / B2_4AC).clamp(min=eps).sqrt()
    a = temp * (A_C + sqrt_A_C_square_B2).clamp(min=eps).sqrt()
    # a = temp * ((A + C) + (A_C_square + B2).clamp(min=0).sqrt()).clamp(min=0).sqrt()
    b = temp * (A_C - sqrt_A_C_square_B2).clamp(min=eps).sqrt()
    # b = temp * ((A + C) - (A_C_square + B2).clamp(min=0).sqrt()).clamp(min=0).sqrt()
    # theta = torch.atan2(B, (C - A - (A_C_square + B2).clamp(min=0).sqrt()))
    theta = torch.atan2((-C + A + sqrt_A_C_square_B2), B) # (A_C_square + B2).clamp(min=0).sqrt()))

    assert a.isnan().sum() == 0
    assert (a == 0).sum() == 0
    assert b.isnan().sum() == 0
    assert (b == 0).sum() == 0
    assert theta.isnan().sum() == 0

    return a, b, theta

