from .base import Benchmark
from .lotka_volterra import LotkaVolterra
from .lotka_volterra_no_rescale import LotkaVolterraNoRescale
from .slcp import SLCP
from .spatialsir import SpatialSIR
from .two_moons import TwoMoons
from .galaxies import Galaxies

# from .gw import GW


def load_benchmark(config: dict) -> Benchmark:
    if config["benchmark"] == "slcp":
        return SLCP(config["data_path"])
    elif config["benchmark"] == "spatialsir":
        return SpatialSIR(config["data_path"])
    elif config["benchmark"] == "lotka_volterra":
        return LotkaVolterra(config["data_path"])
    elif config["benchmark"] == "lotka_volterra_no_rescale":
        return LotkaVolterraNoRescale(config["data_path"])
    elif config["benchmark"] == "two_moons":
        return TwoMoons(config["data_path"])
    elif config["benchmark"] == "galaxies":
        return Galaxies(config["data_path"])
    else:
        raise NotImplementedError("Benchmark not implemented")
