import numpy as np
import os
import argparse
from tqdm import tqdm

from utils.rf_dataset import NpyStorage


def split_train_val(soi, train_frac, target_train, target_val):
    soi_store = NpyStorage(soi, load_to_ram=False)
    n = len(soi_store)
    train_size = int(n * train_frac)
    val_size = n - train_size
    in_val = np.array([True] * val_size + [False] * train_size)
    np.random.shuffle(in_val)

    os.makedirs(target_train, exist_ok=True)
    os.makedirs(target_val, exist_ok=True)
    id_train = 0
    id_val = 0
    for i in tqdm(range(n)):
        if in_val[i]:
            np.save(os.path.join(target_val, f"sig_{id_val}.npy"), soi_store[i])
            id_val += 1
        else:
            np.save(os.path.join(target_train, f"sig_{id_train}.npy"), soi_store[i])
            id_train += 1

    print(f"Saved {id_train} training examples")
    print(f"Saved {id_val} validation examples")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--soi", type=str)
    parser.add_argument("--train_frac", type=float)
    parser.add_argument("--target_train", type=str)
    parser.add_argument("--target_val", type=str)

    args = parser.parse_args()
    split_train_val(args.soi,
                    args.train_frac,
                    args.target_train,
                    args.target_val)
