from .a_star_search import AStarSearch
from .differentiable_tree_search import DifferentiableTreeSearch
from .model_free_q_network import ModelFreeQNetwork
from .treeqn import TreeQN


def get_solver(args):
    if args.solver == ModelFreeQNetwork.__name__:
        return ModelFreeQNetwork(
            args.d_hidden,
        )
    elif args.solver == TreeQN.__name__:
        return TreeQN(
            args.d_hidden,
            args.depth,
        )
    elif args.solver == AStarSearch.__name__:
        return AStarSearch(
            args.d_hidden,
            args.n_trials,
            args.exploration_constant,
        )

    elif args.solver == DifferentiableTreeSearch.__name__:
        return DifferentiableTreeSearch(
            args.d_hidden,
            args.n_trials,
        )

    else:
        print(f"Invalid solver passed! Solver: {args.solver}")
        exit(0)
