import os, sys
import re

import numpy as np


from lib.generate_datasets import GenDataTwoShiftedGausses
from lib.utils import set_seed_everything


MODE = []
#MODE += ['TEST']

python_file_name = os.path.basename(__file__)
out_dirname = re.findall(r'.+_generate_data_(.+).py', python_file_name)[0]
PARENT_DATA_DIR = os.path.join('data/', out_dirname)

N_sample_train = 10000
N_sample_eval = 1000
N_sample_test = 1000
dist_centers_nu_and_de = 5


if 'TEST' in MODE:
    top_outdir = os.path.join(PARENT_DATA_DIR,'TEST')
    N_out_data = 10
    DIM_data_list = [5]
else:
    DIM_data_list = [5]
    N_out_data = 100
    top_outdir = PARENT_DATA_DIR
os.makedirs(top_outdir, exist_ok=True)   

set_seed_everything(0)
for _dim in DIM_data_list:
    gen_Gauss = GenDataTwoShiftedGausses(dist_centers_nu_and_de, _dim)
    for _i_out in range(N_out_data):
        out_dir_name=f'{_i_out:04}'
        out_dir = os.path.join(top_outdir, 
            f'Dim_{_dim:03}_dist{dist_centers_nu_and_de}', out_dir_name)
        os.makedirs(out_dir, exist_ok=True)
        d_train_tuple = gen_Gauss.sample(N_sample_train)
        d_eval_tuple = gen_Gauss.sample(N_sample_eval)
        d_test_tuple = gen_Gauss.sample(N_sample_test)
        print(out_dir)
        print('max rate', np.max(d_train_tuple[2]))
        np.savez_compressed(
            os.path.join(out_dir, 'train.npz'), 
            de=d_train_tuple[0],
            nu=d_train_tuple[1])
            #true_true_density_rate=d_train_tuple[2])
        np.savez_compressed(
            os.path.join(out_dir, 'eval.npz'), 
            de=d_eval_tuple[0],
            nu=d_eval_tuple[1],
            true_density_rate=d_eval_tuple[2])
        np.savez_compressed(
            os.path.join(out_dir, 'test.npz'), 
            de=d_test_tuple[0],
            nu=d_test_tuple[1],
            true_density_rate=d_test_tuple[2])
