import os

env_ids = ["media_streaming", "bridge_crossing", "colour_bomb_grid_world", "colour_bomb_grid_world_v2"]
property_paths = {"media_streaming" : ["./properties/media_streaming/property_1.py"],
                  "bridge_crossing" : ["./properties/bridge_crossing/property_1.py"],
                  "colour_bomb_grid_world": ["./properties/colour_bomb_grid_world/property_1.py",
                                             "./properties/colour_bomb_grid_world/property_2.py"],
                  "colour_bomb_grid_world_v2" : ["./properties/colour_bomb_grid_world_v2/property_1.py",
                                                 "./properties/colour_bomb_grid_world_v2/property_2.py",
                                                 "./properties/colour_bomb_grid_world_v2/property_4.py"],}

num_frames = {"media_streaming": 25000,
              "bridge_crossing": 300000,
              "colour_bomb_grid_world": 100000,
              "colour_bomb_grid_world_v2": 300000,}

random_action_probabilities = {"media_streaming": 0.0,
                               "bridge_crossing": 0.04,
                               "colour_bomb_grid_world": 0.1,
                               "colour_bomb_grid_world_v2": 0.1,}

cost_coeff = {"media_streaming": 100.0,
              "bridge_crossing": 10.0,
              "colour_bomb_grid_world": 10.0,
              "colour_bomb_grid_world_v2": 10.0,}

episode_lengths = {"media_streaming": 40,
                   "colour_bomb_grid_world": 100,
                   "bridge_crossing": 400,
                   "colour_bomb_grid_world_v2": 250}

discount_factors = {"media_streaming": 0.95,
                   "colour_bomb_grid_world": 0.95,
                   "bridge_crossing": 0.99,
                   "colour_bomb_grid_world_v2": 0.95}

log_every = {"media_streaming":400,
             "colour_bomb_grid_world": 1000,
             "bridge_crossing": 2000,
             "colour_bomb_grid_world_v2":2500}

satisfaction_probabilities = {"./properties/media_streaming/property_1.py" : 0.999,
                              "./properties/bridge_crossing/property_1.py": 0.85, # 0.85 is a good option (0.99 gets you the most restrictive shield but still possible)
                              "./properties/colour_bomb_grid_world/property_1.py": 0.99, # use one of [0.85, 0.99, 0.999]; 0.85 least restrictive, 0.98 medium restrictive, 0.999 very restrictive
                              "./properties/colour_bomb_grid_world/property_2.py": 0.88, # 0.88 is a good option (0.99 gets you to avoid all bombs 0.999 is equivalent to property 1 with 0.99)
                              "./properties/colour_bomb_grid_world_v2/property_1.py" : 0.98, # use one of [0.85, 0.99, 0.999]; 0.85 least restrictive, 0.98 medium restrictive, 0.999 very restrictive
                              "./properties/colour_bomb_grid_world_v2/property_2.py" : 0.85, # 0.85 is a good option (0.99 gets you to avoid all bombs 0.999 is equivalent to property 1 with 0.99)
                              "./properties/colour_bomb_grid_world_v2/property_3.py": 0.999,
                              "./properties/colour_bomb_grid_world_v2/property_4.py": 0.999,} # 0.96 is a good lower bound

num_samples = {"./properties/media_streaming/property_1.py" : 8000,
               "./properties/bridge_crossing/property_1.py": 8000,
               "./properties/colour_bomb_grid_world/property_1.py": 16000,
               "./properties/colour_bomb_grid_world/property_2.py": 8000,
               "./properties/colour_bomb_grid_world_v2/property_1.py": 16000,
               "./properties/colour_bomb_grid_world_v2/property_2.py": 8000,
               "./properties/colour_bomb_grid_world_v2/property_3.py": 16000,
               "./properties/colour_bomb_grid_world_v2/property_4.py" : 1000,}

shielding_type = ["action_cond_safe"]
model_checking_type = ["exact", "mc"]

for env_id in env_ids:
    for i, property_path in enumerate(property_paths[env_id]):
        for seed in range(10):
            logdir = f"./logdir/{env_id}/property_{i+1}/q_learning_{seed}"
            template = "python train_q_learning.py --property {} --num-frames {} --env {} --random-action-probability {} --episode-length {} --log-every {} --seed {} --logdir {} --df {}"
            if not os.path.isdir(logdir):
                os.system(template.format(
                    property_path,
                    num_frames[env_id],
                    env_id,
                    random_action_probabilities[env_id],
                    episode_lengths[env_id],
                    log_every[env_id],
                    seed,
                    logdir,
                    discount_factors[env_id]
                ))

            logdir = f"./logdir/{env_id}/property_{i+1}/modified_q_learning_{seed}"
            template = "python train_modified_q_learning.py --property {} --num-frames {} --env {} --random-action-probability {} --episode-length {} --log-every {} --seed {} --logdir {} --cost-coeff {} --df {}"
            if not os.path.isdir(logdir):
                os.system(template.format(
                    property_path,
                    num_frames[env_id],
                    env_id,
                    random_action_probabilities[env_id],
                    episode_lengths[env_id],
                    log_every[env_id],
                    seed,
                    logdir,
                    cost_coeff[env_id],
                    discount_factors[env_id]
                ))

            for sh_type in shielding_type:
                for mc_type in model_checking_type:
                    if mc_type == "exact":
                        pretrained = "--pretrained-backup"
                        approximate = "" 
                    if mc_type == "mc":
                        pretrained = ""
                        approximate = "--approximate-model"

                    logdir = f"./logdir/{env_id}/property_{i+1}/{mc_type}{'_approx' if bool(approximate) else ''}{'_pretrained' if bool(pretrained) else ''}_{sh_type}_q_learning_{seed}"
                    template = "python train_shielded_q_learning.py --property {} --num-frames {} --env {} --random-action-probability {} --episode-length {} --log-every {} --seed {} --logdir {} --model-checking-type {} {} {} --shielding-type {} --num-samples {} --sat-prob {} --prior-type identity --device-type gpu --tp-df {}"
                    if not os.path.isdir(logdir):
                        os.system(template.format(
                            property_path,
                            num_frames[env_id],
                            env_id,
                            random_action_probabilities[env_id],
                            episode_lengths[env_id],
                            log_every[env_id],
                            seed,
                            logdir,
                            mc_type,
                            pretrained,
                            approximate,
                            sh_type,
                            num_samples[property_path],
                            satisfaction_probabilities[property_path],
                            discount_factors[env_id],
                        ))
