import gc
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

from marl_exp.env import env_creator
from stable_baselines3.common.vec_env import SubprocVecEnv
from marl_exp.policy import PolicyLinProgMARL
from config import SUBSTRATE_NAME, N_ENVS, RENDER, COIN_CONFIG, REWARD_TYPE, EXTRA_SW_PROP


def run_marl(substrate_name, coin_config, reward_type, save_path='',
             SW_prop=0, extra_SW_prop=0, render=False, train_principal=True, val_agent=True):
    env = SubprocVecEnv([lambda: env_creator(substrate_name, coin_config) for _ in range(N_ENVS)])
    render_env = None

    tags = []
    lr_start = 5e-4
    lr_end = 1e-4
    n_batches = 62_500
    buffer_size = 100_000
    log_freq = 2500
    n_eval_episodes = 5 * N_ENVS
    normalize_obs = False
    tags.append(f'{coin_config["grid_size"]}_{coin_config["max_steps"]}')

    policy = PolicyLinProgMARL(env,
                               render_env,
                               reward_type=reward_type,
                               SW_prop=SW_prop,
                               extra_SW_prop=extra_SW_prop,
                               hid_size=64,
                               n_hid_layers=1,
                               gamma=.99,
                               normalize_obs=normalize_obs,
                               lr_start=lr_start, lr_end=lr_end,
                               eps_start=.4, eps_end=0.,
                               batch_size=128,
                               n_batches=n_batches,
                               n_interactions=1,
                               buffer_size=buffer_size,
                               n_warm_start_batches=25,
                               target_update_freq=100,
                               log_wandb=True,
                               log_freq=log_freq,
                               n_eval_episodes=n_eval_episodes,
                               scheduling_speed=1.5,
                               train_principal=train_principal,
                               val_agent=val_agent,
                               save_path=save_path,
                               tags=tags,
                               )
    prop = policy.train()
    del policy
    gc.collect()
    return prop


if __name__ == "__main__":
    print(SUBSTRATE_NAME)
    run_marl(SUBSTRATE_NAME, COIN_CONFIG, REWARD_TYPE, extra_SW_prop=EXTRA_SW_PROP, render=RENDER)
