"""Real spherical harmonics in Cartesian form for PyTorch.

This is an autogenerated file. See
https://github.com/cheind/torch-spherical-harmonics
for more information.
"""

import torch


def rsh_cart_0(xyz: torch.Tensor):
    """Computes all real spherical harmonics up to degree 0.

    This is an autogenerated method. See
    https://github.com/cheind/torch-spherical-harmonics
    for more information.

    Params:
        xyz: (N,...,3) tensor of points on the unit sphere

    Returns:
        rsh: (N,...,1) real spherical harmonics
            projections of input. Ynm is found at index
            `n*(n+1) + m`, with `0 <= n <= degree` and
            `-n <= m <= n`.
    """

    return torch.stack(
        [
            xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
        ],
        -1,
    )


def rsh_cart_1(xyz: torch.Tensor):
    """Computes all real spherical harmonics up to degree 1.

    This is an autogenerated method. See
    https://github.com/cheind/torch-spherical-harmonics
    for more information.

    Params:
        xyz: (N,...,3) tensor of points on the unit sphere

    Returns:
        rsh: (N,...,4) real spherical harmonics
            projections of input. Ynm is found at index
            `n*(n+1) + m`, with `0 <= n <= degree` and
            `-n <= m <= n`.
    """
    x = xyz[..., 0]
    y = xyz[..., 1]
    z = xyz[..., 2]

    return torch.stack(
        [
            xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
            -0.48860251190292 * y,
            0.48860251190292 * z,
            -0.48860251190292 * x,
        ],
        -1,
    )


def rsh_cart_2(xyz: torch.Tensor):
    """Computes all real spherical harmonics up to degree 2.

    This is an autogenerated method. See
    https://github.com/cheind/torch-spherical-harmonics
    for more information.

    Params:
        xyz: (N,...,3) tensor of points on the unit sphere

    Returns:
        rsh: (N,...,9) real spherical harmonics
            projections of input. Ynm is found at index
            `n*(n+1) + m`, with `0 <= n <= degree` and
            `-n <= m <= n`.
    """
    x = xyz[..., 0]
    y = xyz[..., 1]
    z = xyz[..., 2]

    x2 = x**2
    y2 = y**2
    z2 = z**2
    xy = x * y
    xz = x * z
    yz = y * z

    return torch.stack(
        [
            xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
            -0.48860251190292 * y,
            0.48860251190292 * z,
            -0.48860251190292 * x,
            1.09254843059208 * xy,
            -1.09254843059208 * yz,
            0.94617469575756 * z2 - 0.31539156525252,
            -1.09254843059208 * xz,
            0.54627421529604 * x2 - 0.54627421529604 * y2,
        ],
        -1,
    )


def rsh_cart_3(xyz: torch.Tensor):
    """Computes all real spherical harmonics up to degree 3.

    This is an autogenerated method. See
    https://github.com/cheind/torch-spherical-harmonics
    for more information.

    Params:
        xyz: (N,...,3) tensor of points on the unit sphere

    Returns:
        rsh: (N,...,16) real spherical harmonics
            projections of input. Ynm is found at index
            `n*(n+1) + m`, with `0 <= n <= degree` and
            `-n <= m <= n`.
    """
    x = xyz[..., 0]
    y = xyz[..., 1]
    z = xyz[..., 2]

    x2 = x**2
    y2 = y**2
    z2 = z**2
    xy = x * y
    xz = x * z
    yz = y * z

    return torch.stack(
        [
            xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
            -0.48860251190292 * y,
            0.48860251190292 * z,
            -0.48860251190292 * x,
            1.09254843059208 * xy,
            -1.09254843059208 * yz,
            0.94617469575756 * z2 - 0.31539156525252,
            -1.09254843059208 * xz,
            0.54627421529604 * x2 - 0.54627421529604 * y2,
            -0.590043589926644 * y * (3.0 * x2 - y2),
            2.89061144264055 * xy * z,
            0.304697199642977 * y * (1.5 - 7.5 * z2),
            1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
            0.304697199642977 * x * (1.5 - 7.5 * z2),
            1.44530572132028 * z * (x2 - y2),
            -0.590043589926644 * x * (x2 - 3.0 * y2),
        ],
        -1,
    )


def rsh_cart_4(xyz: torch.Tensor):
    """Computes all real spherical harmonics up to degree 4.

    This is an autogenerated method. See
    https://github.com/cheind/torch-spherical-harmonics
    for more information.

    Params:
        xyz: (N,...,3) tensor of points on the unit sphere

    Returns:
        rsh: (N,...,25) real spherical harmonics
            projections of input. Ynm is found at index
            `n*(n+1) + m`, with `0 <= n <= degree` and
            `-n <= m <= n`.
    """
    x = xyz[..., 0]
    y = xyz[..., 1]
    z = xyz[..., 2]

    x2 = x**2
    y2 = y**2
    z2 = z**2
    xy = x * y
    xz = x * z
    yz = y * z
    x4 = x2**2
    y4 = y2**2
    z4 = z2**2

    return torch.stack(
        [
            xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
            -0.48860251190292 * y,
            0.48860251190292 * z,
            -0.48860251190292 * x,
            1.09254843059208 * xy,
            -1.09254843059208 * yz,
            0.94617469575756 * z2 - 0.31539156525252,
            -1.09254843059208 * xz,
            0.54627421529604 * x2 - 0.54627421529604 * y2,
            -0.590043589926644 * y * (3.0 * x2 - y2),
            2.89061144264055 * xy * z,
            0.304697199642977 * y * (1.5 - 7.5 * z2),
            1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
            0.304697199642977 * x * (1.5 - 7.5 * z2),
            1.44530572132028 * z * (x2 - y2),
            -0.590043589926644 * x * (x2 - 3.0 * y2),
            2.5033429417967 * xy * (x2 - y2),
            -1.77013076977993 * yz * (3.0 * x2 - y2),
            0.126156626101008 * xy * (52.5 * z2 - 7.5),
            0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
            1.48099765681286
            * z
            * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
            - 0.952069922236839 * z2
            + 0.317356640745613,
            0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
            0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
            -1.77013076977993 * xz * (x2 - 3.0 * y2),
            -3.75501441269506 * x2 * y2
            + 0.625835735449176 * x4
            + 0.625835735449176 * y4,
        ],
        -1,
    )


def rsh_cart_5(xyz: torch.Tensor):
    """Computes all real spherical harmonics up to degree 5.

    This is an autogenerated method. See
    https://github.com/cheind/torch-spherical-harmonics
    for more information.

    Params:
        xyz: (N,...,3) tensor of points on the unit sphere

    Returns:
        rsh: (N,...,36) real spherical harmonics
            projections of input. Ynm is found at index
            `n*(n+1) + m`, with `0 <= n <= degree` and
            `-n <= m <= n`.
    """
    x = xyz[..., 0]
    y = xyz[..., 1]
    z = xyz[..., 2]

    x2 = x**2
    y2 = y**2
    z2 = z**2
    xy = x * y
    xz = x * z
    yz = y * z
    x4 = x2**2
    y4 = y2**2
    z4 = z2**2

    return torch.stack(
        [
            xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
            -0.48860251190292 * y,
            0.48860251190292 * z,
            -0.48860251190292 * x,
            1.09254843059208 * xy,
            -1.09254843059208 * yz,
            0.94617469575756 * z2 - 0.31539156525252,
            -1.09254843059208 * xz,
            0.54627421529604 * x2 - 0.54627421529604 * y2,
            -0.590043589926644 * y * (3.0 * x2 - y2),
            2.89061144264055 * xy * z,
            0.304697199642977 * y * (1.5 - 7.5 * z2),
            1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
            0.304697199642977 * x * (1.5 - 7.5 * z2),
            1.44530572132028 * z * (x2 - y2),
            -0.590043589926644 * x * (x2 - 3.0 * y2),
            2.5033429417967 * xy * (x2 - y2),
            -1.77013076977993 * yz * (3.0 * x2 - y2),
            0.126156626101008 * xy * (52.5 * z2 - 7.5),
            0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
            1.48099765681286
            * z
            * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
            - 0.952069922236839 * z2
            + 0.317356640745613,
            0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
            0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
            -1.77013076977993 * xz * (x2 - 3.0 * y2),
            -3.75501441269506 * x2 * y2
            + 0.625835735449176 * x4
            + 0.625835735449176 * y4,
            -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
            8.30264925952416 * xy * z * (x2 - y2),
            0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
            0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
            0.241571547304372
            * y
            * (
                2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
                + 9.375 * z2
                - 1.875
            ),
            -1.24747010616985 * z * (1.5 * z2 - 0.5)
            + 1.6840846433293
            * z
            * (
                1.75
                * z
                * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
                - 1.125 * z2
                + 0.375
            )
            + 0.498988042467941 * z,
            0.241571547304372
            * x
            * (
                2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
                + 9.375 * z2
                - 1.875
            ),
            0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
            0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
            2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
            -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
        ],
        -1,
    )


def rsh_cart_6(xyz: torch.Tensor):
    """Computes all real spherical harmonics up to degree 6.

    This is an autogenerated method. See
    https://github.com/cheind/torch-spherical-harmonics
    for more information.

    Params:
        xyz: (N,...,3) tensor of points on the unit sphere

    Returns:
        rsh: (N,...,49) real spherical harmonics
            projections of input. Ynm is found at index
            `n*(n+1) + m`, with `0 <= n <= degree` and
            `-n <= m <= n`.
    """
    x = xyz[..., 0]
    y = xyz[..., 1]
    z = xyz[..., 2]

    x2 = x**2
    y2 = y**2
    z2 = z**2
    xy = x * y
    xz = x * z
    yz = y * z
    x4 = x2**2
    y4 = y2**2
    z4 = z2**2

    return torch.stack(
        [
            xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
            -0.48860251190292 * y,
            0.48860251190292 * z,
            -0.48860251190292 * x,
            1.09254843059208 * xy,
            -1.09254843059208 * yz,
            0.94617469575756 * z2 - 0.31539156525252,
            -1.09254843059208 * xz,
            0.54627421529604 * x2 - 0.54627421529604 * y2,
            -0.590043589926644 * y * (3.0 * x2 - y2),
            2.89061144264055 * xy * z,
            0.304697199642977 * y * (1.5 - 7.5 * z2),
            1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
            0.304697199642977 * x * (1.5 - 7.5 * z2),
            1.44530572132028 * z * (x2 - y2),
            -0.590043589926644 * x * (x2 - 3.0 * y2),
            2.5033429417967 * xy * (x2 - y2),
            -1.77013076977993 * yz * (3.0 * x2 - y2),
            0.126156626101008 * xy * (52.5 * z2 - 7.5),
            0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
            1.48099765681286
            * z
            * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
            - 0.952069922236839 * z2
            + 0.317356640745613,
            0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
            0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
            -1.77013076977993 * xz * (x2 - 3.0 * y2),
            -3.75501441269506 * x2 * y2
            + 0.625835735449176 * x4
            + 0.625835735449176 * y4,
            -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
            8.30264925952416 * xy * z * (x2 - y2),
            0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
            0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
            0.241571547304372
            * y
            * (
                2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
                + 9.375 * z2
                - 1.875
            ),
            -1.24747010616985 * z * (1.5 * z2 - 0.5)
            + 1.6840846433293
            * z
            * (
                1.75
                * z
                * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
                - 1.125 * z2
                + 0.375
            )
            + 0.498988042467941 * z,
            0.241571547304372
            * x
            * (
                2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
                + 9.375 * z2
                - 1.875
            ),
            0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
            0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
            2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
            -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
            4.09910463115149 * x**4 * xy
            - 13.6636821038383 * xy**3
            + 4.09910463115149 * xy * y**4,
            -2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
            0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5),
            0.00584892228263444
            * y
            * (3.0 * x2 - y2)
            * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
            0.0701870673916132
            * xy
            * (
                2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
                - 91.875 * z2
                + 13.125
            ),
            0.221950995245231
            * y
            * (
                -2.8 * z * (1.5 - 7.5 * z2)
                + 2.2
                * z
                * (
                    2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
                    + 9.375 * z2
                    - 1.875
                )
                - 4.8 * z
            ),
            -1.48328138624466
            * z
            * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
            + 1.86469659985043
            * z
            * (
                -1.33333333333333 * z * (1.5 * z2 - 0.5)
                + 1.8
                * z
                * (
                    1.75
                    * z
                    * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
                    - 1.125 * z2
                    + 0.375
                )
                + 0.533333333333333 * z
            )
            + 0.953538034014426 * z2
            - 0.317846011338142,
            0.221950995245231
            * x
            * (
                -2.8 * z * (1.5 - 7.5 * z2)
                + 2.2
                * z
                * (
                    2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
                    + 9.375 * z2
                    - 1.875
                )
                - 4.8 * z
            ),
            0.0350935336958066
            * (x2 - y2)
            * (
                2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
                - 91.875 * z2
                + 13.125
            ),
            0.00584892228263444
            * x
            * (x2 - 3.0 * y2)
            * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
            0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4),
            -2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
            0.683184105191914 * x2**3
            + 10.2477615778787 * x2 * y4
            - 10.2477615778787 * x4 * y2
            - 0.683184105191914 * y2**3,
        ],
        -1,
    )


def rsh_cart_7(xyz: torch.Tensor):
    """Computes all real spherical harmonics up to degree 7.

    This is an autogenerated method. See
    https://github.com/cheind/torch-spherical-harmonics
    for more information.

    Params:
        xyz: (N,...,3) tensor of points on the unit sphere

    Returns:
        rsh: (N,...,64) real spherical harmonics
            projections of input. Ynm is found at index
            `n*(n+1) + m`, with `0 <= n <= degree` and
            `-n <= m <= n`.
    """
    x = xyz[..., 0]
    y = xyz[..., 1]
    z = xyz[..., 2]

    x2 = x**2
    y2 = y**2
    z2 = z**2
    xy = x * y
    xz = x * z
    yz = y * z
    x4 = x2**2
    y4 = y2**2
    z4 = z2**2

    return torch.stack(
        [
            xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
            -0.48860251190292 * y,
            0.48860251190292 * z,
            -0.48860251190292 * x,
            1.09254843059208 * xy,
            -1.09254843059208 * yz,
            0.94617469575756 * z2 - 0.31539156525252,
            -1.09254843059208 * xz,
            0.54627421529604 * x2 - 0.54627421529604 * y2,
            -0.590043589926644 * y * (3.0 * x2 - y2),
            2.89061144264055 * xy * z,
            0.304697199642977 * y * (1.5 - 7.5 * z2),
            1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
            0.304697199642977 * x * (1.5 - 7.5 * z2),
            1.44530572132028 * z * (x2 - y2),
            -0.590043589926644 * x * (x2 - 3.0 * y2),
            2.5033429417967 * xy * (x2 - y2),
            -1.77013076977993 * yz * (3.0 * x2 - y2),
            0.126156626101008 * xy * (52.5 * z2 - 7.5),
            0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
            1.48099765681286
            * z
            * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
            - 0.952069922236839 * z2
            + 0.317356640745613,
            0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
            0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
            -1.77013076977993 * xz * (x2 - 3.0 * y2),
            -3.75501441269506 * x2 * y2
            + 0.625835735449176 * x4
            + 0.625835735449176 * y4,
            -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
            8.30264925952416 * xy * z * (x2 - y2),
            0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
            0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
            0.241571547304372
            * y
            * (
                2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
                + 9.375 * z2
                - 1.875
            ),
            -1.24747010616985 * z * (1.5 * z2 - 0.5)
            + 1.6840846433293
            * z
            * (
                1.75
                * z
                * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
                - 1.125 * z2
                + 0.375
            )
            + 0.498988042467941 * z,
            0.241571547304372
            * x
            * (
                2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
                + 9.375 * z2
                - 1.875
            ),
            0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
            0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
            2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
            -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
            4.09910463115149 * x**4 * xy
            - 13.6636821038383 * xy**3
            + 4.09910463115149 * xy * y**4,
            -2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
            0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5),
            0.00584892228263444
            * y
            * (3.0 * x2 - y2)
            * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
            0.0701870673916132
            * xy
            * (
                2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
                - 91.875 * z2
                + 13.125
            ),
            0.221950995245231
            * y
            * (
                -2.8 * z * (1.5 - 7.5 * z2)
                + 2.2
                * z
                * (
                    2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
                    + 9.375 * z2
                    - 1.875
                )
                - 4.8 * z
            ),
            -1.48328138624466
            * z
            * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
            + 1.86469659985043
            * z
            * (
                -1.33333333333333 * z * (1.5 * z2 - 0.5)
                + 1.8
                * z
                * (
                    1.75
                    * z
                    * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
                    - 1.125 * z2
                    + 0.375
                )
                + 0.533333333333333 * z
            )
            + 0.953538034014426 * z2
            - 0.317846011338142,
            0.221950995245231
            * x
            * (
                -2.8 * z * (1.5 - 7.5 * z2)
                + 2.2
                * z
                * (
                    2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
                    + 9.375 * z2
                    - 1.875
                )
                - 4.8 * z
            ),
            0.0350935336958066
            * (x2 - y2)
            * (
                2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
                - 91.875 * z2
                + 13.125
            ),
            0.00584892228263444
            * x
            * (x2 - 3.0 * y2)
            * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
            0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4),
            -2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
            0.683184105191914 * x2**3
            + 10.2477615778787 * x2 * y4
            - 10.2477615778787 * x4 * y2
            - 0.683184105191914 * y2**3,
            -0.707162732524596
            * y
            * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3),
            2.6459606618019 * z * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4),
            9.98394571852353e-5
            * y
            * (5197.5 - 67567.5 * z2)
            * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
            0.00239614697244565
            * xy
            * (x2 - y2)
            * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z),
            0.00397356022507413
            * y
            * (3.0 * x2 - y2)
            * (
                3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
                + 1063.125 * z2
                - 118.125
            ),
            0.0561946276120613
            * xy
            * (
                -4.8 * z * (52.5 * z2 - 7.5)
                + 2.6
                * z
                * (
                    2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
                    - 91.875 * z2
                    + 13.125
                )
                + 48.0 * z
            ),
            0.206472245902897
            * y
            * (
                -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
                + 2.16666666666667
                * z
                * (
                    -2.8 * z * (1.5 - 7.5 * z2)
                    + 2.2
                    * z
                    * (
                        2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
                        + 9.375 * z2
                        - 1.875
                    )
                    - 4.8 * z
                )
                - 10.9375 * z2
                + 2.1875
            ),
            1.24862677781952 * z * (1.5 * z2 - 0.5)
            - 1.68564615005635
            * z
            * (
                1.75
                * z
                * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
                - 1.125 * z2
                + 0.375
            )
            + 2.02901851395672
            * z
            * (
                -1.45833333333333
                * z
                * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
                + 1.83333333333333
                * z
                * (
                    -1.33333333333333 * z * (1.5 * z2 - 0.5)
                    + 1.8
                    * z
                    * (
                        1.75
                        * z
                        * (
                            1.66666666666667 * z * (1.5 * z2 - 0.5)
                            - 0.666666666666667 * z
                        )
                        - 1.125 * z2
                        + 0.375
                    )
                    + 0.533333333333333 * z
                )
                + 0.9375 * z2
                - 0.3125
            )
            - 0.499450711127808 * z,
            0.206472245902897
            * x
            * (
                -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
                + 2.16666666666667
                * z
                * (
                    -2.8 * z * (1.5 - 7.5 * z2)
                    + 2.2
                    * z
                    * (
                        2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
                        + 9.375 * z2
                        - 1.875
                    )
                    - 4.8 * z
                )
                - 10.9375 * z2
                + 2.1875
            ),
            0.0280973138060306
            * (x2 - y2)
            * (
                -4.8 * z * (52.5 * z2 - 7.5)
                + 2.6
                * z
                * (
                    2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
                    - 91.875 * z2
                    + 13.125
                )
                + 48.0 * z
            ),
            0.00397356022507413
            * x
            * (x2 - 3.0 * y2)
            * (
                3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
                + 1063.125 * z2
                - 118.125
            ),
            0.000599036743111412
            * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
            * (-6.0 * x2 * y2 + x4 + y4),
            9.98394571852353e-5
            * x
            * (5197.5 - 67567.5 * z2)
            * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
            2.6459606618019 * z * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3),
            -0.707162732524596
            * x
            * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3),
        ],
        -1,
    )


# @torch.jit.script
def rsh_cart_8(xyz: torch.Tensor):
    """Computes all real spherical harmonics up to degree 8.

    This is an autogenerated method. See
    https://github.com/cheind/torch-spherical-harmonics
    for more information.

    Params:
        xyz: (N,...,3) tensor of points on the unit sphere

    Returns:
        rsh: (N,...,81) real spherical harmonics
            projections of input. Ynm is found at index
            `n*(n+1) + m`, with `0 <= n <= degree` and
            `-n <= m <= n`.
    """
    x = xyz[..., 0]
    y = xyz[..., 1]
    z = xyz[..., 2]

    x2 = x**2
    y2 = y**2
    z2 = z**2
    xy = x * y
    xz = x * z
    yz = y * z
    x4 = x2**2
    y4 = y2**2
    # z4 = z2**2
    return torch.stack(
        [
            0.282094791773878 * torch.ones(1, device=xyz.device).expand(xyz.shape[:-1]),
            -0.48860251190292 * y,
            0.48860251190292 * z,
            -0.48860251190292 * x,
            1.09254843059208 * xy,
            -1.09254843059208 * yz,
            0.94617469575756 * z2 - 0.31539156525252,
            -1.09254843059208 * xz,
            0.54627421529604 * x2 - 0.54627421529604 * y2,
            -0.590043589926644 * y * (3.0 * x2 - y2),
            2.89061144264055 * xy * z,
            0.304697199642977 * y * (1.5 - 7.5 * z2),
            1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
            0.304697199642977 * x * (1.5 - 7.5 * z2),
            1.44530572132028 * z * (x2 - y2),
            -0.590043589926644 * x * (x2 - 3.0 * y2),
            2.5033429417967 * xy * (x2 - y2),
            -1.77013076977993 * yz * (3.0 * x2 - y2),
            0.126156626101008 * xy * (52.5 * z2 - 7.5),
            0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
            1.48099765681286
            * z
            * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
            - 0.952069922236839 * z2
            + 0.317356640745613,
            0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
            0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
            -1.77013076977993 * xz * (x2 - 3.0 * y2),
            -3.75501441269506 * x2 * y2
            + 0.625835735449176 * x4
            + 0.625835735449176 * y4,
            -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
            8.30264925952416 * xy * z * (x2 - y2),
            0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
            0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
            0.241571547304372
            * y
            * (
                2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
                + 9.375 * z2
                - 1.875
            ),
            -1.24747010616985 * z * (1.5 * z2 - 0.5)
            + 1.6840846433293
            * z
            * (
                1.75
                * z
                * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
                - 1.125 * z2
                + 0.375
            )
            + 0.498988042467941 * z,
            0.241571547304372
            * x
            * (
                2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
                + 9.375 * z2
                - 1.875
            ),
            0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
            0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
            2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
            -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
            4.09910463115149 * x**4 * xy
            - 13.6636821038383 * xy**3
            + 4.09910463115149 * xy * y**4,
            -2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
            0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5),
            0.00584892228263444
            * y
            * (3.0 * x2 - y2)
            * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
            0.0701870673916132
            * xy
            * (
                2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
                - 91.875 * z2
                + 13.125
            ),
            0.221950995245231
            * y
            * (
                -2.8 * z * (1.5 - 7.5 * z2)
                + 2.2
                * z
                * (
                    2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
                    + 9.375 * z2
                    - 1.875
                )
                - 4.8 * z
            ),
            -1.48328138624466
            * z
            * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
            + 1.86469659985043
            * z
            * (
                -1.33333333333333 * z * (1.5 * z2 - 0.5)
                + 1.8
                * z
                * (
                    1.75
                    * z
                    * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
                    - 1.125 * z2
                    + 0.375
                )
                + 0.533333333333333 * z
            )
            + 0.953538034014426 * z2
            - 0.317846011338142,
            0.221950995245231
            * x
            * (
                -2.8 * z * (1.5 - 7.5 * z2)
                + 2.2
                * z
                * (
                    2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
                    + 9.375 * z2
                    - 1.875
                )
                - 4.8 * z
            ),
            0.0350935336958066
            * (x2 - y2)
            * (
                2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
                - 91.875 * z2
                + 13.125
            ),
            0.00584892228263444
            * x
            * (x2 - 3.0 * y2)
            * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
            0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4),
            -2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
            0.683184105191914 * x2**3
            + 10.2477615778787 * x2 * y4
            - 10.2477615778787 * x4 * y2
            - 0.683184105191914 * y2**3,
            -0.707162732524596
            * y
            * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3),
            2.6459606618019 * z * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4),
            9.98394571852353e-5
            * y
            * (5197.5 - 67567.5 * z2)
            * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
            0.00239614697244565
            * xy
            * (x2 - y2)
            * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z),
            0.00397356022507413
            * y
            * (3.0 * x2 - y2)
            * (
                3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
                + 1063.125 * z2
                - 118.125
            ),
            0.0561946276120613
            * xy
            * (
                -4.8 * z * (52.5 * z2 - 7.5)
                + 2.6
                * z
                * (
                    2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
                    - 91.875 * z2
                    + 13.125
                )
                + 48.0 * z
            ),
            0.206472245902897
            * y
            * (
                -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
                + 2.16666666666667
                * z
                * (
                    -2.8 * z * (1.5 - 7.5 * z2)
                    + 2.2
                    * z
                    * (
                        2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
                        + 9.375 * z2
                        - 1.875
                    )
                    - 4.8 * z
                )
                - 10.9375 * z2
                + 2.1875
            ),
            1.24862677781952 * z * (1.5 * z2 - 0.5)
            - 1.68564615005635
            * z
            * (
                1.75
                * z
                * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
                - 1.125 * z2
                + 0.375
            )
            + 2.02901851395672
            * z
            * (
                -1.45833333333333
                * z
                * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
                + 1.83333333333333
                * z
                * (
                    -1.33333333333333 * z * (1.5 * z2 - 0.5)
                    + 1.8
                    * z
                    * (
                        1.75
                        * z
                        * (
                            1.66666666666667 * z * (1.5 * z2 - 0.5)
                            - 0.666666666666667 * z
                        )
                        - 1.125 * z2
                        + 0.375
                    )
                    + 0.533333333333333 * z
                )
                + 0.9375 * z2
                - 0.3125
            )
            - 0.499450711127808 * z,
            0.206472245902897
            * x
            * (
                -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
                + 2.16666666666667
                * z
                * (
                    -2.8 * z * (1.5 - 7.5 * z2)
                    + 2.2
                    * z
                    * (
                        2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
                        + 9.375 * z2
                        - 1.875
                    )
                    - 4.8 * z
                )
                - 10.9375 * z2
                + 2.1875
            ),
            0.0280973138060306
            * (x2 - y2)
            * (
                -4.8 * z * (52.5 * z2 - 7.5)
                + 2.6
                * z
                * (
                    2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
                    - 91.875 * z2
                    + 13.125
                )
                + 48.0 * z
            ),
            0.00397356022507413
            * x
            * (x2 - 3.0 * y2)
            * (
                3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
                + 1063.125 * z2
                - 118.125
            ),
            0.000599036743111412
            * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
            * (-6.0 * x2 * y2 + x4 + y4),
            9.98394571852353e-5
            * x
            * (5197.5 - 67567.5 * z2)
            * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
            2.6459606618019 * z * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3),
            -0.707162732524596
            * x
            * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3),
            5.83141328139864 * xy * (x2**3 + 7.0 * x2 * y4 - 7.0 * x4 * y2 - y2**3),
            -2.91570664069932
            * yz
            * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3),
            7.87853281621404e-6
            * (1013512.5 * z2 - 67567.5)
            * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4),
            5.10587282657803e-5
            * y
            * (5.0 * z * (5197.5 - 67567.5 * z2) + 41580.0 * z)
            * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
            0.00147275890257803
            * xy
            * (x2 - y2)
            * (
                3.75 * z * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
                - 14293.125 * z2
                + 1299.375
            ),
            0.0028519853513317
            * y
            * (3.0 * x2 - y2)
            * (
                -7.33333333333333 * z * (52.5 - 472.5 * z2)
                + 3.0
                * z
                * (
                    3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
                    + 1063.125 * z2
                    - 118.125
                )
                - 560.0 * z
            ),
            0.0463392770473559
            * xy
            * (
                -4.125 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
                + 2.5
                * z
                * (
                    -4.8 * z * (52.5 * z2 - 7.5)
                    + 2.6
                    * z
                    * (
                        2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
                        - 91.875 * z2
                        + 13.125
                    )
                    + 48.0 * z
                )
                + 137.8125 * z2
                - 19.6875
            ),
            0.193851103820053
            * y
            * (
                3.2 * z * (1.5 - 7.5 * z2)
                - 2.51428571428571
                * z
                * (
                    2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
                    + 9.375 * z2
                    - 1.875
                )
                + 2.14285714285714
                * z
                * (
                    -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
                    + 2.16666666666667
                    * z
                    * (
                        -2.8 * z * (1.5 - 7.5 * z2)
                        + 2.2
                        * z
                        * (
                            2.25
                            * z
                            * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
                            + 9.375 * z2
                            - 1.875
                        )
                        - 4.8 * z
                    )
                    - 10.9375 * z2
                    + 2.1875
                )
                + 5.48571428571429 * z
            ),
            1.48417251362228
            * z
            * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
            - 1.86581687426801
            * z
            * (
                -1.33333333333333 * z * (1.5 * z2 - 0.5)
                + 1.8
                * z
                * (
                    1.75
                    * z
                    * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
                    - 1.125 * z2
                    + 0.375
                )
                + 0.533333333333333 * z
            )
            + 2.1808249179756
            * z
            * (
                1.14285714285714 * z * (1.5 * z2 - 0.5)
                - 1.54285714285714
                * z
                * (
                    1.75
                    * z
                    * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
                    - 1.125 * z2
                    + 0.375
                )
                + 1.85714285714286
                * z
                * (
                    -1.45833333333333
                    * z
                    * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
                    + 1.83333333333333
                    * z
                    * (
                        -1.33333333333333 * z * (1.5 * z2 - 0.5)
                        + 1.8
                        * z
                        * (
                            1.75
                            * z
                            * (
                                1.66666666666667 * z * (1.5 * z2 - 0.5)
                                - 0.666666666666667 * z
                            )
                            - 1.125 * z2
                            + 0.375
                        )
                        + 0.533333333333333 * z
                    )
                    + 0.9375 * z2
                    - 0.3125
                )
                - 0.457142857142857 * z
            )
            - 0.954110901614325 * z2
            + 0.318036967204775,
            0.193851103820053
            * x
            * (
                3.2 * z * (1.5 - 7.5 * z2)
                - 2.51428571428571
                * z
                * (
                    2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
                    + 9.375 * z2
                    - 1.875
                )
                + 2.14285714285714
                * z
                * (
                    -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
                    + 2.16666666666667
                    * z
                    * (
                        -2.8 * z * (1.5 - 7.5 * z2)
                        + 2.2
                        * z
                        * (
                            2.25
                            * z
                            * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
                            + 9.375 * z2
                            - 1.875
                        )
                        - 4.8 * z
                    )
                    - 10.9375 * z2
                    + 2.1875
                )
                + 5.48571428571429 * z
            ),
            0.0231696385236779
            * (x2 - y2)
            * (
                -4.125 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
                + 2.5
                * z
                * (
                    -4.8 * z * (52.5 * z2 - 7.5)
                    + 2.6
                    * z
                    * (
                        2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
                        - 91.875 * z2
                        + 13.125
                    )
                    + 48.0 * z
                )
                + 137.8125 * z2
                - 19.6875
            ),
            0.0028519853513317
            * x
            * (x2 - 3.0 * y2)
            * (
                -7.33333333333333 * z * (52.5 - 472.5 * z2)
                + 3.0
                * z
                * (
                    3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
                    + 1063.125 * z2
                    - 118.125
                )
                - 560.0 * z
            ),
            0.000368189725644507
            * (-6.0 * x2 * y2 + x4 + y4)
            * (
                3.75 * z * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
                - 14293.125 * z2
                + 1299.375
            ),
            5.10587282657803e-5
            * x
            * (5.0 * z * (5197.5 - 67567.5 * z2) + 41580.0 * z)
            * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
            7.87853281621404e-6
            * (1013512.5 * z2 - 67567.5)
            * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3),
            -2.91570664069932
            * xz
            * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3),
            -20.4099464848952 * x2**3 * y2
            - 20.4099464848952 * x2 * y2**3
            + 0.72892666017483 * x4**2
            + 51.0248662122381 * x4 * y4
            + 0.72892666017483 * y4**2,
        ],
        -1,
    )


__all__ = [
    "rsh_cart_0",
    "rsh_cart_1",
    "rsh_cart_2",
    "rsh_cart_3",
    "rsh_cart_4",
    "rsh_cart_5",
    "rsh_cart_6",
    "rsh_cart_7",
    "rsh_cart_8",
]


from typing import Optional

import torch


class SphHarm(torch.nn.Module):
    def __init__(self, m, n, dtype=torch.float32) -> None:
        super().__init__()
        self.dtype = dtype
        m = torch.tensor(list(range(-m + 1, m)))
        n = torch.tensor(list(range(n)))
        self.is_normalized = False
        vals = torch.cartesian_prod(m, n).T
        vals = vals[:, vals[0] <= vals[1]]
        m, n = vals.unbind(0)

        self.register_buffer("m", tensor=m)
        self.register_buffer("n", tensor=n)
        self.register_buffer("l_max", tensor=torch.max(self.n))

        f_a, f_b, initial_value, d0_mask_3d, d1_mask_3d = self._init_legendre()
        self.register_buffer("f_a", tensor=f_a)
        self.register_buffer("f_b", tensor=f_b)
        self.register_buffer("d0_mask_3d", tensor=d0_mask_3d)
        self.register_buffer("d1_mask_3d", tensor=d1_mask_3d)
        self.register_buffer("initial_value", tensor=initial_value)

    @property
    def device(self):
        return next(self.buffers()).device

    def forward(self, points: torch.Tensor) -> torch.Tensor:
        """Computes the spherical harmonics."""
        # Y_l^m = (-1) ^ m c_l^m P_l^m(cos(theta)) exp(i m phi)
        B, N, D = points.shape
        dtype = points.dtype
        theta, phi = points.view(-1, D).to(self.dtype).unbind(-1)
        cos_colatitude = torch.cos(phi)
        legendre = self._gen_associated_legendre(cos_colatitude)
        vals = torch.stack([self.m.abs(), self.n], dim=0)
        vals = torch.cat(
            [
                vals.repeat(1, theta.shape[0]),
                torch.arange(theta.shape[0], device=theta.device)
                .unsqueeze(0)
                .repeat_interleave(vals.shape[1], dim=1),
            ],
            dim=0,
        )
        legendre_vals = legendre[vals[0], vals[1], vals[2]]
        legendre_vals = legendre_vals.reshape(-1, theta.shape[0])
        angle = torch.outer(self.m.abs(), theta)
        vandermonde = torch.complex(torch.cos(angle), torch.sin(angle))
        harmonics = torch.complex(
            legendre_vals * torch.real(vandermonde),
            legendre_vals * torch.imag(vandermonde),
        )

        # Negative order.
        m = self.m.unsqueeze(-1)
        harmonics = torch.where(
            m < 0, (-1.0) ** m.abs() * torch.conj(harmonics), harmonics
        )
        harmonics = harmonics.permute(1, 0).reshape(B, N, -1).to(dtype)
        return harmonics

    def _gen_recurrence_mask(self) -> tuple[torch.Tensor, torch.Tensor]:
        """Generates mask for recurrence relation on the remaining entries.

        The remaining entries are with respect to the diagonal and offdiagonal
        entries.

        Args:
        l_max: see `gen_normalized_legendre`.
        Returns:
        torch.Tensors representing the mask used by the recurrence relations.
        """

        # Computes all coefficients.
        m_mat, l_mat = torch.meshgrid(
            torch.arange(0, self.l_max + 1, device=self.device, dtype=self.dtype),
            torch.arange(0, self.l_max + 1, device=self.device, dtype=self.dtype),
            indexing="ij",
        )
        if self.is_normalized:
            c0 = l_mat * l_mat
            c1 = m_mat * m_mat
            c2 = 2.0 * l_mat
            c3 = (l_mat - 1.0) * (l_mat - 1.0)
            d0 = torch.sqrt((4.0 * c0 - 1.0) / (c0 - c1))
            d1 = torch.sqrt(((c2 + 1.0) * (c3 - c1)) / ((c2 - 3.0) * (c0 - c1)))
        else:
            d0 = (2.0 * l_mat - 1.0) / (l_mat - m_mat)
            d1 = (l_mat + m_mat - 1.0) / (l_mat - m_mat)

        d0_mask_indices = torch.triu_indices(self.l_max + 1, 1)
        d1_mask_indices = torch.triu_indices(self.l_max + 1, 2)

        d_zeros = torch.zeros(
            (self.l_max + 1, self.l_max + 1), dtype=self.dtype, device=self.device
        )
        d_zeros[d0_mask_indices] = d0[d0_mask_indices]
        d0_mask = d_zeros

        d_zeros = torch.zeros(
            (self.l_max + 1, self.l_max + 1), dtype=self.dtype, device=self.device
        )
        d_zeros[d1_mask_indices] = d1[d1_mask_indices]
        d1_mask = d_zeros

        # Creates a 3D mask that contains 1s on the diagonal plane and 0s elsewhere.
        i = torch.arange(self.l_max + 1, device=self.device)[:, None, None]
        j = torch.arange(self.l_max + 1, device=self.device)[None, :, None]
        k = torch.arange(self.l_max + 1, device=self.device)[None, None, :]
        mask = (i + j - k == 0).to(self.dtype)
        d0_mask_3d = torch.einsum("jk,ijk->ijk", d0_mask, mask)
        d1_mask_3d = torch.einsum("jk,ijk->ijk", d1_mask, mask)
        return (d0_mask_3d, d1_mask_3d)

    def _recursive(self, i: int, p_val: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        coeff_0 = self.d0_mask_3d[i]
        coeff_1 = self.d1_mask_3d[i]
        h = torch.einsum(
            "ij,ijk->ijk",
            coeff_0,
            torch.einsum("ijk,k->ijk", torch.roll(p_val, shifts=1, dims=1), x),
        ) - torch.einsum("ij,ijk->ijk", coeff_1, torch.roll(p_val, shifts=2, dims=1))
        p_val = p_val + h
        return p_val

    def _init_legendre(self):
        a_idx = torch.arange(1, self.l_max + 1, dtype=self.dtype, device=self.device)
        b_idx = torch.arange(self.l_max, dtype=self.dtype, device=self.device)
        if self.is_normalized:
            # The initial value p(0,0).
            initial_value: torch.Tensor = torch.tensor(
                0.5 / (torch.pi**0.5), device=self.device
            )
            f_a = torch.cumprod(-1 * torch.sqrt(1.0 + 0.5 / a_idx), dim=0)
            f_b = torch.sqrt(2.0 * b_idx + 3.0)
        else:
            # The initial value p(0,0).
            initial_value = torch.tensor(1.0, device=self.device)
            f_a = torch.cumprod(1.0 - 2.0 * a_idx, dim=0)
            f_b = 2.0 * b_idx + 1.0

        d0_mask_3d, d1_mask_3d = self._gen_recurrence_mask()
        return f_a, f_b, initial_value, d0_mask_3d, d1_mask_3d

    def _gen_associated_legendre(self, x: torch.Tensor) -> torch.Tensor:
        r"""Computes associated Legendre functions (ALFs) of the first kind.

        The ALFs of the first kind are used in spherical harmonics. The spherical
        harmonic of degree `l` and order `m` can be written as
        `Y_l^m(θ, φ) = N_l^m * P_l^m(cos(θ)) * exp(i m φ)`, where `N_l^m` is the
        normalization factor and θ and φ are the colatitude and longitude,
        repectively. `N_l^m` is chosen in the way that the spherical harmonics form
        a set of orthonormal basis function of L^2(S^2). For the computational
        efficiency of spherical harmonics transform, the normalization factor is
        used in the computation of the ALFs. In addition, normalizing `P_l^m`
        avoids overflow/underflow and achieves better numerical stability. Three
        recurrence relations are used in the computation.

        Args:
        l_max: The maximum degree of the associated Legendre function. Both the
            degrees and orders are `[0, 1, 2, ..., l_max]`.
        x: A vector of type `float32`, `float64` containing the sampled points in
            spherical coordinates, at which the ALFs are computed; `x` is essentially
            `cos(θ)`. For the numerical integration used by the spherical harmonics
            transforms, `x` contains the quadrature points in the interval of
            `[-1, 1]`. There are several approaches to provide the quadrature points:
            Gauss-Legendre method (`scipy.special.roots_legendre`), Gauss-Chebyshev
            method (`scipy.special.roots_chebyu`), and Driscoll & Healy
            method (Driscoll, James R., and Dennis M. Healy. "Computing Fourier
            transforms and convolutions on the 2-sphere." Advances in applied
            mathematics 15, no. 2 (1994): 202-250.). The Gauss-Legendre quadrature
            points are nearly equal-spaced along θ and provide exact discrete
            orthogonality, (P^m)^T W P_m = I, where `T` represents the transpose
            operation, `W` is a diagonal matrix containing the quadrature weights,
            and `I` is the identity matrix. The Gauss-Chebyshev points are equally
            spaced, which only provide approximate discrete orthogonality. The
            Driscoll & Healy qudarture points are equally spaced and provide the
            exact discrete orthogonality. The number of sampling points is required to
            be twice as the number of frequency points (modes) in the Driscoll & Healy
            approach, which enables FFT and achieves a fast spherical harmonics
            transform.
        is_normalized: True if the associated Legendre functions are normalized.
            With normalization, `N_l^m` is applied such that the spherical harmonics
            form a set of orthonormal basis functions of L^2(S^2).

        Returns:
        The 3D array of shape `(l_max + 1, l_max + 1, len(x))` containing the values
        of the ALFs at `x`; the dimensions in the sequence of order, degree, and
        evalution points.
        """
        p = torch.zeros(
            (self.l_max + 1, self.l_max + 1, x.shape[0]), dtype=x.dtype, device=x.device
        )
        p[0, 0] = self.initial_value

        # Compute the diagonal entries p(l,l) with recurrence.
        y = torch.cumprod(
            torch.broadcast_to(torch.sqrt(1.0 - x * x), (self.l_max, x.shape[0])), dim=0
        )
        p_diag = self.initial_value * torch.einsum("i,ij->ij", self.f_a, y)
        # torch.diag_indices(l_max + 1)
        diag_indices = torch.stack(
            [torch.arange(0, self.l_max + 1, device=x.device)] * 2, dim=0
        )
        p[(diag_indices[0][1:], diag_indices[1][1:])] = p_diag

        diag_indices = torch.stack(
            [torch.arange(0, self.l_max, device=x.device)] * 2, dim=0
        )

        # Compute the off-diagonal entries with recurrence.
        p_offdiag = torch.einsum(
            "ij,ij->ij",
            torch.einsum("i,j->ij", self.f_b, x),
            p[(diag_indices[0], diag_indices[1])],
        )  # p[torch.diag_indices(l_max)])
        p[(diag_indices[0][: self.l_max], diag_indices[1][: self.l_max] + 1)] = (
            p_offdiag
        )

        # Compute the remaining entries with recurrence.
        if self.l_max > 1:
            for i in range(2, self.l_max + 1):
                p = self._recursive(i, p, x)
        return p
