from rl_algos.single_agent.Combined_agent.agent import Agent

from utils.misc import get_dataset, wandb_init

def combined(config_dict, agent_type='diff_state'):
    env = config_dict['env']
    dataset = get_dataset(env, config_dict)

    lr_info = {'critic_lr':5e-4, 
		'actor_lr':5e-4,
                'bc_lr':1e-3,
		'tau':5e-3,
		'sample_type':'double_q', 
		}


    config_dict.update(lr_info)

    exploration_info = {'epsilon':1,
			'epsilon_min':0.05,
			'exploration_decay':0.999}

    config_dict.update(exploration_info)

    config_dict['algo_name'] = 'combined'

    #default setting
    config_dict['offline_sample_rate'] = 0.5
    config_dict['max_sample_rate'] = 1 
    config_dict['min_sample_rate'] = 0
    config_dict['increment_sample_rate'] = -0.1


    if config_dict['use_wandb']:
        wandb_init(config_dict)


    agent = Agent(obs_dims=env.observation_space.shape[0],
                    action_dims=env.action_space.shape[0],
                    dataset=dataset,
                    **config_dict
                    )

    if config_dict['offline']:
        agent.train_offline(config_dict)
    else:
        agent.train_online(config_dict)


