from dataclasses import dataclass, field

import yaml
from argparse_dataclass import ArgumentParser

from wmcal.experiments import run_experiment
from wmcal.utils import set_float_dtype


@dataclass
class Args:
    cfg: str = field(metadata={"help": "Path to YAML config file"})
    redo: bool = field(
        default=False, metadata={"help": "Rerun experiment even if already completed"}
    )
    debug: bool = field(default=False, metadata={"help": "Enable debug logging"})


def main(args: Args):
    # Patch numpy to use float16 by default (for experimentation)
    set_float_dtype()

    with open(args.cfg, "r") as f:
        config_dict = yaml.safe_load(f)

    run_experiment(config_dict, redo=args.redo, debug=args.debug)


if __name__ == "__main__":
    parser = ArgumentParser(Args)
    args = parser.parse_args()
    main(args)
