import argparse

from mow.utils.program import Program


class TrainArgs(argparse.Namespace):
    type: str
    config: str


class TrainProgram(
    Program, args=TrainArgs, name="train", help="Train a model."
):
    @staticmethod
    def add_arguments(parser: argparse.ArgumentParser):
        parser.add_argument(
            "type",
            choices=["expert", "mow", "router"],
            help="Type of model to train",
        )
        parser.add_argument("config", help="Path to the model config file")

    @staticmethod
    def main(args: TrainArgs):
        match args.type:
            case "expert":
                from mow.scripts.train_expert import (
                    TrainExpertConfig,
                    train_expert,
                )

                config = TrainExpertConfig.from_file(args.config)
                train_expert(config)
            case "mow":
                from mow.scripts.train_mow import TrainMoWConfig, train_mow

                config = TrainMoWConfig.from_file(args.config)
                train_mow(config)
            case "router":
                from mow.scripts.train_router import (
                    TrainRouterConfig,
                    train_router,
                )

                config = TrainRouterConfig.from_file(args.config)
                train_router(config)
            case _:
                raise ValueError(f"Unknown model type: {args.type}")
