import os, sys, inspect, torch, numpy as np, random, importlib
sys.path.insert(0, os.path.abspath(''))


def set_seed(value):
    np.random.seed(value)
    random.seed(value)
    torch.manual_seed(value)
    return None


GYM_ENVS = ['InvertedPendulum', 'Swimmer', 'HalfCheetah']


def init_env(config_env):
    env_name = config_env.pop('env_name')
    if env_name in GYM_ENVS:
        env_path = f'envs.GymGame'
        env = getattr(importlib.import_module(env_path), 'GymGame')(env_name + '-v4', **config_env)
    else:
        env_path = f'envs.{env_name}'
        env = getattr(importlib.import_module(env_path), env_name)(**config_env)
    return env


def divide_config(config, cl, adding_args=[]):
    cl_args = inspect.getfullargspec(cl.__init__).args
    cl_args += adding_args
    config_cl = {k:v for k, v in config.items() if k in cl_args}
    config_noncl = {k:v for k, v in config.items() if k not in cl_args}
    return config_cl, config_noncl


def save_models(agents, model_names, path, prefix=''):
    for model_name in model_names:
        model_path = os.path.join(path, f'{prefix}{model_name}.pt')
        model_stat_dict = getattr(agents, model_name).state_dict()
        torch.save(model_stat_dict, model_path)
        
        
def load_models(agents, model_names, path, prefix=''):
    for model_name in model_names:
        model_path = os.path.join(path, f'{prefix}{model_name}.pt')
        model = torch.load(model_path)
        getattr(agents, model_name).load_state_dict(model)

def get_action_values(action_dim, action_min, axtion_max, action_n):
    action_axis_n = int(action_n ** (1 / action_dim))
    axis_values = [np.linspace(action_min[i], axtion_max[i], action_axis_n) 
                   for i in range(action_dim)]
    axis_grids = [axis_grid.reshape(-1) for axis_grid in np.meshgrid(*axis_values)]
    return np.array(list(zip(*axis_grids)))
