import os
import json
from multiprocessing import Pool, cpu_count
from tqdm import tqdm

def run_cmd(task):
	split, scene = task
	cmd = f"python main_eval_nearfar.py --split {split} --env_scene {scene}"
	print(f"Running: {cmd}")
	os.system(cmd)  # OR use subprocess.run(cmd, shell=True) for better control

if __name__ == "__main__":
	splits = ['train', 'val', 'test']
	train_scans = ['ZMojNkEp431', '1LXtFkjw3qL', 'sT4fr6TAbpF', 'r1Q1Z4BcV1o', 'cV4RVeZvu5T', 'EDJbREhghzL', 'PX4nDJXEHrG', 'YmJkqBEsHnH', 'ULsKaCPVFJR', '7y3sRwLe3Va', 'mJXqzFtmKg4', '759xd9YjKW5', '17DRP5sb8fy', 'ac26ZMwG7aT', 'Pm6F8kyY3z2', 'Vvot9Ly1tCj', 'PuKPg4mmafe', 'S9hNv5qa7GM', 'vyrNrziPKCB', 'SN83YJsR3w2', 'rPc6DW4iMge', 'r47D5H71a5s', 'qoiz87JEwZ2', '29hnd4uzFmX', '5LpN3gDmAk7', 'VFuaQ6m2Qom', 'i5noydFURQK', 'dhjEzFoUFzH', 'E9uDoFAP3SH', 's8pcmisQ38h', 'GdvgFV5R1Z5', '5q7pvUzZiYa', 'kEZ7cmS4wCh', 'JF19kD82Mey', 'pRbA3pwrgk9', '2n8kARJN3HM', 'HxpKQynjfin', 'VzqfbhrpDEA', 'Uxmj2M2itWa', 'V2XKFyX4ASd', 'JeFG25nYj2p', 'p5wJjkQkbXX', 'VLzqgDo317F', '8WUmhLawc2A', 'XcA2TqTSSAj', 'D7N2EKCX4Sj', 'aayBHfsNo7d', 'gZ6f7yhEvPG', 'D7G3Y4RVNrH', 'b8cTxDM8gDG', 'VVfe2KiqLaN', 'uNb9QFRL6hY', 'e9zR4mvMWw7', 'sKLMLpTHeUy', '82sE5b5pLXE', '1pXnuDYAj8r', 'jh4fc5c5qoQ', 'ur6pFq6Qu1A', 'B6ByNegPMKs', 'JmbYfDe2QKZ', 'gTV8FGcVJC9']
	val_scans = ['oLBMNvg9in8', 'TbHJrupSAjP', 'QUCTc6BB5sX', 'EU6Fwq7SyZv', 'zsNo4HB9uLZ', 'X7HyMhZNoso', 'x8F5xyUWy9e', '8194nk5LbLH', '2azQ1b91cZZ', 'pLe4wQe7qrG', 'Z6MFQCViBuw']
	test_scans = ['2t7WUuJeko7', '5ZKStnWn8Zo', 'ARNzJeq3xxb', 'RPmz2sHmrrY', 'UwV83HsGsw3', 'Vt2qJdWjCF2', 'WYY7iVyf5p8', 'YFuZgdQ5vWj', 'YVUC4YcDtcY', 'fzynW3qQPVF', 'gYvKGZ5eRqb', 'gxdoqLR6rwA', 'jtcxE69GiFV', 'pa4otMbVnkk', 'q9vSo1VnCiC', 'rqfALeAoiTq', 'wc2JMjhGNzB', 'yqstnuAEVhm']

	tasks = []
	for split in splits:
		if split == 'train':
			all_scene_list = train_scans
		elif split == 'val':
			all_scene_list = val_scans
		elif split == 'test':
			all_scene_list = test_scans
		else:
			raise ValueError(f'Unknown split: {split}')
		
		all_scene_list = [x.split('_')[0] for x in all_scene_list]
		all_scene_list = list(set(all_scene_list))
		all_scene_list.sort()

		for env_scene in all_scene_list:
			tasks.append((split, env_scene))
	
	errors = ['results_ur6pFq6Qu1A_0.npy', 'results_ur6pFq6Qu1A_1.npy', 'results_B6ByNegPMKs_0.npy', 'results_B6ByNegPMKs_1.npy', 'results_Vt2qJdWjCF2_1.npy', 'results_Vt2qJdWjCF2_2.npy', 'results_Vt2qJdWjCF2_3.npy', 'results_Vt2qJdWjCF2_4.npy', 'results_gxdoqLR6rwA_0.npy', 'results_gxdoqLR6rwA_1.npy', 'results_gxdoqLR6rwA_2.npy', 'results_gxdoqLR6rwA_3.npy', 'results_gxdoqLR6rwA_4.npy', 'results_gxdoqLR6rwA_5.npy']
	# errors = ['results_PuKPg4mmafe_1.npy']
	error_tasks = []
	for error in errors:
		error_scene = error.split('_')[1]
		if error_scene in train_scans:
			error_tasks.append(('train', error_scene))
		elif error_scene in val_scans:
			error_tasks.append(('val', error_scene))
		elif error_scene in test_scans:
			error_tasks.append(('test', error_scene))
		else:
			print(f"Error: {error_scene} not in train, val, or test scans.")
			continue
	tasks = error_tasks

	num_workers = min(cpu_count(), 8)

	with Pool(num_workers) as pool:
		list(tqdm(pool.imap_unordered(run_cmd, tasks), total=len(tasks)))