from src.solvers.cg_boost import CGBoost
from src.solvers.lp_boost import LPBoost
from src.solvers.md_boost import MDBoost
from src.solvers.qrlp_boost import QRLPBoost
from src.solvers.neg_margins import NMBoost
from src.solvers.erlp_boost import ERLPBoost
from src.solvers.solver import Solver


def get_solver(solver_type: str) -> Solver:
    """
    Return Solver object corresponding to input type.

    Args:
        solver_type (str): Type of solver to return.

    Returns:
        Solver.

    Raises:
        ValueError: If weigher_type is not recognized.

    """
    weigher_map = {
        "cg_boost": CGBoost,
        "erlp_boost": ERLPBoost,
        "lp_boost": LPBoost,
        "md_boost": MDBoost,
        "qrlp_boost": QRLPBoost,
        "neg_margins": NMBoost,
    }

    if solver_type in weigher_map:
        return weigher_map[solver_type]()

    msg = f"Solver {solver_type} not recognized"
    raise ValueError(msg)
