import argparse
from typing import Any

from src.fid_utils.fid.multi_step import run_from_config as run_fid_multi_step
from .config import get_configs


def run(
        index: int,
        model_load_path: str,
        save_folder: str = None,
        model_load_keys: list[str] = None,
) -> None:
    configs: list[dict[str, Any]] = get_configs(
        model_load_path=model_load_path,
        save_folder=save_folder,
        model_load_keys=model_load_keys
    )
    run_fid_multi_step(configs[index])


def run_from_config(config: dict[str, Any]) -> None:
    run(**config)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument('--index', type=int, required=True)
    parser.add_argument('--model_load_path', type=str, required=True)
    parser.add_argument('--save_folder', type=str, required=False, default=None)
    parser.add_argument('--model_load_keys', type=str, nargs='+', required=False, default=None)
    return parser.parse_args()


def get_config_from_args(args: argparse.Namespace) -> dict[str, Any]:
    return vars(args)


def main() -> None:
    args: argparse.Namespace = parse_args()
    config: dict[str, Any] = get_config_from_args(args)
    run_from_config(config)


if __name__ == '__main__':
    main()
