''' A demo file telling you how to configure variant and launch experiments
    NOTE: By running this demo, it will create a `data` directory from where you run it.
'''
from exptools.launching.variant import VariantLevel, make_variants, update_config
from exptools.launching.affinity import encode_affinity, quick_affinity_code
from exptools.launching.exp_launcher import run_experiments
from exptools.launching.slurm import build_slurm_resource
from exptools.launching.exp_launcher import run_on_slurm

default_config = dict(
    env= "Reacher-v2", # (Swimmer-v2, Hopper-v2, Walker2d-v2, HalfCheetah-v2)
	env_type= None,
    env_kwargs= dict(
        attacker_space= 1e-4, # the action space for attacker somewhere call it `epsilon`
        random_seed= 666,
        be_attacked_prob= 0.5, # to make the victim robust, you have to decide the probability to attacking victim when the environment is for victim
    ),
	alg= "ppo2",
    algo_kwargs= dict( # some other args specified to algorithms
        nsteps=2048,
        ent_coef=0.0,
        vf_coef=0.5,
        lr=3e-4,
        max_grad_norm=0.5,
        gamma=0.99,
        lam=0.95,
        log_interval=10,
        nminibatches=4,
        noptepochs=4,
        cliprange=0.2,
        save_interval=0,
        seed= 666,
    ),
    save_video_interval= 0,
    network= None,
    attacker_name= "attacker", # for tensorflow variable scope
    victim_name= "victim", # for tensorflow variable scope
    alternate_itr= 100,
    attacker_inner_timesteps= int(1e6),
    victim_inner_timesteps= int(1e6),
    log_kwargs= dict(
        save_interval= 1,
        n_file_kept= 1000,
    )
)

def main(args):
    experiment_title = "rarl_experiment"

    # set up variants
    variant_levels = list()

    values = [
        ["expert", 2048 * int(2e3), 2048 * int(2e3)],
    ]
    dir_names = ["" for v in values]
    keys = [("victim_name",), ("attacker_inner_timesteps",), ("victim_inner_timesteps",)]
    variant_levels.append(VariantLevel(keys, values, dir_names))

    exp_env = "Ant"
    values = [
        # ["Swimmer-v2"],
        # ["Hopper-v2"], # X
        # ["Walker2d-v2"], # X
        # ["HalfCheetah-v2"],
        [exp_env + "-v2", ],
    ]
    dir_names = ["{}".format(*v) for v in values]
    keys = [("env",)]
    variant_levels.append(VariantLevel(keys, values, dir_names))
    
    values = [
        # [10.0,],
        # [5.0,],
        # [1.0,],
        [0.5,],
        [1e-2,],
        [5e-3,],
        [1e-3,],
        # [5e-4,],
        # [1e-4,],
    ]
    dir_names = ["a{0:.2e}".format(*v) for v in values]
    keys = [
        ("env_kwargs", "attacker_space"),
    ]
    variant_levels.append(VariantLevel(keys, values, dir_names))

    # get all variants and their own log directory
    variants, log_dirs = make_variants(*variant_levels)
    for i, variant in enumerate(variants):
        variants[i] = update_config(default_config, variant)
        if args.debug:
            variants[i]["algo_kwargs"]["save_interval"] = 2
            variants[i]["attacker_inner_timesteps"] = 40960
            variants[i]["victim_inner_timesteps"] = 40960

    if args.where == "local":
        affinity_code = encode_affinity(
            n_cpu_core=28,  # Total number to use on machine (not virtual).
            n_gpu=4,  # Total number to use on machine.
            cpu_reserved=1,  # Number CPU to reserve per GPU.
            contexts_per_gpu=7,  # e.g. 2 will put two experiments per GPU.
            gpu_per_run=1,  # For multi-GPU optimizaion.
            n_socket=None,  # Leave None for auto-detect.
            alternating=False,  # True for altenating sampler.
            set_affinity=True,  # Everything same except psutil.Process().cpu_affinity(cpus)
        )
        run_experiments(
            script= "defense/run_rarl.py",
            affinity_code= affinity_code,
            experiment_title= experiment_title + ("--debug" if args.debug else ""),
            runs_per_setting= 1, # how many times to run repeated experiments
            variants= variants,
            log_dirs= log_dirs,
            debug_mode= args.debug, # if greater than 0, the launcher will run one variant in this process)
        )
    elif args.where == "slurm":
        slurm_resource = build_slurm_resource(
            mem= "16G",
            time= "3-12:00:00",
            n_gpus= 1,
            partition= "short",
            cuda_module= "cuda-10.0",
        )
        run_on_slurm(
            script= "defense/run_rarl.py",
            slurm_resource= slurm_resource,
            experiment_title= experiment_title + ("--debug" if args.debug else ""),
            # experiment_title= "temp_test" + ("--debug" if args.debug else ""),
            script_name= exp_env + "_exp",
            runs_per_setting= 1,
            variants= variants,
            log_dirs= log_dirs,
            debug_mode= args.debug,
        )
    else:
        raise RuntimeError("Unspecified experiment resource.")

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()

    parser.add_argument(
        '--debug', help= 'A common setting of whether to entering debug mode for remote attach',
        type= int, default= 0,
    )
    parser.add_argument(
        '--where', help= 'decide where to run the experiment',
        type= str, default= "local", choices= ["local", "slurm",],
    )

    args = parser.parse_args()
    # setup for debugging if needed
    if args.debug > 0:
        # configuration for remote attach and debug
        import ptvsd
        import sys
        ip_address = ('0.0.0.0', 6789)
        print("Process: " + " ".join(sys.argv[:]))
        print("Is waiting for attach at address: %s:%d" % ip_address, flush= True)
        # Allow other computers to attach to ptvsd at this IP address and port.
        ptvsd.enable_attach(address=ip_address, redirect_output= True)
        # Pause the program until a remote debugger is attached
        ptvsd.wait_for_attach()
        print("Process attached, start running into experiment...", flush= True)
        ptvsd.break_into_debugger()

    main(args)