import os

from dataset import get_datalist

def get_samples(root):
    folds = [f'param{i}' for i in range(9)]
    samples = []
    for fold in folds:
        fold_samples = []
        files = os.listdir(os.path.join(root, fold))
        for file in files:
            path = os.path.join(root, os.path.join(fold, file))
            if os.path.isdir(path):
                fold_samples.append(os.path.join(fold, file))
        samples.append(fold_samples)
    return samples # 100 + 99 + 97 + 100 + 100 + 96 + 100 + 98 + 99 = 889 samples

def load_train_val_fold(args, preprocessed):
    samples = get_samples(args.data_dir)
    trainlst = []
    for i in range(len(samples)):
        if i == args.fold_id:
            continue
        trainlst += samples[i]
    vallst = samples[args.fold_id] if 0 <= args.fold_id < len(samples) else None

    train_dataset, coef_norm = get_datalist(args.data_dir, trainlst, norm=True, savedir=args.save_dir, preprocessed=preprocessed)
    val_dataset = get_datalist(args.data_dir, vallst, coef_norm=coef_norm, savedir=args.save_dir, preprocessed=preprocessed)

    return train_dataset, val_dataset, coef_norm

