from rl_algos.single_agent.BC.state_agent import Agent as StateAgent
from rl_algos.single_agent.BC.action_agent import Agent as ActionAgent
from rl_algos.single_agent.BC.diff_state_agent import Agent as DiffStateAgent
from rl_algos.single_agent.BC.cont_diff_state_agent import Agent as ContDiffStateAgent

from utils.misc import get_dataset, wandb_init

def bc_offline(config_dict, agent_type=''):

    env = config_dict['env']
    dataset = get_dataset(env, config_dict)


    lr_info = {'bc_lr':1e-4,
               }

    config_dict.update(lr_info)
    config_dict['ensemble_num'] = 1
    config_dict['critic_factor'] = 1

    config_dict['agent_type'] = agent_type
    config_dict['algo_name'] = f'bc_{agent_type}'

    if config_dict['use_wandb']:
        wandb_init(config_dict)

    if agent_type == 'state':
        config_dict['num_env_steps']= 100000
        agent = StateAgent(obs_dims=env.observation_space.shape[0],
                      action_dims=env.action_space.shape[0],
                      dataset=dataset,
                      **config_dict
                      )
    elif agent_type =='diff_state':
        agent = DiffStateAgent(obs_dims=env.observation_space.shape[0],
                      action_dims=env.action_space.shape[0],
                      dataset=dataset,
                      **config_dict
                      )
    elif agent_type =='cont_diff_state':
        config_dict['num_env_steps']= 100000
        agent = ContDiffStateAgent(obs_dims=env.observation_space.shape[0],
                      action_dims=env.action_space.shape[0],
                      dataset=dataset,
                      **config_dict
                      )
    elif agent_type == 'action':
        agent = ActionAgent(obs_dims=env.observation_space.shape[0],
                      action_dims=env.action_space.shape[0],
                      dataset=dataset,
                      **config_dict
                      )

    if config_dict['offline']:
        agent.train_offline(config_dict)
    else:
        agent.train_online(config_dict)

    return agent



