import argparse

from mow.scripts.train_router import TrainRouterConfig
from mow.utils.program import Program


class PlotProgram(Program, name="plot", help="Plot the results of a model."):
    @staticmethod
    def add_arguments(parser: argparse.ArgumentParser):
        PlotProgram.build_subprogram(parser, [PlotRouterProgram], dest="target")

    @staticmethod
    def main(args: argparse.Namespace):
        PlotProgram.run_subprogram(args)


class PlotRouterArgs(argparse.Namespace):
    train_router_config: str
    output: str
    color_with_class: bool = False


class PlotRouterProgram(
    Program, args=PlotRouterArgs, name="router", help="Plot the router results."
):
    @staticmethod
    def add_arguments(parser: argparse.ArgumentParser):
        parser.add_argument(
            "train_router_config",
            help="Path to the train router config file.",
        )
        parser.add_argument(
            "--output",
            "-o",
            default="router_plot.png",
            help="Output file for the router plot (default: router_plot.png).",
        )
        parser.add_argument(
            "--color-with-class",
            "--cls",
            action="store_true",
            help="Color the nodes in the plot based on their class.",
        )

    @staticmethod
    def main(args: PlotRouterArgs):
        from mow.scripts.plot_router_results import plot_router_results

        config = TrainRouterConfig.from_file(args.train_router_config)

        plot_router_results(
            config=config,
            output=args.output,
            color_with_class=args.color_with_class,
        )
