import argparse
import json
import os
import singleshot

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Network Compression')
    # Training Hyperparameters
    training_args = parser.add_argument_group('training')
    training_args.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'imagenet', 'speed_test'])
    training_args.add_argument('--model', type=str, default='resnet18', help='model architecture (default: resnet18)')
    training_args.add_argument('--groups', type=int, default=1, help='number of groups for groups convolution')
    training_args.add_argument('--width-factor', type=float, default=1.0, help='width factor for the model')
    training_args.add_argument('--dense-classifier', type=bool, default=False, help='ensure last layer of model is dense (default: False)')
    training_args.add_argument('--pretrained', type=bool, default=False, help='load pretrained weights (default: False)')
    training_args.add_argument('--optimizer', type=str, default='adam', choices=['sgd', 'momentum', 'adam', 'rms'], help='optimizer (default: adam)')
    training_args.add_argument('--train-batch-size', type=int, default=64, help='input batch size for training (default: 64)')
    training_args.add_argument('--test-batch-size', type=int, default=256, help='input batch size for testing (default: 256)')
    training_args.add_argument('--pre-epochs', type=int, default=0, help='number of epochs to train before pruning (default: 0)')
    training_args.add_argument('--post-epochs', type=int, default=10, help='number of epochs to train after pruning (default: 10)')
    training_args.add_argument('--lr-scheduler', type=str, default='drop', choices=['drop', 'linear'], help='learning rate scheduler')
    training_args.add_argument('--lr', type=float, default=0.001, help='learning rate (default: 0.001)')
    training_args.add_argument('--lr-drops', type=int, nargs='*', default=[], help='list of learning rate drops (default: [])')
    training_args.add_argument('--lr-step-size', type=int, default=0, help='learning rate drop frequency')
    training_args.add_argument('--lr-drop-rate', type=float, default=0.1, help='multiplicative factor of learning rate drop (default: 0.1)')
    training_args.add_argument('--weight-decay', type=float, default=0.0, help='weight decay (default: 0.0)')
    # Pruning Hyperparameters
    pruning_args = parser.add_argument_group('pruning')
    pruning_args.add_argument('--pruner',
                              type=str,
                              default='synflow',
                              choices=['synflow', 'snip', 'grasp', 'lottery', 'opt_params', 'opt_flops', 'opt_both'],
                              help='prune strategy (default: rand)')
    pruning_args.add_argument('--compression',
                              type=float,
                              default=1.0,
                              help='quotient of prunable non-zero prunable parameters before and after pruning (default: 1.0)')
    pruning_args.add_argument('--compression-flops',
                              type=float,
                              default=None,
                              help='quotient of prunable non-zero prunable flops before and after pruning (default: None)')
    pruning_args.add_argument('--prune-epochs', type=int, default=1, help='number of iterations for scoring (default: 1)')
    pruning_args.add_argument('--compression-schedule',
                              type=str,
                              default='exponential',
                              choices=['linear', 'exponential', 'constant', 'increase'],
                              help='whether to use a linear or exponential compression schedule (default: exponential)')
    pruning_args.add_argument('--mask-scope', type=str, default='global', choices=['global', 'pregrouping', 'precropping', 'filter'])
    pruning_args.add_argument('--prune-dataset-ratio', type=int, default=10, help='ratio of prune dataset size and number of classes (default: 10)')
    pruning_args.add_argument('--prune-batch-size', type=int, default=256, help='input batch size for pruning (default: 256)')
    pruning_args.add_argument('--Lambda', type=float, default=None, help='lambda used in kernel pruning regularization (default: None)')
    pruning_args.add_argument('--prune-bias', type=bool, default=False, help='whether to prune bias parameters (default: False)')
    pruning_args.add_argument('--prune-batchnorm', type=bool, default=False, help='whether to prune batchnorm layers (default: False)')
    pruning_args.add_argument('--prune-residual', type=bool, default=False, help='whether to prune residual connections (default: False)')
    pruning_args.add_argument('--no-prune-linear', action='store_false', help='whether to prune linear layers (default: True)')
    pruning_args.add_argument('--prune-shortcut', action='store_false', help='whether to prune shortcut 1x1 conv2D layers (default: True)')
    pruning_args.add_argument('--prune-pw-only', action='store_true', help='only prune pointwise convolution, for ShuffleNet only now')
    pruning_args.add_argument('--fix-shuffle', action='store_true', help='enabling fixed channel shuffle')
    pruning_args.add_argument('--prune-train-mode', type=bool, default=False, help='whether to prune in train mode (default: False)')
    pruning_args.add_argument('--reinitialize', action='store_true', help='whether to reinitialize weight parameters after pruning (default: False)')
    pruning_args.add_argument('--skip-last', action='store_true', help='whether to skip pruning the very last Linear layer')
    pruning_args.add_argument('--sqrt',
                              action='store_true',
                              help='whether to use sqare root of the sparsity or sparsity to initialize the weight (default: no sqrt)')
    pruning_args.add_argument('--reinitialize-sparse',
                              action='store_true',
                              help='whether to reinitialize weight parameters after pruning in sparse mode (default: False)')
    pruning_args.add_argument('--shuffle', action='store_true', help='whether to shuffle masks after pruning (default: False)')
    pruning_args.add_argument('--uniform_shuffle', action='store_true', help='whether to shuffle masks after pruning uniformly (default: False)')
    pruning_args.add_argument('--invert', action='store_true', help='whether to invert scores during pruning (default: False)')
    pruning_args.add_argument('--random', action='store_true', help='use random score within each layer')
    pruning_args.add_argument('--expand', action='store_true', help='whether allow expand the layers when scope = opt_both')
    pruning_args.add_argument('--prune-epoch-list', type=int, nargs='*', default=[], help='list of prune epochs for singleshot (default: [])')
    ## Experiment Hyperparameters ##
    parser.add_argument('--expid', type=str, default='', help='name used to save results (default: "")')
    parser.add_argument('--result-dir', type=str, default='Results/data', help='path to directory to save results (default: "Results/data")')
    parser.add_argument('--resume', type=str, default=None, help='path to directory to resume LT prune')
    parser.add_argument('--gather-result-path', type=str, default=None, help='path to the file gather results across different runs')
    parser.add_argument('--data-dir', type=str, default='/path/to/imagenet', help='path to directory to store the data')
    parser.add_argument('--gpu', type=int, default='0', help='number of GPU device to use (default: 0)')
    parser.add_argument('--workers', type=int, default='8', help='number of data loading workers (default: 32)')
    parser.add_argument('--no-cuda', action='store_true', help='disables CUDA training')
    parser.add_argument('--parallel', action='store_true', help='enables multi-CUDA parallel training')
    parser.add_argument('--scores-only', action='store_true', help='only compute and store the data for the pruning')
    parser.add_argument('--speed-test', type=str, default='', choices=['', 'inference', 'training'])
    parser.add_argument('--flops', action='store_true', help='change pruning objective to flops, pruning objective is #params by default')
    parser.add_argument('--seed', type=int, default=None, help='random seed (default: 1)')
    parser.add_argument('--verbose', action='store_true', help='print statistics during training and testing')
    parser.add_argument('--print-model', action='store_true', help='print the model definition before training')
    args = parser.parse_args()

    ## Construct Result Directory ##
    if args.expid == "":
        print("WARNING: this experiment is not being saved.")
        setattr(args, 'save', False)
    else:
        result_dir = '{}/{}'.format(args.result_dir, args.expid)
        setattr(args, 'save', True)
        setattr(args, 'result_dir', result_dir)
        try:
            os.makedirs(result_dir)
        except FileExistsError:
            pass

    ## Save Args ##
    if args.save:
        with open(args.result_dir + '/args.json', 'w') as f:
            json.dump(args.__dict__, f, sort_keys=True, indent=4)

    ## Run Experiment ##
    singleshot.run(args)
