r"""Spherical Harmonics as polynomials of x, y, z
"""
from typing import Union, List, Any

import math

import torch
import numpy as np

from e3nn import o3

import scipy

from utils import cartesian_to_spherical__numpy


def complex_spherical_harmonics(l: Union[int, List[int], str, o3.Irreps],
                                x: np.array,
                                normalization: str = 'integral'):
        
        '''
        uses numpy, so this is non-differentiable
        '''

        assert normalization in ['integral', 'component', 'norm']
        irreps_out = l

        if isinstance(irreps_out, str):
            irreps_out = o3.Irreps(irreps_out)

        if isinstance(irreps_out, o3.Irreps):
            ls = []
            for mul, (l, p) in irreps_out:
                ls.extend([l]*mul)
        elif isinstance(irreps_out, int):
            ls = [irreps_out]
        else:
            ls = list(irreps_out)

        ls_list = ls
        lmax = max(ls)
        is_range_lmax = ls == list(range(max(ls) + 1))

        # if self.normalize:
        #     x = torch.nn.functional.normalize(x, dim=-1)  # forward 0's instead of nan for zero-radius

        # Unlike the e3nn code for real spherical harmonics, the complex spherical harmonics function in scipy already assumes normalized vectors
        # for now I assume no unnormalized version can be requested

        # convert to spherical coordinates
        rtp = cartesian_to_spherical__numpy(x)
        r, t, p = rtp[:, 0], rtp[:, 1], rtp[:, 2]

        sh = _complex_spherical_harmonics(lmax, t, p)

        if not is_range_lmax:
            sh = np.hstack([
                sh[..., l*l:(l+1)*(l+1)]
                for l in ls_list
            ])

        if normalization == 'integral':
            sh *= np.hstack([
                (math.sqrt(2 * l + 1) / math.sqrt(4 * math.pi)) * np.ones(2 * l + 1, dtype=sh.dtype)
                for l in ls_list
            ])
        elif normalization == 'component':
            sh *= np.hstack([
                math.sqrt(2 * l + 1) * np.ones(2 * l + 1, dtype=sh.dtype)
                for l in ls_list
            ])
        # if normalization == 'norm' nothing happens

        return sh

def _complex_spherical_harmonics(lmax, t, p):
    # scipy's implementation of the complex spherical harmonics returns a conveniently formatted output for fixed l

    ret_val = []
    for l in range(lmax+1):
        ms = np.arange(-l, l+1)
        ret_val.append(scipy.special.sph_harm(ms,l,p.reshape(-1, 1),t.reshape(-1, 1))) # scipy uses the opposite convention for theta and phi
    return np.hstack(ret_val)