import os
import gym
import hydra
from iq_learn.agent.sac import SAC
from iq_learn.agent.softq import SoftQ
from iq_learn.agent.softq_models import *


def make_agent(ob_space, ac_space, args, load_agent_path=None):
    obs_dim = ob_space.shape[0]

    if args.agent.name == 'softq':
        print('--> Using Soft-Q agent')
        action_dim = ac_space.n
        # TODO: Simplify logic
        args.agent.obs_dim = obs_dim
        args.agent.action_dim = action_dim
        agent = SoftQ(obs_dim, action_dim, args.train.batch, args)
    else:
        print('--> Using SAC agent')
        action_dim = ac_space.shape[0]
        action_range = [
            float(ac_space.low.min()),
            float(ac_space.high.max())
        ]
        # TODO: Simplify logic
        args.agent.obs_dim = obs_dim
        args.agent.action_dim = action_dim
        agent = SAC(obs_dim, action_dim, action_range, args.train.batch, args)


    if load_agent_path:
        pretrain_path = hydra.utils.to_absolute_path(load_agent_path)
        if os.path.isfile(pretrain_path):
            print("=> loading pretrain '{}'".format(load_agent_path))
            agent.load(pretrain_path)
        else:
            print("[Attention]: Did not find checkpoint {}".format(load_agent_path))

    return agent
