import os
from k_level_policy_gradients.src.utils.config_loader import load_config
from k_level_policy_gradients.src.utils.slurm_launcher import Launcher
from k_level_policy_gradients.src.utils.slurm_launcher import is_local

# Experiment parameters
exp_params = {
    "env__": "sc2",  # choose between "sc2" and "mamujoco"
    "agent__": "kfacmac",  # e.g. facmac, maddpg, kfacmac, kmaddpg
    "log_wandb": False,  # only use if you have configured wandb
    "log_best_agents": False,  # store best agents upto current iteration
    "log_online_data": False,  # store metrics from training phase
}
exp_params["note"] = ""  # leave a note for each experiment

# SC2 parameters
sc2_params = {"map_name": "MMM"}
sc2_run_params = {}

# Mamujoco parameters
mamujoco_params = {
    "scenario": "HalfCheetah",
    "partitioning": "2x3",
}
mamujoco_run_params = {}

# Agents
kfacmac_params = {"k_level": 2}

# Launcher parameters
exp_name = "k2facmac_MMM"
exp_file = "run_experiment"
n_seeds = 1
n_experiments_in_parallel = n_seeds
n_cores_per_experiment = 1
n_cores = n_experiments_in_parallel * n_cores_per_experiment
memory_per_core = 4000
gres = None
partition = ""

# Launcher parameters (don't change unless you know what you are doing)
hours = 36
base_dir = os.path.dirname(os.path.realpath(__file__)) + "/results/running"
conda_env = "kpg_torch"
use_timestamp = True


def main():
    # Get all params
    exp_params = get_params()
    launcher = Launcher(
        exp_name=exp_name,
        exp_file=exp_file,
        n_seeds=n_seeds,
        n_exps_in_parallel=n_experiments_in_parallel,
        n_cores=n_cores,
        memory_per_core=memory_per_core,
        hours=hours,
        base_dir=base_dir,
        conda_env=conda_env,
        gres=gres,
        partition=partition,
        use_timestamp=use_timestamp,
    )
    launcher.add_experiment(**exp_params)
    launcher.run(local=is_local())


def get_params():
    dir_path = os.path.dirname(os.path.realpath(__file__))

    env = exp_params["env__"]
    env_params = globals().get(f"{env}_params", {})
    env_params_base = load_config(
        f"{dir_path}/k_level_policy_gradients/src/configs/env.yaml"
    )[env]
    env_params_base.update(env_params)

    env_run_params = globals().get(f"{env}_run_params", {})
    env_run_params_base = load_config(
        f"{dir_path}/k_level_policy_gradients/src/configs/env_run.yaml"
    )[env]
    env_run_params_base.update(env_run_params)

    agent = exp_params["agent__"]
    agent_params = globals().get(f"{agent}_params", {})
    agent_params_base = load_config(
        f"{dir_path}/k_level_policy_gradients/src/configs/agent.yaml"
    )[agent]
    agent_params_base.update(agent_params)

    exp_params["env_params"] = env_params_base
    exp_params["env_run_params"] = env_run_params_base
    exp_params["agent_params"] = agent_params_base
    return exp_params


if __name__ == "__main__":
    main()
