import argparse
import glob
import os

from mow.scripts.train_mow import MowTrainerConfig, TrainMoWConfig
from mow.utils.program import Program


class FewShotExpansionArgs(argparse.Namespace):
    config: str
    datasets: list[str]
    num_samples: int
    output_path: str
    max_steps: int
    learning_rate: float


class FewShotExpansionProgram(
    Program,
    name="expand",
    args=FewShotExpansionArgs,
    help="Expand the model with few-shot data.",
):
    @staticmethod
    def add_arguments(parser: argparse.ArgumentParser):
        parser.add_argument(
            "config",
            help="Path to the few-shot expansion config file.",
        )
        parser.add_argument(
            "--datasets",
            "-d",
            nargs="+",
            required=True,
            help="List of dataset paths to use for expansion.",
        )
        parser.add_argument(
            "--num-samples",
            "-n",
            type=int,
            default=10,
            help="Number of samples to use from each dataset (default: 10).",
        )
        parser.add_argument(
            "--output-path",
            "-o",
            default="expanded_model",
            help="Output path for the expanded model (default: expanded_model).",
        )
        parser.add_argument(
            "--max-steps",
            "--steps",
            type=int,
            default=100,
            help="Maximum number of training steps (default: 100).",
        )
        parser.add_argument(
            "--learning-rate",
            "--lr",
            type=float,
            default=1e-5,
            help="Learning rate for training (default: 1e-5).",
        )

    @staticmethod
    def main(args: FewShotExpansionArgs):
        from pathlib import Path

        from mow.scripts.few_shot_expansion import few_shot_expansion

        few_shot_expansion(
            config=TrainMoWConfig.from_file(args.config),
            datasets=sum((glob.glob(ds) for ds in args.datasets), []),
            num_samples=args.num_samples,
            output_path=Path(args.output_path),
            max_steps=args.max_steps,
            learning_rate=args.learning_rate,
        )
