# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os

import torch

from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY
from fairseq.criterions import CRITERION_REGISTRY
from fairseq.optim import OPTIMIZER_REGISTRY
from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY


def get_training_parser():
    parser = get_parser('Trainer')
    add_dataset_args(parser, train=True, gen=True)
    add_model_args(parser)
    add_optimization_args(parser)
    add_checkpoint_args(parser)
    add_inference_args(parser)
    add_perf_args(parser)
    add_lpmm_args(parser)
    return parser


def get_inference_parser():
    parser = get_parser('Generation')
    add_dataset_args(parser, gen=True)
    add_inference_args(parser)
    add_perf_args(parser)
    return parser


def parse_args_and_arch(parser, input_args=None, parse_known=False):
    # The parser doesn't know about model/criterion/optimizer-specific args, so
    # we parse twice. First we parse the model/criterion/optimizer, then we
    # parse a second time after adding the *-specific arguments.
    # If input_args is given, we will parse those args instead of sys.argv.
    args, _ = parser.parse_known_args(input_args)

    # Add model-specific args to parser.
    if hasattr(args, 'arch'):
        model_specific_group = parser.add_argument_group(
            'Model-specific configuration',
            # Only include attributes which are explicitly given as command-line
            # arguments or which have default values.
            argument_default=argparse.SUPPRESS,
        )
        ARCH_MODEL_REGISTRY[args.arch].add_args(model_specific_group)

    # Add *-specific args to parser.
    # if hasattr(args, 'optimizer'):
    #     OPTIMIZER_REGISTRY[args.optimizer].add_args(parser)
    if hasattr(args, 'lr_scheduler'):
        LR_SCHEDULER_REGISTRY[args.lr_scheduler].add_args(parser)

    # Parse a second time.
    if parse_known:
        args, extra = parser.parse_known_args(input_args)
    else:
        args = parser.parse_args(input_args)
        extra = None

    # Post-process args.
    if hasattr(args, 'max_sentences_valid') and args.max_sentences_valid is None:
        args.max_sentences_valid = args.max_sentences

    args.max_positions = (args.max_source_positions, args.max_target_positions)

    if hasattr(args, 'target_bleu') and (args.online_eval or args.target_bleu) and not args.remove_bpe:
        args.remove_bpe = '@@ '

    # Apply architecture configuration.
    if hasattr(args, 'arch'):
        ARCH_CONFIG_REGISTRY[args.arch](args)

    if parse_known:
        return args, extra
    else:
        return args


def get_parser(desc):
    parser = argparse.ArgumentParser(
        description='Facebook AI Research Sequence-to-Sequence Toolkit -- ' + desc)
    parser.add_argument('--log-interval', type=int, default=500, metavar='N',
                        help='print aggregated stats and flush json log every N iteration')
    parser.add_argument('--seed', default=1, type=int, metavar='N',
                        help='pseudo random number generator seed')
    parser.add_argument('--amp', action='store_true',
                        help='use Automatic Mixed Precision')
    parser.add_argument('--stat-file', type=str, default='run_log.json',
                        help='Name of the file containing DLLogger output')
    parser.add_argument('--save-dir', metavar='DIR', default='results',
                        help='path to save checkpoints and logs')
    parser.add_argument('--do-sanity-check', action='store_true',
                        help='Perform evaluation on test set before running the training')

    return parser


def add_dataset_args(parser, train=False, gen=False):
    group = parser.add_argument_group('Dataset and data loading')
    group.add_argument('--max-tokens', type=int, metavar='N',
                       help='maximum number of tokens in a batch')
    group.add_argument('--max-sentences', '--batch-size', type=int, metavar='N',
                       help='maximum number of sentences in a batch')
    parser.add_argument('-s', '--source-lang', default=None, metavar='SRC',
                        help='source language')
    parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
                        help='target language')
    parser.add_argument('--raw-text', action='store_true',
                        help='load raw text dataset')
    parser.add_argument('--left-pad-source', default=True, type=bool, metavar='BOOL',
                        help='pad the source on the left (default: True)')
    parser.add_argument('--left-pad-target', default=False, type=bool, metavar='BOOL',
                        help='pad the target on the left (default: False)')
    parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N',
                        help='max number of tokens in the source sequence')
    parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
                        help='max number of tokens in the target sequence')
    parser.add_argument('--pad-sequence', default=1, type=int, metavar='N',
                        help='Pad sequences to a multiple of N')
    if train:
        parser.add_argument('data', metavar='DIR', help='path to data directory')
        group.add_argument('--train-subset', default='train', metavar='SPLIT',
                           choices=['train', 'valid', 'test'],
                           help='data subset to use for training (train, valid, test)')
        group.add_argument('--valid-subset', default='valid', metavar='SPLIT',
                           help='comma separated list of data subsets to use for validation'
                                ' (train, valid, valid1, test, test1)')
        group.add_argument('--max-sentences-valid', type=int, metavar='N',
                           help='maximum number of sentences in a validation batch'
                                ' (defaults to --max-sentences)')
    if gen:
        group.add_argument('--gen-subset', default='test', metavar='SPLIT',
                           help='data subset to generate (train, valid, test)')
        group.add_argument('--num-shards', default=1, type=int, metavar='N',
                           help='shard generation over N shards')
        group.add_argument('--shard-id', default=0, type=int, metavar='ID',
                           help='id of the shard to generate (id < num_shards)')
    return group

def add_optimization_args(parser):
    group = parser.add_argument_group('Optimization')
    group.add_argument('--max-epoch', '--me', default=0, type=int, metavar='N',
                       help='force stop training at specified epoch')
    group.add_argument('--max-update', '--mu', default=0, type=int, metavar='N',
                       help='force stop training at specified update')
    group.add_argument('--target-bleu', default=0.0, type=float, metavar='TARGET',
                       help='force stop training after reaching target bleu')
    group.add_argument('--clip-norm', default=25, type=float, metavar='NORM',
                       help='clip threshold of gradients')
    group.add_argument('--update-freq', default=[1], nargs='+', type=int,
                       help='update parameters every N_i batches, when in epoch i')

    # Optimizer definitions can be found under fairseq/optim/
    group.add_argument('--optimizer', default='nag', metavar='OPT',
                       help='optimizer: {} (default: nag)'.format(', '.join(OPTIMIZER_REGISTRY.keys())))
    group.add_argument('--lr', '--learning-rate', default=[0.25], nargs='+', type=float,
                       help='learning rate for the first N epochs; all epochs >N using LR_N'
                            ' (note: this may be interpreted differently depending on --lr-scheduler)')
    group.add_argument('--momentum', default=0.99, type=float, metavar='M',
                       help='momentum factor')
    group.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
                       help='weight decay')

    # Learning rate schedulers can be found under fairseq/optim/lr_scheduler/
    group.add_argument('--lr-scheduler', default='reduce_lr_on_plateau',
                       help='learning rate scheduler: {} (default: reduce_lr_on_plateau)'.format(
                           ', '.join(LR_SCHEDULER_REGISTRY.keys())))
    group.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS',
                       help='learning rate shrink factor for annealing, lr_new = (lr * lr_shrink)')
    group.add_argument('--min-lr', default=1e-5, type=float, metavar='LR',
                       help='minimum learning rate')
    # adam betas and eps
    group.add_argument('--adam-betas', default=(0.9, 0.999), nargs=2, type=float, metavar='B1 B2',
                        help='betas for Adam optimizer')
    group.add_argument('--adam-eps', type=float, default=1e-8, metavar='D',
                        help='epsilon for Adam optimizer')

    # Criterion args
    parser.add_argument('--label-smoothing', default=0., type=float, metavar='D',
                        help='epsilon for label smoothing, 0 means no label smoothing')

    return group


def add_checkpoint_args(parser):
    group = parser.add_argument_group('Checkpointing')
    group.add_argument('--restore-file', default='checkpoint_last.pt',
                       help='filename in save-dir from which to load checkpoint')
    group.add_argument('--save-interval', type=int, default=1, metavar='N',
                       help='save a checkpoint every N epochs')
    group.add_argument('--no-save', action='store_true',
                       help='don\'t save models or checkpoints')
    group.add_argument('--no-epoch-checkpoints', action='store_true',
                       help='only store last and best checkpoints')
    group.add_argument('--validate-interval', type=int, default=1, metavar='N',
                       help='validate every N epochs')
    return group


def add_common_eval_args(group):
    group.add_argument('--path', metavar='FILE',
                       help='path(s) to model file(s), colon separated')
    group.add_argument('--file', metavar='FILE', default=None, type=str,
                       help='path to a file with input data for inference')
    group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None,
                       help='remove BPE tokens before scoring')
    group.add_argument('--cpu', action='store_true', help='generate on CPU')
    group.add_argument('--quiet', action='store_true',
                       help='only print final scores')


def add_inference_args(parser):
    group = parser.add_argument_group('Generation')
    add_common_eval_args(group)
    group.add_argument('--beam', default=4, type=int, metavar='N',
                       help='beam size')
    group.add_argument('--nbest', default=1, type=int, metavar='N',
                       help='number of hypotheses to output')
    group.add_argument('--max-len-a', default=0, type=float, metavar='N',
                       help=('generate sequences of maximum length ax + b, '
                             'where x is the source length'))
    group.add_argument('--max-len-b', default=200, type=int, metavar='N',
                       help=('generate sequences of maximum length ax + b, '
                             'where x is the source length'))
    group.add_argument('--min-len', default=1, type=float, metavar='N',
                       help=('minimum generation length'))
    group.add_argument('--no-early-stop', action='store_true',
                       help=('continue searching even after finalizing k=beam '
                             'hypotheses; this is more correct, but increases '
                             'generation time by 50%%'))
    group.add_argument('--unnormalized', action='store_true',
                       help='compare unnormalized hypothesis scores')
    group.add_argument('--no-beamable-mm', action='store_true',
                       help='don\'t use BeamableMM in attention layers')
    group.add_argument('--lenpen', default=1, type=float,
                       help='length penalty: <1.0 favors shorter, >1.0 favors longer sentences')
    group.add_argument('--unkpen', default=0, type=float,
                       help='unknown word penalty: <0 produces more unks, >0 produces fewer')
    group.add_argument('--replace-unk', nargs='?', const=True, default=None,
                       help='perform unknown replacement (optionally with alignment dictionary)')
    group.add_argument('--prefix-size', default=0, type=int, metavar='PS',
                       help='initialize generation by target prefix of given length')
    group.add_argument('--sampling', action='store_true',
                       help='sample hypotheses instead of using beam search')
    group.add_argument('--sampling-topk', default=-1, type=int, metavar='PS',
                       help='sample from top K likely next words instead of all words')
    group.add_argument('--sampling-temperature', default=1, type=float, metavar='N',
                       help='temperature for random sampling')
    group.add_argument('--print-alignment', action='store_true',
                       help='if set, uses attention feedback to compute and print alignment to source tokens')
    group.add_argument('--online-eval', action='store_true',
                       help='score model at the end of epoch')
    group.add_argument('--save-predictions', action='store_true',
                       help='Save predictions produced with online evaluation')
    group.add_argument('--test-cased-bleu', action='store_true',
                       help='Use cased bleu for online eval')
    group.add_argument('--bpe-codes', default=None, type=str, metavar='CODES',
                       help='file with bpe codes')
    group.add_argument('--buffer-size', default=64, type=int, metavar='N',
                       help='read this many sentences into a buffer before processing them')
    group.add_argument('--fp16', action='store_true', help='use fp16 precision')
    return group

def add_model_args(parser):
    group = parser.add_argument_group('Model configuration')

    # Model definitions can be found under fairseq/models/
    #
    # The model architecture can be specified in several ways.
    # In increasing order of priority:
    # 1) model defaults (lowest priority)
    # 2) --arch argument
    group.add_argument(
        '--arch', '-a', default='fconv', metavar='ARCH', required=True,
        choices=ARCH_MODEL_REGISTRY.keys(),
        help='model architecture: {} (default: fconv)'.format(
            ', '.join(ARCH_MODEL_REGISTRY.keys())),
    )

    # Criterion definitions can be found under fairseq/criterions/
    group.add_argument(
        '--criterion', default='cross_entropy', metavar='CRIT',
        choices=CRITERION_REGISTRY.keys(),
        help='training criterion: {} (default: cross_entropy)'.format(
            ', '.join(CRITERION_REGISTRY.keys())),
    )

    return group


def add_perf_args(parser):
    group = parser.add_argument_group('Performance')
    group.add_argument('--fuse-dropout-add', action='store_true',
                       help='Fuse dropout and residual adds.')
    group.add_argument('--fuse-relu-dropout', action='store_true',
                       help='Fuse Relu and Dropout.')
    group.add_argument('--fuse-layer-norm', action='store_true',
                       help='Use APEX\'s FusedLayerNorm instead of torch.nn.LayerNorm')

    return group


def add_lpmm_args(parser):
    group = parser.add_argument_group('LPMM Quantization')
    group.add_argument("--q_cfg", default="../../lpmm/configs/default.yml", type=str)
    group.add_argument('--lpmm_enable', type=int, default=0)
    group.add_argument('--pb', type=int, default=8, help='parameter compression bits')
    group.add_argument('--gb', type=int, default=4, help='gradient compression bits')
    group.add_argument('--mb', type=int, default=4, help='momentum compression bits')
    group.add_argument('--sqmb', type=int, default=4, help='square momentum compression bits')
    group.add_argument('--truncated-mode', type=str, default=None)
    group.add_argument('--truncated-factor', type=float, default=None)
    group.add_argument('--truncated-global-factor', type=float, default=None)
    group.add_argument("--quant_type", type=str, default=None)
    group.add_argument("--round_type", type=str, default=None)
    group.add_argument("--q_oracle", type=str, default=None)

    return group
