from rl_algos.single_agent.TD3.agent import Agent

from utils.misc import get_dataset, wandb_init


def td3_n(config_dict):
    env = config_dict['env']
    dataset = None

    lr_info = {'critic_lr':5e-4, 
               'actor_lr':5e-4, 
               'tau':5e-3, 
               'update_ratio':1,
               }

    config_dict.update(lr_info)

    config_dict['algo_name']='td3_n'

    config_dict['critic_factor'] = 2


    if config_dict['use_wandb']:
        wandb_init(config_dict)

    ensemble_num = config_dict['ensemble_num']
    config_dict['critic_ensemble_num'] = ensemble_num
    config_dict['actor_ensemble_num'] = ensemble_num

    obs_dims = env.observation_space.shape[0]



    agent = Agent(obs_dims=obs_dims,
                  action_dims=env.action_space.shape[0],
                  dataset=dataset,
                  **config_dict
                  )


    if config_dict['offline']:
        agent.train_offline(config_dict)
    else:
        agent.train_online(config_dict)


