''' A demo file telling you how to configure variant and launch experiments
    NOTE: By running this demo, it will create a `data` directory from where you run it.
'''
import os
from exptools.launching.variant import VariantLevel, make_variants, update_config
from exptools.launching.affinity import encode_affinity, quick_affinity_code
from exptools.launching.exp_launcher import run_experiments
from exptools.launching.slurm import build_slurm_resource
from exptools.launching.exp_launcher import run_on_slurm

default_config = dict(
    env= "Reacher-v2", # (Swimmer-v2, Hopper-v2, Walker2d-v2, HalfCheetah-v2, Ant-v2)
    env_kwargs= dict(
        expert_path= "",
        student_path= None,
        expert_panelty= 0,
        epsilon= 0.01,
        hacked= False,
        mode= "il",
        build_backup_student= False,
        optimizer_kwargs= dict(
            learning_rate= 0.001, # assuming Adam Optimizer
            beta1= 0.9,
            beta2= 0.999,
            epsilon= 1e-08,
            use_locking= False,
        ),
    ),
	env_type= None,
	seed= None,
	alg= "ppo2",
    algo_kwargs= dict( # some other args specified to algorithms
        nsteps=2048,
        ent_coef=0.0,
        vf_coef=0.5,
        lr=3e-4,
        max_grad_norm=0.5,
        gamma=0.99,
        lam=0.95,
        log_interval=10,
        nminibatches=4,
        noptepochs=4,
        cliprange=0.2,
        save_interval=0,
    ),
    train_attacker= True,
    train_student= True,
	num_timesteps= int(1e6), # num of iterations for training TRPO attack
	network= None, # (mlp, cnn, lstm, cnn_lstm, conv_only)
	gamestate= None,
	num_env= None,
	reward_scale= 1.0,
    pretrain_iterations= 2000,
	# save_path: change to log_dir
	save_video_interval= 0, # Save video every x steps (0 = disabled)
	save_video_length= 200, # Length of recorded video
	play= False,
	render= False,
	zero_order= True,
	trpo_lr= None,
    buffer_size= int(5e6),
    dagger_itr= int(2e3), 
    student_steps= int(1e3), # num of student training iteration in supervised learning manner
    batch_size= 64,
    num_rollouts= int(40), # num of rollouts for expert labeling 
    log_kwargs= dict(
        save_interval= 25,
        n_file_kept= 50,
    ),
    script_name= "defense/run_aril.py",
)

def main(args):
    experiment_title = "aril_experiment"

    # set up variants
    variant_levels = list()

    values = [
        # [int(5e4), ],
        # [int(5e5), ],
        # [int(5e5), ],
        [int(1e6), ],
        # [int(5e6), ],
    ]
    dir_names = ["".format() for v in values]
    keys = [("buffer_size",)]
    variant_levels.append(VariantLevel(keys, values, dir_names))

    values = [
        # [False, True, False],
        # [True, False, False],
        [True, True, False],
    ]
    dir_names = ["tAttacker{}tStudent{}".format(*v) for v in values]
    # dir_names = ["".format(*v) for v in values]
    keys = [("train_attacker",), ("train_student",), ("env_kwargs", "build_backup_student"),]
    variant_levels.append(VariantLevel(keys, values, dir_names))

    values = [
        [True, ],
        # [False, ],
    ]
    dir_names = ["physicalAttack{}".format(*v) for v in values]
    keys = [("env_kwargs", "hacked",), ]
    variant_levels.append(VariantLevel(keys, values, dir_names))

    # default_config["script_name"] = "attack/run_attack.py"

    exp_env = "Swimmer"
    # exp_env = "Hopper"
    # exp_env = "Walker"
    # exp_env = "HalfCheetah"
    # exp_env = "Ant"
    #########################################
    if exp_env == "Ant":
        # model_dir = "./data/slurm/rarl_experiment/20210216/Ant-v2/a1.00e-02"
        model_dir = "./data/dagger/Ant-v2/"
        values = [
            ["Ant-v2",
                os.path.join(model_dir, "run_0/expert_snapshot"),
                os.path.join(model_dir, "run_0/snapshot-475"),
            ],
        ]
        dir_names = ["{}".format(v[0]) for v in values]
        keys = [("env",), ("env_kwargs", "expert_path"), ("env_kwargs", "student_path"),]
        variant_levels.append(VariantLevel(keys, values, dir_names))

        values = [
            # [1.0, ], # sensory >
            # [0.75, ], # sensory ~=
            # [0.5, ], # sensory ~=
            # [0.25, ], # sensory ~=
            # [0.175, ], # sensory ~=
            # [0.15, ],
            # [0.125, ],
            # [0.1, ], # sensory <, physical >
            # [0.075, ], # physical >
            # [0.0625, ], # physical >
            # [0.06, ],
            # [0.05625, ], # physical >
            # [0.05, ], # physical >
            # [0.04,],
            # [0.03,], # physical ~<
            # [0.02,], # physical ~<
            [0.01, ], # rarl is trained under 1e-2, physical <
        ]
        dir_names = ["a{}".format(v[0]) for v in values]
        keys = [("env_kwargs", "epsilon",), ]
        variant_levels.append(VariantLevel(keys, values, dir_names))
    ##########################################
    elif exp_env == "Swimmer":
        # model_dir = os.path.abspath("./data/local/aril_experiment/20201206/tAttackerFalsetStudentTrue/env_epsilon-1e-08/Swimmer-v2")
        model_dir = os.path.abspath("./data/dagger/Swimmer-v2")
        values = [
            ["Swimmer-v2",
                os.path.join(model_dir, "run_0/expert_snapshot"),
                os.path.join(model_dir, "run_0/snapshot-699"),
            ],
        ]
        dir_names = ["{}".format(v[0]) for v in values]
        keys = [("env",), ("env_kwargs", "expert_path"), ("env_kwargs", "student_path"),]
        variant_levels.append(VariantLevel(keys, values, dir_names))

        values = [
            # [0.5, ],
            # [0.25, ],
            # [0.225, ], # selected for sensory attack, physical >
            # [0.1, ], # physical ~>
            # [0.05, ], # physical >
            # [0.015, ], # physical ~~
            [0.014, ], # physical ~~
            # [0.0135, ],
            # [0.013, ], # physical ~~
            # [0.012, ],
            # [0.011, ],
            # [0.01, ], # physical >
            # [0.0075, ], # physical >
            # [0.00625, ],
            # [0.005, ], # rarl is trained under 5e-3
        ]
        dir_names = ["a{}".format(v[0]) for v in values]
        keys = [("env_kwargs", "epsilon",), ]
        variant_levels.append(VariantLevel(keys, values, dir_names))

        values = [
            # [-0.01, 1.],
            [-0.001, 1e3],
            # [0.01,  1.],
            # [0.0,  1.],
        ]
        dir_names = ["eEnt{}vf{}".format(*v) for v in values]
        keys = [("algo_kwargs", "ent_coef"), ("algo_kwargs", "vf_coef"), ]
        variant_levels.append(VariantLevel(keys, values, dir_names))
    #########################################
    elif exp_env == "Hopper":
        # model_dir = os.path.abspath("./data/local/aril_experiment/20201206/tAttackerFalsetStudentTrue/env_epsilon-1e-08/Hopper-v2")
        # model_dir = os.path.abspath("./data/dagger/Hopper-v2")
        model_dir = os.path.abspath("./data/slurm/aril_experiment/20210208/tAttackerFalsetStudentTrue/Hopper-v2/seed-666")
        values = [
            ["Hopper-v2",
                os.path.join(model_dir, "run_0/expert_snapshot"),
                os.path.join(model_dir, "run_0/snapshot-999"),
            ],
        ]
        dir_names = ["{}".format(v[0]) for v in values]
        keys = [("env",), ("env_kwargs", "expert_path"), ("env_kwargs", "student_path"),]
        variant_levels.append(VariantLevel(keys, values, dir_names))

        values = [
            [25, 100, int(1000)],
        ]
        dir_names = ["".format() for v in values]
        keys = [("log_kwargs", "save_interval",), ("log_kwargs", "n_file_kept",), ("dagger_itr",)]
        variant_levels.append(VariantLevel(keys, values, dir_names))

        # values = [
        #     [0.1, ],
        #     # [0.025, ],
        #     # [0.01, ], # rarl is trained under 0.01
        #     # [1e-5, ],
        #     # [0.0, ],
        # ]
        # dir_names = ["a{}".format(v[0]) for v in values]
        # keys = [("env_kwargs", "epsilon",), ]
        # variant_levels.append(VariantLevel(keys, values, dir_names))
    ########################################
    elif exp_env == "Walker":
        # model_dir = os.path.abspath("./data/local/aril_experiment/20201206/tAttackerFalsetStudentTrue/env_epsilon-1e-16/Walker2d-v2")
        model_dir = os.path.abspath("./data/dagger/Walker2d-v2")
        values = [
            ["Walker2d-v2",
                os.path.join(model_dir, "run_0/expert_snapshot"),
                os.path.join(model_dir, "run_0/snapshot-700"),
            ],
        ]
        dir_names = ["{}".format(v[0]) for v in values]
        keys = [("env",), ("env_kwargs", "expert_path"), ("env_kwargs", "student_path"),]
        variant_levels.append(VariantLevel(keys, values, dir_names))

        values = [
            # [1.0, ], # >
            # [0.5, ], # >
            # [0.25, ], # >
            # [0.2, ], # >
            # [0.15, ], # >
            [0.125, ],
            # [0.1, ], # <
            # [0.01, ], # rarl is trained under 0.01
            # [0.0, ],
        ]
        dir_names = ["a{}".format(v[0]) for v in values]
        keys = [("env_kwargs", "epsilon",), ]
        variant_levels.append(VariantLevel(keys, values, dir_names))
    ########################################
    elif exp_env == "HalfCheetah":
        ### success experiments with good curves, no need to run again
        # model_dir = os.path.abspath("./data/local/aril_experiment/20201229/tAttackerFalsetStudentTrue/pre_lr-0.001/dagger_ep1e-07/HalfCheetah-v2/a1.0/seed-12138")
        model_dir = os.path.abspath("./data/dagger/HalfCheetah-v2")
        values = [
            ["HalfCheetah-v2",
                os.path.join(model_dir, "run_0/expert_snapshot"),
                os.path.join(model_dir, "run_0/snapshot-570"),
            ],
        ]
        dir_names = ["{}".format(v[0]) for v in values]
        keys = [("env",), ("env_kwargs", "expert_path"), ("env_kwargs", "student_path"),]
        variant_levels.append(VariantLevel(keys, values, dir_names))

        values = [
            # [1.5, ], # This epsilon is working in verifying ARIL
            # [1.25, ], # 1.0 is trained under rarl and workable under attack-only
            # [0.5, ], # verifying physical attack >>>
            # [0.125, ], # physical >
            # [0.05, ], # physical ~=
            [0.024,],
            # [0.0125, ], # physical <
        ]
        dir_names = ["a{}".format(v[0]) for v in values]
        keys = [("env_kwargs", "epsilon",), ]
        variant_levels.append(VariantLevel(keys, values, dir_names))
    ##########################################

    print("Experiment env name: ", exp_env)

    values = [
        [666,],
        [12138,],
        [65535,],
        [2048,],
    ]
    dir_names = ["seed-{}".format(*v) for v in values]
    keys = [("seed",), ]
    variant_levels.append(VariantLevel(keys, values, dir_names))

    # get all variants and their own log directory
    variants, log_dirs = make_variants(*variant_levels)
    for i, variant in enumerate(variants):
        variants[i] = update_config(default_config, variant)
        # if not variants[i]["train_attacker"]:
        #     variants[i]["env_kwargs"]["student_path"] = None
        
        # if variants[i]["train_student"]:
        #     # make openai log 10 times a dagger_itr
        variants[i]["num_timesteps"] = int(default_config["num_timesteps"] / default_config["dagger_itr"] * 7)
        # if not variants[i]["train_student"]:
        #     variants[i]["dagger_itr"] = 2

        if args.debug:
            variants[i]["buffer_size"] = 10000
            variants[i]["num_timesteps"] = 5e4

    if args.where == "local":
        affinity_code = encode_affinity(
            n_cpu_core=24,  # Total number to use on machine (not virtual).
            n_gpu=0,  # Total number to use on machine.
            cpu_reserved=1,  # Number CPU to reserve per GPU.
            contexts_per_gpu=6,  # e.g. 2 will put two experiments per GPU.
            gpu_per_run=0,  # For multi-GPU optimizaion.
            n_socket=None,  # Leave None for auto-detect.
            alternating=False,  # True for altenating sampler.
            set_affinity=True,  # Everything same except psutil.Process().cpu_affinity(cpus)
        )
        # NOTE: you can also use encode_affinity to specifying how to distribute each
        # experiment in your computing nodes.
        run_experiments(
            script= default_config["script_name"],
            affinity_code= affinity_code,
            experiment_title= experiment_title + ("--debug" if args.debug else ""),
            runs_per_setting= 1, # how many times to run repeated experiments
            variants= variants,
            log_dirs= log_dirs,
            debug_mode= args.debug, # if greater than 0, the launcher will run one variant in this process)
        )
    elif args.where == "slurm":
        slurm_resource = build_slurm_resource(
            mem= "16G",
            time= "3-12:00:00",
            n_gpus= 0,
            partition= "short",
            # cuda_module= "cuda-10.0",
        )
        run_on_slurm(
            script= default_config["script_name"],
            slurm_resource= slurm_resource,
            experiment_title= experiment_title + ("--debug" if args.debug else ""),
            # experiment_title= "temp_test" + ("--debug" if args.debug else ""),
            script_name= exp_env + "_exp",
            runs_per_setting= 1,
            variants= variants,
            log_dirs= log_dirs,
            debug_mode= args.debug,
        )
    else:
        raise RuntimeError("Unspecified experiment resource.")

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()

    parser.add_argument(
        '--debug', help= 'A common setting of whether to entering debug mode for remote attach',
        type= int, default= 0,
    )
    parser.add_argument(
        '--where', help= 'decide where to run the experiment',
        type= str, default= "local", choices= ["local", "slurm",],
    )

    args = parser.parse_args()
    # setup for debugging if needed
    if args.debug > 0:
        # configuration for remote attach and debug
        import ptvsd
        import sys
        ip_address = ('0.0.0.0', 6789)
        print("Process: " + " ".join(sys.argv[:]))
        print("Is waiting for attach at address: %s:%d" % ip_address, flush= True)
        # Allow other computers to attach to ptvsd at this IP address and port.
        ptvsd.enable_attach(address=ip_address, redirect_output= True)
        # Pause the program until a remote debugger is attached
        ptvsd.wait_for_attach()
        print("Process attached, start running into experiment...", flush= True)
        ptvsd.break_into_debugger()

    main(args)