import numpy as np
import os
import sys

# RP_Feature_Names = ['mv', 'gv', 'g', 'mom', 't']
RP_Feature_Names = ['mt', 'gt', 'g', 'mom5', 'mom9', 'mom99'   ]

DM_Feature_Name = ['g']
__GENERATED_ALL__ = {



    'RP_s':
    ('SR_RP_s_260917.npy',
     RP_Feature_Names[:2],
     2,  20, 
     '#2 RP large, final result, for distill RP_s',
    ),
    



    'RP_s_i':
    ('SR_RP_s_i_f4_94440.npy',
     RP_Feature_Names[:4],
     4,  20, 
     '#4 RP large, final result, for distill Adam & mom5',
    ),





    'RP_s':
    ('SR_RP_s_26080.npy',
     RP_Feature_Names[:2],
     2,  20, 
     'first results, #2 RP, not used'
    ),



    }






def load_SR_dataset(which, N_train=1000, feature_names=None, verbose=False):
    N_test = 1000
    if which is None:
        assert feature_names is not None
        fname = get_newest_file(filter='.npy')
        if verbose: print(f'loading newest Dataset:\n\t {fname}')
        l2o_num_grad_features = None
        N_pre = 20
    else:
        fname, feature_names, l2o_num_grad_features, N_pre, desc = __GENERATED_ALL__[which]
        if verbose: print(f'loading recorded Dataset:\n\t {fname}')


    fname = os.path.join(os.path.dirname(__file__), fname)
    Xy_all = np.load(fname)

    Xy_train = Xy_all[:N_train]
    Xy_test = Xy_all[N_train:N_train+N_test]
    Xy_fit = Xy_all[N_train+N_test:]

    return Xy_train, Xy_test, Xy_fit, feature_names, l2o_num_grad_features, N_pre







# print(os.getcwd()) # outer
# print(os.path.abspath(__file__)) # inner, but return file name rather than dir

def get_newest_file(file_dir='.', filter='file'):
    # os.getcwd()
    if file_dir=='.': file_dir = os.path.dirname(__file__)

    files=os.listdir(file_dir)
    files.sort(key=lambda fn: os.path.getmtime(os.path.join(file_dir,fn)) if not os.path.isdir(os.path.join(file_dir,fn)) else 0)

    newest = files[0]
    # prt(files)
    for x in files:
        if filter == 'file':
            print(os.path.isdir(x),x)
            if not os.path.isdir(x):
                newest = x
        elif filter == 'dir':
            if os.path.isdir(x):
                newest = x
        elif filter != '':

            print(x[-len(filter):]==filter, x)
            if x[-len(filter):]==filter:
                newest = x
        else:  # filter==''
            newest = x

    # newest = files[-1]
    return os.path.join(file_dir,newest)



