import gymnasium as gym

import src.utils.torch_networks
from src.model_based_agents.CAPPO import CAPPO
from src.model_based_agents.CADQN import CADQN
from src.model_based_agents.CASAC import CASAC
from src.model_based_agents.RPC import RPC
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.env_util import make_vec_env

from torch import nn
import cProfile
import yaml
from importlib import import_module
import wandb
from wandb.integration.sb3 import WandbCallback
import sys

def load_config(config_path):
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config


def wrap_env(env_name, wraps, n_envs=1, env_config={}):
    envs = []
    for i in range(n_envs):
        for wrapper_class_str in wraps:
            env = gym.make(env_name, render_mode='rgb_array', **env_config)
            module_name, class_name = wrapper_class_str.rsplit('.', 1)
            module = import_module(module_name)
            wrapper_class = getattr(module, class_name)
            # if wrapper_class_str == "stable_baselines3.common.vec_env.DummyVecEnv" or wrapper_class_str == "stable_baselines3.common.vec_env.VecEnv":
            #     env = wrapper_class([lambda: env for i in range(n_envs)])
            # elif wrapper_class_str == "stable_baselines3.common.vec_env.VecNormalize":
            #     env = wrapper_class(env, norm_reward=False)
            # else:
            env = wrapper_class(env)
        envs.append(lambda: env)
    return envs


def train(game,algorithm, steps, k=None, seed=None):
    if seed is not None:
        set_random_seed(seed)
    else:
        set_random_seed(0)

    if algorithm == "CAPPO":
        config_file = 'parameters_cappo.yml'
    elif algorithm == "CADQN":
        config_file = 'parameters_cadqn.yml'
    elif algorithm == "CASAC":
        config_file = 'parameters_casac.yml'
    else:
        raise ValueError("Invalid algorithm name")

    config = load_config(config_file)
    params = config[game]['parameters']
    try:
        if "activation_fn" in params["policy_kwargs"]:
            params["policy_kwargs"]["activation_fn"] = eval(params["policy_kwargs"]["activation_fn"])
    except:
        pass
    if game == "roundabout-v0":
        env_config = {'config': config[game]['environment']['config']}
    else:
        env_config = {}
    # Create environment
    env_name = config[game]['environment']['name']
    render_mode = config[game]['environment']['render_mode']
    try:
        num_envs = config[game]['environment']['num_envs']
    except KeyError:
        num_envs = 1
    if 'policy_kwargs' not in params:
        params['policy_kwargs'] = {}
    if game in ["roundabout-v0", "highway-fast-v0", "DynamicObstaclesSwitch-8x8-v0", "SlipperyDistShift-v0"]:

        env = make_vec_env(game, seed=seed, n_envs=num_envs, vec_env_cls=DummyVecEnv,
                                   wrapper_class=gym.wrappers.flatten_observation.FlattenObservation,
                           env_kwargs=env_config)
    else:
    # Apply wrappers
        env = make_vec_env(game, seed=seed, n_envs=num_envs, vec_env_cls=DummyVecEnv,
                       # wrapper_class=gym.wrappers.flatten_observation.FlattenObservation,
                       env_kwargs=env_config)
    if 'stable_baselines3.common.vec_env.VecNormalize' in config[game]['wrappers']:
        env = VecNormalize(env, norm_reward=False)

    if k is not None:
        params['cm_w'] = k
    if algorithm == "CAPPO":
        model = CAPPO("MlpPolicy", env, verbose=1, tensorboard_log='./logs/model_based/'+game+'/', device="cpu", **params)
    elif algorithm == "CADQN":
        model = CADQN("MlpPolicy", env, verbose=1, tensorboard_log='./logs/model_based/'+game+'/', device="cpu", **params)
    elif algorithm == "CASAC":
        model = CASAC("MlpPolicy", env, verbose=1, tensorboard_log='./logs/model_based/'+game+'/', device="cpu", **params)
    else:
        raise ValueError("Invalid algorithm name")
    model.learn(total_timesteps=steps)
    model.save(
        game+'_'+algorithm+'_'+str(seed),
    )
    if isinstance(env, VecNormalize):
        env = model.get_env()
        env.save(game+'_'+algorithm+'_'+str(seed)+"_env.pkl")


if __name__ == "__main__":
    # Get game from command line
    game = sys.argv[1]
    # Get k from command line
    k = None
    if len(sys.argv) > 3:
        k = float(sys.argv[3])
    seed = None
    if len(sys.argv) > 4:
        seed = int(sys.argv[4])
    steps = int(sys.argv[2])
    # Get algorithm from command line
    algorithm = sys.argv[5]
    assert algorithm in ["CAPPO", "CADQN", "CASAC"]
    train(game, algorithm, steps, k, seed)



