import argparse

from mow.utils.program import Program


class SearchArgs(argparse.Namespace):
    type: str
    config: str


class SearchProgram(
    Program,
    args=SearchArgs,
    name="search",
    help="Search hyperparameters for a model.",
):
    @staticmethod
    def add_arguments(parser: argparse.ArgumentParser):
        parser.add_argument(
            "type",
            choices=["router"],
            help="Type of model to search hyperparameters for",
        )
        parser.add_argument("config", help="Path to the model config file")
        parser.add_argument(
            "--name",
            default="hyperparameter_search",
            help="Name of the hyperparameter search run",
        )
        parser.add_argument(
            "--num-trials",
            "-n",
            type=int,
            default=100,
            help="Number of trials to run for hyperparameter search",
        )

    @staticmethod
    def main(args: SearchArgs):
        match args.type:
            case "router":
                from mow.scripts.train_router import (
                    TrainRouterConfig,
                    hyperparameter_search,
                )

                config = TrainRouterConfig.from_file(args.config)
                hyperparameter_search(
                    config, n_trials=args.num_trials, study_name=args.name
                )
            case "mow":
                from mow.scripts.train_mow import (
                    TrainMoWConfig,
                    hyperparameter_search,
                )

                config = TrainMoWConfig.from_file(args.config)
                hyperparameter_search(
                    config, n_trials=args.num_trials, study_name=args.name
                )
            case _:
                raise ValueError(f"Unknown model type: {args.type}")
