
import numpy as np
import argparse
import gzip, pickle

from lie_learn.spaces.spherical_quadrature import estimate_spherical_quadrature_weights
import lie_learn.spaces.S2 as S2


def get_projection_grid(b, grid_type="Driscoll-Healy"):
    ''' returns the spherical grid in euclidean
    coordinates, where the sphere's center is moved
    to (0, 0, 1)'''
    theta, phi = S2.meshgrid(b=b, grid_type=grid_type)
    x_ = np.sin(theta) * np.cos(phi)
    y_ = np.sin(theta) * np.sin(phi)
    z_ = np.cos(theta)
    return (theta, phi), (x_, y_, z_)
    

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--bandwidth', type=int, default=30)
    parser.add_argument('--lmax', type=int, default=10)
    args = parser.parse_args()

    ba_grid, xyz_grid = get_projection_grid(b=args.bandwidth)

    # flatten grids

    ba_grid = list(ba_grid)
    for i in range(2):
        ba_grid[i] = ba_grid[i].flatten()
    ba_grid = tuple(ba_grid)

    xyz_grid = list(xyz_grid)
    for i in range(3):
        xyz_grid[i] = xyz_grid[i].flatten()
    xyz_grid = tuple(xyz_grid)

    ## Pre-compute quadrature weights for the grid
    # We need to compute the intergal of the forward ft as accurately as possible in order to be able to
    #   compute the inverse ft and reconstruct the original signals.
    # The package 'lie_learn' estimates the quadrature weights of given points via least squares regression and using real spherical harmonics
    # In the folder 'quadrature_figures' one can see that not using quadrature weights makes it way harder
    #   to reconstruct the original signals.
    # Ideally, one would probably construct the grid and the weights at the same time with the needed symmetries in order
    #   to get the intergal as exact as possible (instead of using a least squares approximation of the weights). That would
    #   probably generate the most accurate ft projections and resulting reconstructions.
    print('Computing quadrature weights...')
    quad_weights, residuals, rank, s = estimate_spherical_quadrature_weights(np.transpose(np.vstack(ba_grid)), args.lmax, normalization='seismology', condon_shortley=True)
    print('Done. Saving...')

    print(np.mean(quad_weights))

    with gzip.open('ba_grid-b=%d.gz' % (args.bandwidth), 'wb') as f:
        pickle.dump(ba_grid, f)
    
    with gzip.open('xyz_grid-b=%d.gz' % (args.bandwidth), 'wb') as f:
        pickle.dump(xyz_grid, f)
    
    np.save('quad_weights-b=%d-lmax=%d.npy' % (args.bandwidth, args.lmax), quad_weights)

