# %%
import os, sys
import re
import itertools

import numpy as np
import torch

from lib.generate_datasets import GenDataMultiGaussVsSingleGauss_with_KL
from lib.utils import set_seed_everything

MODE = []
#MODE += ['TEST']

python_file_name = os.path.basename(__file__)
out_dirname = re.search(r'.+_generate_data_(.+?)(|-.*).py', python_file_name).group(1)
PARENT_DATA_DIR = os.path.join('data/', out_dirname)


N_sample_train = 10000
N_sample_eval = 10000
N_sample_test = 10000

N_modal_list = [1, 2, 3, 4]
DIM_data_list = [5]
KL_div_list = [1, 2, 4, 6, 8, 10, 12, 14]

N_out_data = 100


# %%
if 'TEST' in MODE:
    top_outdir = os.path.join(PARENT_DATA_DIR,'TEST')
else:
    top_outdir = PARENT_DATA_DIR

for dim, kl, n_mdl, in itertools.product(DIM_data_list, KL_div_list, N_modal_list):
    set_seed_everything(0)
    gen_data = GenDataMultiGaussVsSingleGauss_with_KL(
        kl,
        dim,
        n_mdl)
    kl_div = gen_data.get_true_KL()
    data_dir = f'dim-{dim}_kl-{kl}_nmdl-{n_mdl}'
    for _i_out in range(N_out_data):
        data_id = f'{_i_out:04}'
        out_dir = os.path.join(top_outdir, data_dir, data_id)
        os.makedirs(out_dir, exist_ok=True)
        d_train_tuple = gen_data.sample(N_sample_train)
        d_eval_tuple = gen_data.sample(N_sample_eval)
        d_test_tuple = gen_data.sample(N_sample_test)
        np.savez_compressed(
            os.path.join(out_dir, 'train.npz'), 
            de=d_train_tuple[0],
            nu=d_train_tuple[1])
            #true_true_density_rate=d_train_tuple[2])
        np.savez_compressed(
            os.path.join(out_dir, 'eval.npz'), 
            de=d_eval_tuple[0],
            nu=d_eval_tuple[1],
            true_density_rate=d_eval_tuple[2])
        np.savez_compressed(
            os.path.join(out_dir, 'test.npz'), 
            de=d_test_tuple[0],
            nu=d_test_tuple[1],
            true_density_rate=d_test_tuple[2],
            KL_div=kl_div)
        

# %%
