from doexp import cmd, In, Out, GLOBAL_CONTEXT
from socket import gethostname
from subprocess import run
import math
import os


HOST = gethostname()

WANDB_ENTITY = "[redacted]"
os.environ["WANDB_ENTITY"] = WANDB_ENTITY
os.environ["WANDB_PROJECT"] = "[redacted]"

seeds = list(range(10))

# Used by Tianshou
mujoco_env_names_v3 = [
    "HalfCheetah-v3",
    "Walker2d-v3",
    "Hopper-v3",
    "Swimmer-v3",
    "InvertedPendulum-v2",
    "Reacher-v2",
]

# Used by trust-region-layers
mujoco_envs = [
    "HalfCheetah-v2",
    "Hopper-v2",
    "Walker2d-v2",
    "Swimmer-v3",
    "InvertedDoublePendulum-v2",
    "Reacher-v2",
]


EPOCHS = 500  # Basically, this determines the x-axis plot resolution for Tianshou logs


MT50_ENV_NAMES = [
    "assembly",
    "basketball",
    "bin-picking",
    "box-close",
    "button-press-topdown",
    "button-press-topdown-wall",
    "button-press",
    "button-press-wall",
    "coffee-button",
    "coffee-pull",
    "coffee-push",
    "dial-turn",
    "disassemble",
    "door-close",
    "door-lock",
    "door-open",
    "door-unlock",
    "hand-insert",
    "drawer-close",
    "drawer-open",
    "faucet-open",
    "faucet-close",
    "hammer",
    "handle-press-side",
    "handle-press",
    "handle-pull-side",
    "handle-pull",
    "lever-pull",
    "peg-insert-side",
    "pick-place-wall",
    "pick-out-of-hole",
    "reach",
    "push-back",
    "push",
    "pick-place",
    "plate-slide",
    "plate-slide-side",
    "plate-slide-back",
    "plate-slide-back-side",
    "peg-unplug-side",
    "soccer",
    "stick-push",
    "stick-pull",
    "push-wall",
    "reach-wall",
    "shelf-place",
    "sweep-into",
    "sweep",
    "window-open",
    "window-close",
]

MT10_ENV_NAMES = [
    'reach-v2',
    'push-v2',
    'pick-place-v2',
    'door-open-v2',
    'drawer-open-v2',
    'drawer-close-v2',
    'button-press-topdown-v2',
    'peg-insert-side-v2',
    'window-open-v2',
    'window-close-v2',
]


def mujoco_fixpo_tianshou(
    seed, env, group, priority=None, cores=2, add_to_path=None, total_steps=None, **kwargs
):
    if total_steps is None:
        total_steps = 10_000_000
    if priority is None:
        priority = (50, -seed)
    if add_to_path is None:
        add_to_path = [k for k, _ in kwargs.items()][:5]
    kwargs_path = "_".join(
        f"{k.replace('_', '-')}={kwargs.get(k)}" for k in add_to_path
    )
    return cmd(
        "python",
        "src/mujoco_fixpo_tianshou.py",
        "--seed",
        seed,
        "--env",
        env,
        "--epoch", EPOCHS,
        "--step-per-epoch", math.ceil(total_steps / EPOCHS),
        *[f"--{k.replace('_', '-')}={v}" for (k, v) in kwargs.items()],
        "--wandb-entity", WANDB_ENTITY,
        "--wandb-group",
        group,
        "--log-dir",
        Out(f"fixpo_tianshou/env={env}_seed={seed}_{kwargs_path}_group={group}/"),
        warmup_time=3,
        ram_gb=6,
        priority=priority,
        cores=cores,
    )

def mujoco_ppo_tianshou(
    seed, env, group, priority=None, cores=2, add_to_path=None, total_steps=None, **kwargs
):
    if total_steps is None:
        total_steps = 10_000_000
    if priority is None:
        priority = (50, -seed)
    if add_to_path is None:
        add_to_path = [k for k, _ in kwargs.items()][:5]
    kwargs_path = "_".join(
        f"{k.replace('_', '-')}={kwargs.get(k)}" for k in add_to_path
    )
    return cmd(
        "python",
        "src/mujoco_ppo_tianshou.py",
        "--seed",
        seed,
        "--env",
        env,
        "--epoch", EPOCHS,
        "--step-per-epoch", math.ceil(total_steps / EPOCHS),
        *[f"--{k.replace('_', '-')}={v}" for (k, v) in kwargs.items()],
        "--wandb-entity", WANDB_ENTITY,
        "--wandb-group",
        group,
        "--log-dir",
        Out(f"ppo_tianshou/env={env}_seed={seed}_{kwargs_path}_group={group}/"),
        warmup_time=3,
        ram_gb=6,
        priority=priority,
        cores=cores,
    )

def metaworld_fixpo_tianshou(
    seed, env, group, priority=None, cores=2, add_to_path=None, total_steps=None, **kwargs
):
    if total_steps is None:
        total_steps = 20_000_000
    if priority is None:
        priority = (50, -seed)
    if add_to_path is None:
        add_to_path = [k for k, _ in kwargs.items()][:5]
    kwargs_path = "_".join(
        f"{k.replace('_', '-')}={kwargs.get(k)}" for k in add_to_path
    )
    return cmd(
        "python",
        "src/metaworld_fixpo_tianshou.py",
        "--seed",
        seed,
        "--env",
        env,
        "--epoch", EPOCHS,
        "--step-per-epoch", math.ceil(total_steps / EPOCHS),
        *[f"--{k.replace('_', '-')}={v}" for (k, v) in kwargs.items()],
        "--wandb-entity", WANDB_ENTITY,
        "--wandb-group",
        group,
        "--log-dir",
        Out(f"fixpo_tianshou/env={env}_seed={seed}_{kwargs_path}_group={group}/"),
        warmup_time=3,
        ram_gb=6,
        priority=priority,
        cores=cores,
    )

def metaworld_ppo_tianshou(
    seed, env, group, priority=None, cores=2, add_to_path=None, total_steps=None, **kwargs
):
    if total_steps is None:
        total_steps = 20_000_000
    if priority is None:
        priority = (50, -seed)
    if add_to_path is None:
        add_to_path = [k for k, _ in kwargs.items()][:5]
    kwargs_path = "_".join(
        f"{k.replace('_', '-')}={kwargs.get(k)}" for k in add_to_path
    )
    return cmd(
        "python",
        "src/metaworld_ppo_tianshou.py",
        "--seed",
        seed,
        "--env",
        env,
        "--epoch", EPOCHS,
        "--step-per-epoch", math.ceil(total_steps / EPOCHS),
        *[f"--{k.replace('_', '-')}={v}" for (k, v) in kwargs.items()],
        "--wandb-entity", WANDB_ENTITY,
        "--wandb-group",
        group,
        "--log-dir",
        Out(f"ppo_tianshou/env={env}_seed={seed}_{kwargs_path}_group={group}/"),
        warmup_time=3,
        ram_gb=6,
        priority=priority,
        cores=cores,
    )

if HOST == SLURM_HOST:
    # for seed in seeds:
    #     for env in mujoco_env_names_v3:
    #             total_steps = 10_000_000
    #             group = "fixpo-tianshou-mujoco-core-profile"
    #             cores = seed + 2
    #             cmd(
    #                 "python",
    #                 "src/mujoco_fixpo_tianshou.py",
    #                 "--seed",
    #                 seed,
    #                 "--env",
    #                 env,
    #                 "--epoch", EPOCHS,
    #                 "--step-per-epoch", math.ceil(total_steps / EPOCHS),
    #                 "--wandb-entity", WANDB_ENTITY,
    #                 "--wandb-group",
    #                 "fixpo-tianshou-mujoco-core-profile",
    #                 "--log-dir",
    #                 Out(f"fixpo_tianshou/env={env}_seed={seed}_cores={cores}_group={group}/"),
    #                 warmup_time=3,
    #                 ram_gb=6,
    #                 priority=(60, -seed),
    #                 cores=cores,
    #             )
    for seed in seeds:
        if seed < 3:
            base_priority = 70
        else:
            base_priority = 40
            continue
        for env_i, env in enumerate(MT50_ENV_NAMES):
            metaworld_fixpo_tianshou(
                seed=seed,
                env=env,
                group="fixpo-tianshou-metaworld",
                step_per_collect=10_000,
                priority=(base_priority, -seed, -env_i),
            )
        for env_i, env in enumerate(MT10_ENV_NAMES):
            group = "fixpo-tianshou-metaworld-transfer"
            cmd(
                "python",
                "src/metaworld_fixpo_tianshou.py",
                "--seed",
                seed,
                "--env",
                env,
                "--base-task-path", In(f"fixpo_tianshou/env=pick-place_seed={seed}_step-per-collect=10000_group=fixpo-tianshou-metaworld/policy.pth"),
                "--wandb-entity", WANDB_ENTITY,
                "--wandb-group", group,
                "--log-dir",
                Out(f"fixpo_tianshou/env={env}_seed={seed}_group={group}/"),
                warmup_time=3,
                ram_gb=6,
                priority=(base_priority - 5, -seed, -env_i),
                cores=2,
            )
            back_group = "fixpo-tianshou-metaworld-transfer-back"
            cmd(
                "python",
                "src/metaworld_fixpo_tianshou.py",
                "--seed",
                seed,
                "--env",
                "pick-place",
                "--base-task-path", In(f"fixpo_tianshou/env={env}_seed={seed}_group={group}/policy.pth"),
                "--wandb-entity", WANDB_ENTITY,
                "--wandb-group", group,
                "--log-dir",
                Out(f"fixpo_tianshou/env=pick-place_base_env={env}_seed={seed}_group={back_group}/"),
                warmup_time=3,
                ram_gb=6,
                priority=(base_priority - 5, -seed, -env_i),
                cores=2,
            )

            # metaworld_fixpo_tianshou(
            #     seed=seed,
            #     env=env,
            #     group="fixpo-tianshou-metaworld",
            #     step_per_collect=50_000,
            #     priority=(61, -seed, -env_i),
            # )
            # metaworld_fixpo_tianshou(
            #     seed=seed,
            #     env=env,
            #     group="fixpo-tianshou-metaworld",
            #     fixup_every_repeat=0,
            #     priority=(60, -seed, -env_i),
            # )
            # metaworld_ppo_tianshou(
            #     seed=seed,
            #     env=env,
            #     group="ppo-tianshou-metaworld",
            #     priority=(61, -seed, -env_i),
            # )
            metaworld_ppo_tianshou(
                seed=seed,
                env=env,
                group="ppo-tianshou-metaworld",
                max_grad_norm=0.1,
                step_per_collect=10_000,
                priority=(50, -seed, -env_i),
            )

    for seed in seeds:
        for env_i, env in enumerate(mujoco_env_names_v3):

            mujoco_ppo_tianshou(
                seed=seed,
                env=env,
                group="ppo-tianshou-mujoco",
                priority=(70, -env_i, -seed))

            target_coeff = 3
            mujoco_fixpo_tianshou(
                seed=seed,
                env=env,
                group="fixpo-tianshou-mujoco",
                target_coeff=target_coeff,
                priority=(60, -env_i, -seed))
            mujoco_fixpo_tianshou(
                seed=seed,
                env=env,
                group="fixpo-tianshou-mujoco",
                fixup_loop=0,
                target_coeff=target_coeff,
                priority=(60, -env_i, -seed))
            mujoco_fixpo_tianshou(
                seed=seed,
                env=env,
                group="fixpo-tianshou-mujoco",
                fixup_every_repeat=0,
                target_coeff=target_coeff,
                priority=(60, -env_i, -seed))
            mujoco_fixpo_tianshou(
                seed=seed,
                env=env,
                group="fixpo-tianshou-mujoco",
                kl_target_stat="mean",
                target_coeff=target_coeff,
                priority=(60, -env_i, -seed))
            mujoco_fixpo_tianshou(
                seed=seed,
                env=env,
                group="fixpo-tianshou-mujoco",
                kl_target_stat="max",
                target_coeff=1,
                priority=(60, -env_i, -seed))
            mujoco_fixpo_tianshou(
                seed=seed,
                env=env,
                group="fixpo-tianshou-mujoco",
                init_beta=10,
                beta_lr=0,
                priority=(60, -env_i, -seed))

            mujoco_fixpo_tianshou(
                seed=seed,
                env=env,
                group="fixpo-tianshou-mujoco",
                kl_target_stat="mean",
                fixup_loop=0,
                target_coeff=target_coeff,
                priority=(30, -env_i, -seed))

            # for eps_kl_args in [{"eps_kl": 0.2}, {}, {"eps_kl": 1.0}]:
            #     for target_coeff in [2, 3, 5]:
            #         mujoco_fixpo_tianshou(
            #             seed=seed,
            #             env=env,
            #             group="fixpo-tianshou-mujoco",
            #             target_coeff=target_coeff,
            #             **eps_kl_args,
            #             priority=(50, -env_i, -seed, -target_coeff))
            #         mujoco_fixpo_tianshou(
            #             seed=seed,
            #             env=env,
            #             group="fixpo-tianshou-mujoco",
            #             fixup_every_repeat=0,
            #             target_coeff=target_coeff,
            #             **eps_kl_args,
            #             priority=(51, -env_i, -seed))
            #     mujoco_fixpo_tianshou(
            #         seed=seed,
            #         env=env,
            #         group="fixpo-tianshou-mujoco",
            #         fixup_loop=0,
            #         **eps_kl_args,
            #         priority=(50, -env_i, -seed))
            #     mujoco_fixpo_tianshou(
            #         seed=seed,
            #         env=env,
            #         group="fixpo-tianshou-mujoco",
            #         kl_target_stat="mean",
            #         **eps_kl_args,
            #         priority=(50, -env_i, -seed))

with open('trust-region-layers/configs/pg/mujoco_kl_config.json') as f:
    kl_config = json.load(f)

with open('trust-region-layers/configs/pg/mujoco_papi_config.json') as f:
    papi_conf = json.load(f)

for seed in seeds:
    for env in mujoco_envs_remaining:
        conf_name = f"data_tmp/trust-region-layers_kl_seed={seed}_env={env}.json"
        out_dir = f"trust-region-layers_kl_seed={seed}_env={env}/"
        kl_config["n_envs"] = 1  # Should only affect sampling speed
        kl_config["seed"] = seed
        kl_config["game"] = env
        kl_config["out_dir"] = f"data_tmp/{out_dir}"
        kl_config["exp_name"] = f'seed_{seed}_env_{env}_{datetime.now().strftime("%Y%m%d_%H%M%S_%f")}'
        with open(conf_name, 'w') as f:
            json.dump(kl_config, f, indent=2)
        cmd("python", "trust-region-layers/main.py", conf_name, extra_outputs=[Out(out_dir)],
            cores=1, ram_gb=6, priority=(50, -seed))

        conf_name = f"data_tmp/trust-region-layers_papi_seed={seed}_env={env}.json"
        out_dir = f"trust-region-layers_papi_seed={seed}_env={env}/"
        papi_conf["n_envs"] = 1  # Should only affect sampling speed
        papi_conf["seed"] = seed
        papi_conf["game"] = env
        papi_conf["out_dir"] = f"data_tmp/{out_dir}"
        papi_conf["exp_name"] = f'seed_{seed}_env_{env}_{datetime.now().strftime("%Y%m%d_%H%M%S_%f")}'
        with open(conf_name, 'w') as f:
            json.dump(papi_conf, f, indent=2)
        cmd("python", "trust-region-layers/main.py", conf_name, extra_outputs=[Out(out_dir)],
            cores=1, ram_gb=6, priority=(50, -seed))

for seed in seeds:
    for env_i, env in enumerate(MT50_ENV_NAMES):
        env = "metaworld-" + env
        mt_kl_conf = kl_config.copy()
        conf_name = f"data_tmp/trust-region-layers_kl_seed={seed}_env={env}_logged.json"
        out_dir = f"trust-region-layers_kl_seed={seed}_env={env}/"
        mt_kl_conf["n_envs"] = 10  # Should only affect sampling speed
        mt_kl_conf["n_test_envs"] = 10  # Should only affect sampling speed
        mt_kl_conf["seed"] = seed
        mt_kl_conf["game"] = env
        mt_kl_conf["out_dir"] = f"data_tmp/{out_dir}"
        mt_kl_conf["exp_name"] = f'seed_{seed}_env_{env}_{datetime.now().strftime("%Y%m%d_%H%M%S_%f")}'

        # Meta-World Specific
        mt_kl_conf["rollout_steps"] = 50_000
        mt_kl_conf["max_entropy_coeff"] = 0.01
        mt_kl_conf["max_episode_length"] = 500
        mt_kl_conf["epochs"] = 10
        mt_kl_conf["train_steps"] = 200
        mt_kl_conf["hidden_sizes_policy"] = [128, 128]
        mt_kl_conf["hidden_sizes_vf"] = [128, 128]

        with open(conf_name, 'w') as f:
            json.dump(mt_kl_conf, f, indent=2)
        cmd("python", "trust-region-layers/main.py", conf_name, "--wandb-group=trust-region-layers-kl-metaworld-logged", extra_outputs=[Out(out_dir)],
            cores=2, ram_gb=8, priority=(20, -seed, -env_i))
