from algorithms.sac import SAC
from algorithms.sac_byol import SAC_BYOL
from algorithms.sac_byol_sharedproj import SAC_BYOL_SharedProj
from algorithms.sac_byol_noinv import SAC_BYOL_NOINV
from algorithms.sac_byol_noinv_aug import SAC_BYOL_NOINV_AUG
from algorithms.sac_simsiam import SAC_SIMSIAM

from algorithms.rad import RAD
from algorithms.rad_byol import RAD_BYOL
from algorithms.rad_byol_sharedproj import RAD_BYOL_SharedProj
from algorithms.rad_byol_sharedproj_aug import RAD_BYOL_SharedProj_AUG
from algorithms.rad_byol_sharedproj_cmc import RAD_BYOL_SharedProj_CMC
from algorithms.rad_byol_sharedproj_noinv import RAD_BYOL_SharedProj_NOINV
from algorithms.rad_byol_aug import RAD_BYOL_AUG
from algorithms.rad_byol_noinv import RAD_BYOL_NOINV
from algorithms.rad_byol_noinv_aug import RAD_BYOL_NOINV_AUG
from algorithms.rad_simsiam import RAD_SIMSIAM

algorithm = {
	'SAC': SAC,
	'SAC_BYOL': RAD,
    'SAC_BYOL_SharedProj': SAC_BYOL_SharedProj,
    'RAD_BYOL_SharedProj_AUG': RAD_BYOL_SharedProj_AUG,
    'RAD_BYOL_SharedProj_CMC': RAD_BYOL_SharedProj_CMC,
    'RAD_BYOL_SharedProj_NOINV': RAD_BYOL_SharedProj_NOINV,
    'SAC_BYOL_NOINV': SAC_BYOL_NOINV,
    'SAC_BYOL_NOINV_AUG': SAC_BYOL_NOINV_AUG,
	'SAC_SIMSIAM': SAC_SIMSIAM,
	'RAD': RAD,
	'RAD_BYOL': RAD_BYOL,
    'RAD_BYOL_SharedProj': RAD_BYOL_SharedProj,
    'RAD_BYOL_AUG': RAD_BYOL_AUG,
    'RAD_BYOL_NOINV': RAD_BYOL_NOINV,
    'RAD_BYOL_NOINV_AUG': RAD_BYOL_NOINV_AUG,
	'RAD_SIMSIAM': RAD_SIMSIAM
}

def make_agent(obs_shape, action_shape, args, device):
    if args.agent not in algorithm.keys():
        assert 'agent is not supported: %s' % args.agent
    else:
        return algorithm[args.agent](
            obs_shape=obs_shape,
            action_shape=action_shape,
            device=device,
            hidden_dim=args.hidden_dim,
            discount=args.discount,
            init_temperature=args.init_temperature,
            alpha_lr=args.alpha_lr,
            alpha_beta=args.alpha_beta,
            actor_lr=args.actor_lr,
            actor_beta=args.actor_beta,
            actor_log_std_min=args.actor_log_std_min,
            actor_log_std_max=args.actor_log_std_max,
            actor_update_freq=args.actor_update_freq,
            critic_lr=args.critic_lr,
            critic_beta=args.critic_beta,
            critic_tau=args.critic_tau,
            critic_target_update_freq=args.critic_target_update_freq,
            encoder_type=args.encoder_type,
            encoder_feature_dim=args.encoder_feature_dim,
            encoder_lr=args.encoder_lr,
            encoder_tau=args.encoder_tau,
            num_layers=args.num_layers,
            num_filters=args.num_filters,
            log_interval=args.log_interval,
            detach_encoder=args.detach_encoder,
            latent_dim=args.latent_dim,
            data_augs=args.data_augs
        )