from method.Solvers.cvxpy.cp_bmatching import BmatchingSolver
from method.Solvers.cvxpy.cp_kp import CpKPSolver
from method.Solvers.cvxpy.cp_port import CpPortfolioSolver
from method.Solvers.grb.grb_energy import ICONGrbSolver
from method.Solvers.grb.grb_knapsack import KPGrbSolver
from method.Solvers.grb.grb_qpsolver import QPGrbSolver
from method.Solvers.heuristic.dp import DPSolver
from method.Solvers.heuristic.TopKSolver import TopKSolver
from method.Solvers.neural.BudgetallocSolver import budgetallocSolver
from method.Solvers.neural.softTopkSolver import softTopkSolver


################################# Wrappers ################################################
def solver_wrapper(args, conf, problem):
    prob_solver_dict = {
        "budgetalloc": {"neural": budgetallocSolver},
        "bipartitematching": {"cvxpy": BmatchingSolver},
        "portfolio": {"cvxpy": CpPortfolioSolver},
        "cubic": {"heuristic": TopKSolver, "neural": softTopkSolver},
        "energy": {"gurobi": ICONGrbSolver},
        "knapsack": {
            "gurobi": KPGrbSolver,
            "heuristic": DPSolver,
            "qptl": QPGrbSolver,
            "cvxpy": CpKPSolver,
        },
    }
    solve_dict = {**problem.init_API(), **conf["solver"][args.solver]}
    return prob_solver_dict[args.problem][args.solver](**solve_dict)
