import glob
import os
from sklearn.model_selection import train_test_split
from tqdm.contrib.concurrent import process_map


def split_one_dataset(origin_path, spilt_ratio=0.3):
    with open(origin_path) as f:
        samples = f.read().splitlines()
    paths, labels = [], []
    for sample in samples:
        path, label = sample.split()
        paths.append(path)
        labels.append(label)
    train_paths, test_paths, train_labels, test_labels = train_test_split(paths, labels, test_size=spilt_ratio, stratify=labels)
    origin_name, _ = os.path.splitext(origin_path)
    origin_name = origin_name.replace('_list', '')
    with open(origin_name + '_train.txt', 'w', encoding='utf-8') as f:
        f.write('\n'.join(f"{path} {label}" for path, label in zip(train_paths, train_labels)))
    with open(origin_name + '_test.txt', 'w', encoding='utf-8') as f:
        f.write('\n'.join(f"{path} {label}" for path, label in zip(test_paths, test_labels)))


if __name__ == "__main__":
    data_root = '/home/username/datasets/office/'
    # split_one_dataset(os.path.join(data_root, 'Real.txt'))
    paths = glob.glob(f"{data_root}/*.txt")
    process_map(split_one_dataset, paths)

