import os
import sklearn.model_selection
import numpy as np

def train_test_split(config, y_labels):
    base_path = os.path.join(config.dataset_root, config.dataset_name, 'data_splits')
    os.makedirs(base_path, exist_ok=True)
    train_split_path = os.path.join(base_path, f'train_idx_{config.dataset_valid_size}.npy')
    valid_split_path = os.path.join(base_path, f'valid_idx_{config.dataset_valid_size}.npy')
    if os.path.exists(train_split_path) and  os.path.exists(valid_split_path):
        print('==> Using precomputed data splits')
        train_idx, valid_idx = np.load(train_split_path), np.load(valid_split_path)
    else:
        print('==> Computing data splits')
        indices = list(range(len(y_labels)))
        train_idx, valid_idx = sklearn.model_selection.train_test_split(indices, test_size=config.dataset_valid_size, stratify=y_labels)
        np.save(train_split_path, train_idx)
        np.save(valid_split_path, valid_idx)

    ## shuffle dataset idx
    rng = np.random.default_rng(config.dataset_order_seed)
    rng.shuffle(train_idx)

    return train_idx, valid_idx