import numpy as np
from numba import jit, prange

from persistence_spheres_utils import from_DGMS_to_H, decompose_dgm, make_weighting, persistence, spherical_to_card_vec, meshgrid

import pyshtools as pysh

class PerSPhere(object):
    def __init__(self, DGMS, n_theta = 100, n_phi = 200, eps = 0.00001, weighting=persistence):
        
        self.DGMS = DGMS
        self.n_theta = n_theta
        self.n_phi = n_phi
        self.eps = eps
        self.weighting=weighting

        self.theta_grid = np.linspace(0,np.pi,n_theta)
        self.phi_grid = np.linspace(0,2*np.pi,n_phi)
    
        phiv,thetav = meshgrid(self.phi_grid,self.theta_grid)
        grid = np.zeros((n_theta,n_phi,2))
    
        grid[:,:,0] = phiv
        grid[:,:,1] = thetav
    
        self.pts = spherical_to_card_vec(grid)
        
        self.H = from_DGMS_to_H(self.DGMS,self.pts,self.weighting,self.n_theta, self.n_phi)

        self.coeff_shape = self.get_sph_armonics(self.H[0]).shape

        self.sph_armonics = np.array([self.get_sph_armonics(h).flatten() for h in self.H])



    def get_sph_armonics(self, h):

        grid = pysh.SHGrid.from_array(h[::-1,:], grid='DH')
    
        # Expand into spherical harmonics
        coeffs = grid.expand(normalization='ortho')
    
        return coeffs.coeffs


    def from_sph_armonics_to_h(self,v):

        if v.shape != self.coeff_shape:
            v = v.reshape(self.coeff_shape)

        coeffs = pysh.SHCoeffs.from_array(v,normalization='ortho')

        h = coeffs.expand(grid='DH2').to_array()[::-1,:]

        return h[:-1,:-1]
        

    
    def decompose(self,):
        self.H_decomposition = [decompose_dgm(dgm, weighting=self.weighting,n_theta = self.n_theta, n_phi = self.n_phi) for dgm in self.DGMS]

    
    def decompose_in_vec(self,save_memory=True):

        try:
            self.H_decomposition
        except:
            self.decompose()

        self.H_decomposition_vec = np.array([[self.get_sph_armonics(h) for h in h_dgm] for h_dgm in self.H_decomposition])
        
        if save_memory:
            del(self.H_decomposition)    
        


