import time
import click
import socket
import multiprocessing as mp
from chester.run_exp import run_experiment_lite, VariantGenerator
from GNS.fabric_vsf.vismpc.scripts.predict_softgym import run_task


@click.command()
@click.argument('mode', type=str, default='local')
@click.option('--debug/--no-debug', default=True)
@click.option('--dry/--no-dry', default=False)
def main(mode, debug, dry):
    exp_prefix = '0612-vsf-predict-dual-small'
    vg = VariantGenerator()

    if not debug:
        vg.add('max_episodes', [10])
        # bimanual small
        # vg.add('model_dir', ['/data/user2w2/VCD_policy/data/local/0530-vsf-train-bimanual-small/0530-vsf-train-bimanual-small_2021_05_31_12_58_58_0001/output_data'])
        # vg.add('data_dir', ['/data/user2w2/VCD_policy/data/local/0530-vsf-train-bimanual-small/0530-vsf-train-bimanual-small_2021_05_31_12_58_58_0001/train_data'])
        # vg.add('input_img', ['./GNS/fabric_vsf/data/fold_test_data-8.pkl'])

        vg.add('model_dir', ['/data/user2w2/VCD_policy/data/local/0609-vsf-dual-small-5-fold-input-1-target-3/0609-vsf-dual-small-5-fold-input-1-target-3_2021_06_09_20_12_56_0001/output_data'])
        vg.add('data_dir', ['/data/user2w2/VCD_policy/data/local/0609-vsf-dual-small-5-fold-input-1-target-3/0609-vsf-dual-small-5-fold-input-1-target-3_2021_06_09_20_12_56_0001/train_data'])
        vg.add('input_img', ['./GNS/fabric_vsf/data/fold_test_data-8.pkl'])

        # single small
        # vg.add('model_dir', ['/data/user2w2/VCD_policy/data/local/0601-vsf-train-single-small/0601-vsf-train-single-small_2021_06_02_13_51_25_0001/output_data'])
        # vg.add('data_dir', ['/data/user2w2/VCD_policy/data/local/0601-vsf-train-single-small/0601-vsf-train-single-small_2021_06_02_13_51_25_0001/train_data'])
        # vg.add('input_img', ['./GNS/fabric_vsf/data/vsf-fold-test-data-4_debug_2021_06_06_15_38_09_0001/softgym_traj_0_10_pickle'])

        # dual large
        # vg.add('model_dir', ['/data/user2w2/VCD_policy/data/local/0531-vsf-train-bimanual-large/0531-vsf-train-bimanual-large_2021_06_01_01_29_02_0001/output_data'])
        # vg.add('data_dir', ['/data/user2w2/VCD_policy/data/local/0531-vsf-train-bimanual-large/0531-vsf-train-bimanual-large_2021_06_01_01_29_02_0001/train_data'])
        # vg.add('input_img', ['./GNS/fabric_vsf/data/vsf-fold-test-data-large-8_debug_2021_06_09_20_22_16_0001/softgym_traj_0_10_pickle'])

        # single large
        # vg.add('model_dir', ['/data/user2w2/VCD_policy/data/local/0531-vsf-train-bimanual-large/0531-vsf-train-bimanual-large_2021_06_01_01_29_02_0001/output_data'])
        # vg.add('data_dir', ['/data/user2w2/VCD_policy/data/local/0531-vsf-train-bimanual-large/0531-vsf-train-bimanual-large_2021_06_01_01_29_02_0001/train_data'])
        # vg.add('input_img', ['./GNS/fabric_vsf/data/vsf-fold-test-data-large-4_debug_2021_06_09_20_37_32_0001/softgym_traj_0_10_pickle'])


        vg.add('adim', ['8'])

    else:
        vg.add('max_episodes', [2])
        vg.add('model_dir', ['GNS/fabric_vsf/data/output_data/'])
        vg.add('data_dir', ['GNS/fabric_vsf/data/train_data/'])

        exp_prefix += '_debug'

    print('Number of configurations: ', len(vg.variants()))
    print("exp_prefix: ", exp_prefix)

    hostname = socket.gethostname()
    gpu_num = 1 #torch.cuda.device_count()

    sub_process_popens = []
    for idx, vv in enumerate(vg.variants()):
        while len(sub_process_popens) >= 10:
            sub_process_popens = [x for x in sub_process_popens if x.poll() is None]
            time.sleep(10)
        if mode in ['seuss', 'autobot']:
            if idx == 0:
                compile_script = None  # For the first experiment, compile the current softgym
                wait_compile = None
            else:
                compile_script = None
                wait_compile = 120  # Wait 30 seconds for the compilation to finish
        elif mode == 'ec2':
            compile_script = 'compile_1.0.sh'
            wait_compile = None
        else:
            compile_script = wait_compile = None
        if hostname.startswith('autobot') and gpu_num > 0:
            env_var = {'CUDA_VISIBLE_DEVICES': str(idx % gpu_num)}
        else:
            env_var = None
        cur_popen = run_experiment_lite(
            stub_method_call=run_task,
            variant=vv,
            mode=mode,
            dry=dry,
            use_gpu=True,
            exp_prefix=exp_prefix,
            wait_subprocess=debug,
            compile_script=compile_script,
            wait_compile=wait_compile,
            env=env_var
        )
        if cur_popen is not None:
            sub_process_popens.append(cur_popen)
        if debug:
            break


if __name__ == '__main__':
    main()
