import glob
import os
import random
import sys

location_dataset_name = 'location-2000-v1'
navigation_dataset_name = 'navigation-refine-v1'

def generate_split(data_size, train_rate):
    train_size = int(data_size * train_rate)
    indexes = list(range(data_size))
    random.shuffle(indexes)
    
    return indexes[:train_size], indexes[train_size:]

location_pickle_paths = glob.glob(f'../map-pretrain-data/{location_dataset_name}/*.pkl')
navigation_paths = glob.glob(f'../map-pretrain-data/{navigation_dataset_name}/*')

location_train, location_val = generate_split(len(location_pickle_paths), 0.8)
navigation_train, navigation_val = generate_split(len(navigation_paths), 0.8)

os.makedirs('../map-pretrain-data/data-split', exist_ok=True)

with open(f'../map-pretrain-data/data-split/{location_dataset_name}-train.txt', 'w') as file:
    file.write(' '.join(str(x) for x in location_train))
    
with open(f'../map-pretrain-data/data-split/{location_dataset_name}-val.txt', 'w') as file:
    file.write(' '.join(str(x) for x in location_val))
    
with open(f'../map-pretrain-data/data-split/{navigation_dataset_name}-train.txt', 'w') as file:
    file.write(' '.join(str(x) for x in navigation_train))
    
with open(f'../map-pretrain-data/data-split/{navigation_dataset_name}-val.txt', 'w') as file:
    file.write(' '.join(str(x) for x in navigation_val))