from rl_algos.single_agent.SAC_N.agent import Agent
from utils.misc import get_dataset, wandb_init

def sac_n(config_dict):
    env = config_dict['env']

    dataset = None
    config_dict['is_continuous'] =  True

    config_dict['min_val']=config_dict['env'].action_space.low
    config_dict['max_val']=config_dict['env'].action_space.high

    lr_info = {'critic_lr':7e-5, 
               'actor_lr':7e-5,
               'tau':1e-3, ## for target network
               'learnable_temperature':True,
               'alpha_lr':1e-3,
               'n_steps':1,
               }
    config_dict.update(lr_info)

    config_dict['algo_name']='sac_n'

    config_dict['critic_factor'] = 2

    if config_dict['use_wandb']:
        wandb_init(config_dict)

    ensemble_num = config_dict['ensemble_num']
    config_dict['critic_ensemble_num'] = ensemble_num
    config_dict['actor_ensemble_num'] = ensemble_num

    agent = Agent(obs_dims=env.observation_space.shape[0],
                  action_dims=env.action_space.shape[0],
                  dataset=dataset,
                  **config_dict
                  )

    if config_dict['offline']:
        agent.train_offline(config_dict)
    else:
        agent.train_online(config_dict)


