# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Command line argument parsing"""
import argparse
from munch import Munch

PARSER = argparse.ArgumentParser(description="UNet-medical")

PARSER.add_argument('--exec_mode',
                    choices=['train', 'train_and_predict', 'predict', 'evaluate', 'train_and_evaluate'],
                    type=str,
                    default='train_and_evaluate',
                    help="""Execution mode of running the model""")

PARSER.add_argument('--model_dir',
                    type=str,
                    default='/results',
                    help="""Output directory for information related to the model""")

PARSER.add_argument('--data_dir',
                    type=str,
                    help="""Input directory containing the dataset for training the model""")

PARSER.add_argument('--log_dir',
                    type=str,
                    default=None,
                    help="""Output directory for training logs""")

PARSER.add_argument('--batch_size',
                    type=int,
                    default=1,
                    help="""Size of each minibatch per GPU""")

PARSER.add_argument('--learning_rate',
                    type=float,
                    default=0.0001,
                    help="""Learning rate coefficient for AdamOptimizer""")

PARSER.add_argument('--fold',
                    type=int,
                    default=None,
                    help="""Chosen fold for cross-validation. Use None to disable cross-validation""")

PARSER.add_argument('--max_steps',
                    type=int,
                    default=1000,
                    help="""Maximum number of steps (batches) used for training""")

PARSER.add_argument('--weight_decay',
                    type=float,
                    default=0.0005,
                    help="""Weight decay coefficient""")

PARSER.add_argument('--log_every',
                    type=int,
                    default=100,
                    help="""Log performance every n steps""")

PARSER.add_argument('--evaluate_every',
                    type=int,
                    default=0,
                    help="""Log performance every n steps""")

PARSER.add_argument('--warmup_steps',
                    type=int,
                    default=200,
                    help="""Number of warmup steps""")

PARSER.add_argument('--seed',
                    type=int,
                    default=0,
                    help="""Random seed""")

PARSER.add_argument('--augment', dest='augment', action='store_true',
                    help="""Perform data augmentation during training""")
PARSER.add_argument('--no-augment', dest='augment', action='store_false')
PARSER.set_defaults(augment=False)

PARSER.add_argument('--benchmark', dest='benchmark', action='store_true',
                    help="""Collect performance metrics during training""")
PARSER.add_argument('--no-benchmark', dest='benchmark', action='store_false')
PARSER.set_defaults(augment=False)

PARSER.add_argument('--use_amp', '--amp', dest='use_amp', action='store_true',
                    help="""Train using TF-AMP""")

PARSER.add_argument('--use_xla', '--xla', dest='use_xla', action='store_true',
                    help="""Train using XLA""")

PARSER.add_argument('--use_trt', dest='use_trt', action='store_true',
                    help="""Use TF-TRT""")

PARSER.add_argument('--resume_training', dest='resume_training', action='store_true',
                    help="""Resume training from a checkpoint""")


def parse_args(flags):
    return Munch({
        'exec_mode': flags.exec_mode,
        'model_dir': flags.model_dir,
        'data_dir': flags.data_dir,
        'log_dir': flags.log_dir,
        'batch_size': flags.batch_size,
        'learning_rate': flags.learning_rate,
        'fold': flags.fold,
        'max_steps': flags.max_steps,
        'weight_decay': flags.weight_decay,
        'log_every': flags.log_every,
        'evaluate_every': flags.evaluate_every,
        'warmup_steps': flags.warmup_steps,
        'augment': flags.augment,
        'benchmark': flags.benchmark,
        'seed': flags.seed,
        'use_amp': flags.use_amp,
        'use_trt': flags.use_trt,
        'use_xla': flags.use_xla,
        'resume_training': flags.resume_training,
    })
