from rl_algos.single_agent.DecQN.action_agent import Agent as ActionAgent

from utils.misc import get_dataset, wandb_init
from utils import env_wrappers

def decqn(config_dict, agent_type='state'):


    env = config_dict['env']

    lr_info = {'critic_lr':5e-5,
                'actor_lr':1e-4,
                'tau':1e-3,
                'n_steps':3,
                }


    config_dict['critic_factor'] = 1

    config_dict.update(lr_info)

    exploration_info = {'epsilon':1,
                        'epsilon_min':0.05,
                        'exploration_decay':0.99995}

    config_dict.update(exploration_info)

    config_dict['agent_type'] = agent_type
    config_dict['algo_name']=f'decqn_{agent_type}'

    if config_dict['use_wandb']:
        wandb_init(config_dict)


    agent = ActionAgent(obs_dims=env.observation_space.shape[0],
                  action_dims=env.action_space.shape[0],
                  dataset=None,
                  **config_dict
                  )

    if config_dict['offline']:
        agent.train_offline(config_dict)
    else:
        agent.train_online(config_dict)


                
