import numpy as np
import argparse

import sys, os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from knn import *
from manifolds import *
from data_utils import *
from ide import *
from curv_utils import *

AVAILABLE_MANIFOLDS =  ['lin_nsp',
                        'gman_proj',
                        'gman_vec',
                        'flag_vec',
                        'steifel_proj',
                        'steifel_vec1', 
                        'sun_pauli', 
                        'fractal_hofs',
                        ]

IDE_METHODS = ['lpca_maxgap',
               'lpca_ratio',
               'lpca_fo',
               'mle',
               'danco',
               'corrint',
               'twonn',
               'abid',
               'gride',
               ]

# Single run experiments # 

def parser():
    parser = argparse.ArgumentParser(description="IDE estimation for various mainfolds")

    # Manifold parameters
    parser.add_argument('--manifold_name', type=str, default='hSphere', help = 'Name for the manifold, must match one of the AVAILABLE_MANIFOLDS.')
    parser.add_argument('--di', type=int, default=3, help = 'Intrinsic dimension, -1 implies not needed.')
    parser.add_argument('--da', type=int, default=4, help = 'Intrinsic dimension, -1 implies not needed.')
    parser.add_argument('--sq_eps', type=float, default=1.0, help = 'Squeeze amount. Must be strictly less than 2.0')

    # General IDE parameters
    parser.add_argument('--ide_name', type=str, default='lpca_maxgap', help = 'Name for the IDE technique being tested.')
    parser.add_argument('--k', type=int, default=100, help = 'k nearest neighbors')
    parser.add_argument('--N', type=int, default=5000, help = '# of samples')
    parser.add_argument('--sigma', type=float, default=0.0, help='Noise scale for additive noise')
    parser.add_argument('--noise_type', type=str, default='isoGauss', help='Noise type for our experiments.')

    ## Particular parameters for different IDE technique
    parser.add_argument('--eps', type=float, default=0.1, help='epsilon parameter for lPCA')
    parser.add_argument('--dNoise', type=float, default=0.1, help='noise parameter for MLE')
    parser.add_argument('--dFrac', type=float, default=0.1, help='discard Fraction parameter for TwoNN')
    parser.add_argument('--d', type=int, default=1, help='d parameter for ESS')
    parser.add_argument('--k1', type=int, default=10, help='k1 parameter for CorrInt/DANCo')
    parser.add_argument('--k2', type=float, default=2.0, help='k2 parameter for CorrInt')
    parser.add_argument('--n1', type=int, default=10, help='n1 parameter for Gride')
    parser.add_argument('--n2', type=float, default=2.0, help='n2 parameter for Gride')

    # General parameters 
    parser.add_argument('--dg_seed', type=int, default=42, help='Seed for manifold data generation.')
    parser.add_argument('--seed', type=int, default=0, help='Seed to control randomness for IDE experiments.')
    parser.add_argument('--n_jobs', type=int, default=1, help='Number of avaialable CPUs.')
    parser.add_argument('--if_curv', type=int, default=0, help='Whether curvature needs to be computed or not. 0 means yes, !=0 means no.')

    args = parser.parse_args()

    return args


def main():
    args = parser()

    if 'sphere' in args.manifold_name.lower():
        samples = hSphere_isometry(N=args.N, di=args.di, da=args.da, seed=args.seed)
        m_params = (args.di, args.da, args.seed)
    else:
        raise ValueError('Please choose an available manifold.')
        
    info_dict_m = {'samples' : samples, 'm_params' : m_params}
    
    
    if info_dict_m['samples'].ndim > 2:
        samples = info_dict_m['samples'].reshape((info_dict_m['samples'].shape[0],-1))
    else:
        samples=info_dict_m['samples']
    
    if args.k >= args.N - 1:
        args.k = args.N - 2
    
    np.random.seed(args.seed)
    indices = np.random.permutation(samples.shape[0])[:args.N]

    samples = samples[indices]
    
    if args.if_curv == 0:
        max_k = args.N//50
    else:
        max_k = args.k
    
    
    # Squeeze operation # 
    squeeze_op = 1- args.sq_eps/2 + args.sq_eps*np.random.rand(samples.shape[1])
    samples = np.multiply(squeeze_op[np.newaxis,:], samples)

    knn = precompute_knn(samples,kmax=max_k+1,n_jobs=args.n_jobs)
    knn = (knn[0][:,1:],knn[1][:,1:]) # Exclude self.

    # Curvature and Density Estimation # 
    if args.if_curv == 0:
        kvals = np.logspace(np.log10(max_k//2), np.log10(max_k), num = 5, dtype=int)
        lds_ests = []
        lcs_ests = []
        cvs_ests = []

        for k in kvals:
            lds = local_density_estimation(X=samples, knns=knn, kmax=k, di=di)
            lcs, cvs = mean_curvature_quadratic_fit(X=samples, knns=knn, kmax=k//2, di=di)
            lds_ests.append(lds)
            lcs_ests.append(lcs)
            cvs_ests.append(cvs)

    

    if args.sigma > 0:
        if args.noise_type=='isoGauss':
            noise = np.random.randn(*samples.shape)
        elif args.noise_type=='exp':
            noise = np.random.exponential(scale=1.0, size=samples.shape)
            signs = np.random.choice([-1, 1], size=samples.shape)
            noise = signs*noise
        elif args.noise_type=='uniform':
            noise = np.random.uniform(low = -1.0,high=1.0,size=samples.shape)
        elif args.noise_type=='anisoGauss':
            mean = np.zeros(samples.shape[1])
            cov = np.random.randn(samples.shape[1],samples.shape[1])
            noise = np.random.multivariate_normal(mean=mean, cov=cov, size=samples.shape[0])
        samples += args.sigma*noise

    if args.ide_name not in IDE_METHODS:
        raise ValueError('Please choose an available IDE technique.')
    elif args.ide_name == 'lpca_maxgap':
        params = (args.k, None, args.seed)
        info_dict = ide_lpca(samples=samples, precomputed_knn=knn,
                            manifold_name = args.manifold_name,
                            m_params=m_params, 
                            ver = 'maxgap', params = params,
                            n_jobs=args.n_jobs)
    elif args.ide_name == 'lpca_ratio':
        params = (args.k, args.eps, args.seed)
        info_dict = ide_lpca(samples=samples, precomputed_knn=knn,
                            manifold_name = args.manifold_name,
                            m_params=m_params, 
                            ver = 'ratio', params = params,
                            n_jobs=args.n_jobs)
    elif args.ide_name == 'lpca_fo':
        params = (args.k, args.eps, args.seed)
        info_dict = ide_lpca(samples=samples, precomputed_knn=knn,
                            manifold_name = args.manifold_name,
                            m_params=m_params, 
                            ver = 'fo', params = params,
                            n_jobs=args.n_jobs)
    elif args.ide_name == 'mle':
        params = (args.k, args.dNoise, args.seed)
        info_dict = ide_mle(samples=samples, precomputed_knn=knn,
                            manifold_name = args.manifold_name,
                            m_params=m_params, 
                            params = params,
                            n_jobs=args.n_jobs)
    elif args.ide_name == 'danco':
        if args.k1 <= args.k:
            args.k1 = args.k
        params = (args.k, args.k1, args.seed)
        info_dict = ide_danco(samples=samples, precomputed_knn=knn,
                            manifold_name = args.manifold_name,
                            m_params=m_params, 
                            params = params,
                            n_jobs=args.n_jobs)
    elif args.ide_name == 'twonn':
        params = (args.k, args.dFrac, args.seed)
        info_dict = ide_twonn(samples=samples, precomputed_knn=knn,
                            manifold_name = args.manifold_name,
                            m_params=m_params, 
                            params = params,
                            n_jobs=args.n_jobs)
    elif args.ide_name == 'corrint':
        k2 = int(args.k2*args.k1)
        if k2 >= args.N - 1:
            k2 = args.N - 2
            args.k1 = k2 // 2
        params = (args.k1, k2, args.seed)
        info_dict = ide_corrint(samples=samples, precomputed_knn=knn,
                            manifold_name = args.manifold_name,
                            m_params=m_params, 
                            params = params,
                            n_jobs=args.n_jobs)
    elif args.ide_name == 'abid':
        params = (args.k, args.seed)
        info_dict = ide_abid(samples=samples, precomputed_knn=knn,
                            manifold_name = args.manifold_name,
                            m_params=m_params, 
                            params = params,
                            n_jobs=args.n_jobs)
    elif args.ide_name == 'gride':
        params = (args.k, args.n1, args.n2, args.seed)
        info_dict = ide_gride(samples=samples, precomputed_knn=knn,
                            manifold_name = args.manifold_name,
                            m_params=m_params, 
                            params = params,
                            n_jobs=args.n_jobs)
    else:
        print('Unknown error! Check capitalization, spelling, etc. for name.')
        info_dict = None
    
    if args.if_curv == 0:
        if info_dict is not None:
            info_dict['kvals'] = kvals
            info_dict['lcs'] = lcs_ests
            info_dict['lds'] = lds_ests
            info_dict['cvs'] = cvs_ests
    
    params += (args.N, args.sigma, args.sq_eps) 
    args.manifold_name = 'sq_'+args.manifold_name
    
    if args.if_curv == 0:
        root_dir = './data/squeeze_curve_small/'
    else:
        root_dir = './data/squeeze_no_curve/'
    
    save_ide_info(info_dict=info_dict, 
                  ide_name=args.ide_name,
                  params=params,
                  manifold_name=args.manifold_name,
                  m_params=m_params,
                  root_dir = root_dir)
    
if __name__ == '__main__':
    main()
