import os
from glob import glob
from pathlib import Path
import json
num_session_limit = 10000000
root = "artifacts/demo-pref:v1"
folder_list = [
    'pref_124_0653-top_first-True_front-center', 
    'pref_124_0653-top_first-False_front-center', 
    'pref_5234_160-top_first-True_front-center',
    'pref_5234_160-top_first-False_front-center',
    'pref_325_1460-top_first-False_front-center',
    'pref_253_0641-top_first-True_front-center', 
    'pref_461_0532-top_first-True_front-center',
    'pref_0123_456-top_first-True_back-center',
    'pref_321_4056-top_first-True_back-left',
    'pref_5610_432-top_first-True_front-right'
    'pref_0163_245-top_first-True_front-right',
    'pref_3210_654-top_first-True_back-center',
    'pref_3541_260-top_first-False_back-left',
    'pref_415_0263-top_first-True_front-center',
    'pref_523_0641-top_first-False_front-left',
    'pref_2354_601-top_first-False_front-center',
    'pref_3541_260-top_first-False_back-left',
    'pref_5610_432-top_first-True_front-right'
]


# # heldout
# folder_list =[
#     'pref_325_1460-top_first-True_front-center',
#     'pref_253_0641-top_first-False_front-center', 
#     'pref_123_6504-top_first-True_back-right',
#     'pref_1453_062-top_first-False_front-left',
#     'pref_5104_362-top_first-True_front-center',
# ]

num_obj_list = ['6', '7']  # ['5', '6', '7'] ['3', '4', '5', '8', '9','10'] #
# session_list = glob(Path(root, "pref*", "*", "sess_*.json").as_posix())
suffix = f"num_pref-{len(folder_list)}_num_obj-{'-'.join(num_obj_list)}_num_demo-{num_session_limit}"  # _small" #,6,7"
session_split = {"train": {}, "val": {}, "test": {}}

split_ratios = {"train": 0.8, "val": 0.1, "test": 0.1}

for folder_name in folder_list:
    session_split["train"][folder_name] = []
    session_split["val"][folder_name] = []
    session_split["test"][folder_name] = []
    for num_obj in num_obj_list:
        session_list = glob(Path(root, folder_name, num_obj, "sess_*.json").as_posix())[:num_session_limit]
        # session_dict[folder_name].sort(key=lambda x: int(x.split("/")[-2]))
        # split
        train_sessions = session_list[: int(len(session_list) * split_ratios["train"])]
        val_sessions = session_list[
            len(train_sessions) : len(train_sessions) + int(len(session_list) * split_ratios["val"])
        ]
        test_sessions = session_list[
            len(train_sessions)
            + len(val_sessions) : len(train_sessions)
            + len(val_sessions)
            + int(len(session_list) * split_ratios["test"])
        ]
        session_split["train"][folder_name] += train_sessions
        session_split["val"][folder_name] += val_sessions
        session_split["test"][folder_name] += test_sessions

with open(f"session_split_filepaths/num_pref/{suffix}.json", "w") as f:
    json.dump(session_split, f, indent=4)
