from .comlib import *
from .test_mnist_item import test_rl_mnist2_plot_compare
from .poison_rate import test_poison_ratio
from .data_num import test_data_num

def main():
    parser = argparse.ArgumentParser(description="")
    subparsers = parser.add_subparsers(dest="command")

    subparsers.add_parser("run_iid_poison_rate")
    subparsers.add_parser("run_non_iid_poison_rate")
    subparsers.add_parser("run_iid_data_num")

    eval_parser = subparsers.add_parser("plot")
    eval_parser.add_argument("--name", type=str, required=True, help="poison_rate or data_num")

    args = parser.parse_args()

    if args.command == "run_iid_poison_rate":
        test_poison_ratio("iid")
    elif args.command == "run_non_iid_poison_rate":
        test_poison_ratio("non_iid")
    elif args.command == "run_iid_data_num":
        test_data_num()
    elif args.command == "plot":
        test_rl_mnist2_plot_compare(args.name)
    else:
        parser.print_help()

if __name__ == "__main__":
    main()
