# %%
import os, sys
import re

import numpy as np
import torch

from lib.generate_datasets import GenDataTwoUnitShiftedGausses_WithDifferentSD
from lib.utils import set_seed_everything

MODE = []
#MODE += ['TEST']

python_file_name = os.path.basename(__file__)
out_dirname = re.search(r'.+_generate_data_(.+?)(|-.*).py', python_file_name).group(1)
PARENT_DATA_DIR = os.path.join('data/', out_dirname)


N_sample_train = 10000
N_sample_eval = 10000
N_sample_test = 10000
denominator_sigma_list = [1.0, 1.1, 1.2, 1.4, 1.6, 2.0, 2.5, 3.0]


# %%
if 'TEST' in MODE:
    top_outdir = os.path.join(PARENT_DATA_DIR,'TEST')
    N_out_data = 1
else:
    top_outdir = PARENT_DATA_DIR
    N_out_data = 100
os.makedirs(top_outdir, exist_ok=True)   


set_seed_everything(0)
#for Data_Dim in np.arange(5, 31, 5):
for denominator_sigma in denominator_sigma_list:
    gen_data = GenDataTwoUnitShiftedGausses_WithDifferentSD(
        3,
        0.0, denominator_sigma, # denominator: rho, sigma
        0.0, 4,                 # numerator: rho, sigma
        3)
    kl_div = gen_data.get_true_KL()
    print('KL divergence', kl_div)
    for _i_out in range(N_out_data):
        out_dir_name=f'{_i_out:04}'
        out_dir = os.path.join(top_outdir, 
            f'denominator_sigma_{denominator_sigma*1000:04.0f}', out_dir_name)
        os.makedirs(out_dir, exist_ok=True)
        d_train_tuple = gen_data.sample(N_sample_train)
        d_eval_tuple = gen_data.sample(N_sample_eval)
        d_test_tuple = gen_data.sample(N_sample_test)
        np.savez_compressed(
            os.path.join(out_dir, 'train.npz'), 
            de=d_train_tuple[0],
            nu=d_train_tuple[1])
            #true_true_density_rate=d_train_tuple[2])
        np.savez_compressed(
            os.path.join(out_dir, 'eval.npz'), 
            de=d_eval_tuple[0],
            nu=d_eval_tuple[1],
            true_density_rate=d_eval_tuple[2])
        np.savez_compressed(
            os.path.join(out_dir, 'test.npz'), 
            de=d_test_tuple[0],
            nu=d_test_tuple[1],
            true_density_rate=d_test_tuple[2],
            KL_div=kl_div)
        

# %%
