import numpy as np
from scipy.special import roots_jacobi

def _polar_grid(num_radial, num_angular, quadrature="rectangular"):
    if quadrature == "rectangular":
        radial_grid = np.linspace(0, 1, num_radial, endpoint=False)
        radial_weights = radial_grid * 1 / num_radial
    elif quadrature == "gauss-jacobi":
        radial_grid, radial_weights = roots_jacobi(num_radial, 0, 1)

        radial_grid = (radial_grid + 1) / 2

        # Note that we need an extra factor of two to account for the weight
        # function, which is also rescaled.
        radial_weights /= 2 * 2
    else:
        raise ValueError(f"Unknown quadrature rule: {quadrature}")

    angular_grid = np.linspace(0, 2 * np.pi, num_angular, endpoint=False)
    angular_weights = np.full((num_angular,), 2 * np.pi / num_angular)

    polar_grid = np.meshgrid(radial_grid, angular_grid, indexing="ij")

    polar_grid = (polar_grid[0] * np.cos(polar_grid[1]),
                  polar_grid[0] * np.sin(polar_grid[1]))

    polar_grid = np.stack(polar_grid, axis=-1)

    polar_weights = (radial_weights[:, np.newaxis]
                     * angular_weights[np.newaxis, :])

    return polar_grid, polar_weights

def _cartesian_grid(L):
    step_cartesian = 2 / L

    cartesian_grid = np.mgrid[-1:1:step_cartesian, -1:1:step_cartesian]

    cartesian_grid = np.stack(cartesian_grid, axis=-1)

    return cartesian_grid



def default_rbf_params(params):
    if isinstance(params, str):
        oversampling_angular, oversampling_radial = \
                {"compr": [1, 2],
                 "compr2": [1, 0.5],
                 "compr3": [0.5, 2],
                 "compr4": [0.5, 0.5],
                 "compr5": [0.75, 2],
                 "compr6": [0.75, 1.5],
                 "compr7": [0.75, 1],
                 "compr8": [1.0, 1.5],
                 "compr9": [1.5, 1.0]}[params]
        params = {"oversampling_angular": oversampling_angular,
                  "oversampling_radial": oversampling_radial}

    params = _fill_rbf_params(params)

    return params

def _fill_rbf_params(params):
    _params = {"L": 64, "oversampling": 2, "b": 0.5, "normalize": True,
               "quadrature": "gauss-jacobi"}

    _params.update(params)

    return _params

def _polar_grid_size(params):
    L = params["L"]
    oversampling = params["oversampling"]

    oversampling_radial = oversampling
    oversampling_angular = oversampling

    if "oversampling_radial" in params:
        oversampling_radial = params["oversampling_radial"]

    if "oversampling_angular" in params:
        oversampling_angular = params["oversampling_angular"]

    num_radial = int(np.ceil(L / 2 * oversampling_radial))
    num_angular = int(np.ceil(2 * L * oversampling_angular))

    return num_radial, num_angular

