import numpy as np
import argparse
import h5py
import pickle

from knn import *
from manifolds import *
from data_utils import *


AVAILABLE_MANIFOLDS =  ['lin_nsp',
                        'gman_proj',
                        'gman_vec',
                        'flag_vec',
                        'steifel_proj',
                        'steifel_vec1', 
                        'sun_pauli', 
                        'fractal_hofs',
                        ]

def parser():
    parser = argparse.ArgumentParser(description="Creating samples of data generation")

    parser.add_argument('--manifold_name', type=str, default='lin_nsp', help = 'Name for the manifold, must match one of the AVAILABLE_MANIFOLDS.')
    parser.add_argument('--d1', type=int, default=20, help = 'First integer dimension, -1 implies not needed.')
    parser.add_argument('--d2', type=int, default=10, help = 'Second integer dimension, -1 implies not needed.')
    parser.add_argument('--d3', type=int, default=-1, help = 'Third integer dimension, -1 implies not needed.')
    parser.add_argument('--seed', type=int, default=42, help='Seed to control randomness.')
    parser.add_argument('--n_jobs', type=int, default=1, help='Number of CPUs for computing kNNs.')

    args = parser.parse_args()

    return args



def main():
    args = parser()

    if args.manifold_name not in AVAILABLE_MANIFOLDS:
        raise ValueError('Please choose an available manifold.')
    elif args.manifold_name == 'lin_nsp':
        params = (args.d1, args.d2, args.seed)
        info_dict = data_gen_lin_nsp(params=params)
    elif args.manifold_name == 'gman_proj':
        params = (args.d1, args.d2, args.seed)
        info_dict = data_gen_gman_proj(params=params)
    elif args.manifold_name == 'gman_vec':
        params = (args.d1, args.d2, args.seed)
        info_dict = data_gen_gman_vec(params=params)
    elif args.manifold_name == 'flag_vec':
        params = (args.d1, args.d2, args.d3, args.seed)
        info_dict = data_gen_flag_vec(params=params)
    elif args.manifold_name == 'steifel_proj':
        params = (args.d1, args.d2, args.seed)
        info_dict = data_gen_steifel_proj(params=params)
    elif args.manifold_name == 'steifel_vec1':
        params = (args.d1, args.d2, args.seed)
        info_dict = data_gen_steifel_vec1(params=params)
    elif args.manifold_name == 'sun_pauli':
        params= (args.d1,args.seed)
        info_dict = data_gen_sun_mod_pauli(params=params)
    elif args.manifold_name == 'fractal_hofs':
        params = (args.d1, args.seed)
        info_dict = data_gen_fractal_hofstadter(params=params)
    else:
        print('Unknown error! Check capitalization, spelling, etc. for name.')
        info_dict= None
    
    save_info(info_dict=info_dict,
              params=params,
              manifold_name=args.manifold_name,
              root_dir = './data/samples/')
    
if __name__ == '__main__':
    main()
