from rl_algos.single_agent.Revalued.diff_state_agent import Agent as DiffStateAgent
from rl_algos.single_agent.Revalued.action_agent import Agent as ActionAgent

from utils.misc import get_dataset, wandb_init
from utils import env_wrappers

def revalued(config_dict, agent_type='state'):
    

    env = config_dict['env']
    dataset = get_dataset(env, config_dict)

    lr_info = {'critic_lr':1e-4, 
                'actor_lr':1e-4,
                'tau':1e-3,
                'n_steps':1, 
                'sample_type':'double_q', 
                }


    config_dict.update(lr_info)

    exploration_info = {'epsilon':1,
                        'epsilon_min':0.05,
                        'exploration_decay':0.999}

    config_dict.update(exploration_info)

    config_dict['algo_name']=f'revalued_{agent_type}'

    if config_dict['use_wandb']:
        wandb_init(config_dict)


    obs_dims = env.observation_space.shape[0]

    if agent_type == 'diff_state':
        agent = DiffStateAgent(obs_dims=obs_dims,
                      action_dims=env.action_space.shape[0],
                      dataset=dataset,
                      **config_dict
                      )

    elif agent_type == 'action':
        assert config_dict['dm_suite'], 'Only use with DM suite environments'
        config_dict['n_steps'] = 3
        config_dict['action_bins'] = config_dict['discrete_bins']
        agent = ActionAgent(obs_dims=obs_dims,
                      action_dims=env.action_space.shape[0],
                      dataset=None,
                      **config_dict
                      )

    
    if config_dict['offline']:
        agent.train_offline(config_dict)
    else:
        agent.train_online(config_dict)


