import argparse

from graphgym.utils.agg_runs_v2 import agg_batch


def parse_args():
    """Parses the arguments."""
    parser = argparse.ArgumentParser(
        description='Train a classification model'
    )
    parser.add_argument(
        '--dir',
        dest='dir',
        help='Dir for batch of results',
        required=True,
        type=str
    )
    parser.add_argument(
        '--metric',
        dest='metric',
        help='metric to select best epoch',
        required=False,
        type=str,
        default='accuracy'
    )
    parser.add_argument(
        '--num_splits',
        dest='num_splits',
        help='Number of splits',
        required=False,
        type=int,
        default=3
    )
    parser.add_argument(
        '--epochs',
        dest='epochs',
        help='epochs to evaluate',
        required=False,
        type=str,
        default=None
    )

    return parser.parse_args()


args = parse_args()
print(f'Starting to aggregate results from: {args.dir}')
# agg_batch(args.dir, args.metric, args.epochs)

maximize_metric = {
    'accuracy': True,
    'mae': False,
    'mse': False

}

# agg_batch(dir=args.dir,
#           metric_name=args.metric,
#           maximize_metric=maximize_metric[args.metric],
#           policy='early10',
#           num_splits=args.num_splits)

agg_batch(dir=args.dir,
          metric_name=args.metric,
          maximize_metric=maximize_metric[args.metric], policy='best',
          num_splits=args.num_splits)

agg_batch(dir=args.dir,
          metric_name=args.metric,
          maximize_metric=maximize_metric[args.metric], policy='last',
          num_splits=args.num_splits)
