import argparse
import json
import os
import logging
from Experiments import singleshot
from Experiments import extract_mask
from Experiments.theory import unit_conservation
from Experiments.theory import layer_conservation
from Experiments.theory import imp_conservation
from Experiments.theory import schedule_conservation
from Utils.logger import *
from datetime import datetime
import wandb

os.environ['WANDB_API_KEY'] = ''

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Network Compression')
    # Training Hyperparameters
    training_args = parser.add_argument_group('training')
    training_args.add_argument('--dataset', type=str, default='mnist',
                        choices=['mnist','cifar10','cifar100','tiny-imagenet','imagenet'],
                        help='dataset (default: mnist)')
    training_args.add_argument('--model', type=str, default='fc', choices=['fc','conv',
                        'vgg11','vgg11-bn','vgg13','vgg13-bn','vgg16','vgg16-bn','vgg19','vgg19-bn',
                        'resnet18','resnet20','resnet32','resnet34','resnet44','resnet50',
                        'resnet56','resnet101','resnet110','resnet110','resnet152','resnet1202',
                        'wide-resnet18','wide-resnet20','wide-resnet32','wide-resnet34','wide-resnet44','wide-resnet50',
                        'wide-resnet56','wide-resnet101','wide-resnet110','wide-resnet110','wide-resnet152','wide-resnet1202',
                        'wide-resnet20-64', 'wide-resnet20-128','wide-resnet20-256','wide-resnet20-512'],
                        help='model architecture (default: fc)')
    training_args.add_argument('--hidden_dim', type=int, default=100,
                        help='hidden dimension (default: 100)')
    training_args.add_argument('--n_layers', type=int, default=4,
                        help='number of hidden layers (default: 4)')
    training_args.add_argument('--model-class', type=str, default='default', choices=['default','lottery','tinyimagenet','imagenet'],
                        help='model class (default: default)')
    training_args.add_argument('--dense-classifier', type=bool, default=False,
                        help='ensure last layer of model is dense (default: False)')
    training_args.add_argument('--pretrained', type=bool, default=False,
                        help='load pretrained weights (default: False)')
    training_args.add_argument('--optimizer', type=str, default='adam', choices=['sgd','momentum',
                        'adam','rms', 'adagrad'],
                        help='optimizer (default: adam)')
    training_args.add_argument('--train-batch-size', type=int, default=64,
                        help='input batch size for training (default: 64)')
    training_args.add_argument('--test-batch-size', type=int, default=256,
                        help='input batch size for testing (default: 256)')
    training_args.add_argument('--pre-epochs', type=int, default=0,
                        help='number of epochs to train before pruning (default: 0)')
    training_args.add_argument('--post-epochs', type=int, default=10,
                        help='number of epochs to train after pruning (default: 10)')
    training_args.add_argument('--lr', type=float, default=0.001,
                        help='learning rate (default: 0.001)')
    training_args.add_argument('--lr-drops', type=int, nargs='*', default=[60, 120],
                        help='list of learning rate drops (default: [])')
    training_args.add_argument('--lr-drop-rate', type=float, default=0.1,
                        help='multiplicative factor of learning rate drop (default: 0.1)')
    training_args.add_argument('--weight-decay', type=float, default=1e-4,
                        help='weight decay (default: 0.0)')
    training_args.add_argument('--update-frequency', type=int, default=20,
                        help='update frequency for PathNorm optimizer (default: 20)')
    training_args.add_argument('--lambda-reg', type=float, default=0.1,
                        help='regularization parameter for PathNorm optimizer (default: 0.1)')
    training_args.add_argument('--path-momentum', type=float, default=0.9,
                        help='momentum parameter for PathNorm optimizer (default: 0.9)')
    training_args.add_argument('--use-path-aware-clipping', type=bool, default=False,   
                        help='use path-aware clipping for PathNorm optimizer (default: False)')
    training_args.add_argument('--base-clip-norm', type=float, default=1.0,
                        help='base clipping norm for PathNorm optimizer (default: 1.0)')

    # Pruning Hyperparameters
    pruning_args = parser.add_argument_group('pruning')
    pruning_args.add_argument('--pruner', type=str, default='rand', 
                        choices=['rand','mag','snip','grasp','synflow'],
                        help='prune strategy (default: rand)')
    pruning_args.add_argument('--compression', type=float, default=1.0,
                        help='quotient of prunable non-zero prunable parameters before and after pruning (default: 1.0)')
    pruning_args.add_argument('--prune-epochs', type=int, default=1,
                        help='number of iterations for scoring (default: 1)')
    pruning_args.add_argument('--compression-schedule', type=str, default='exponential', choices=['linear','exponential'],
                        help='whether to use a linear or exponential compression schedule (default: exponential)')
    pruning_args.add_argument('--mask-scope', type=str, default='global', choices=['global','local'],
                        help='masking scope (global or layer) (default: global)')
    pruning_args.add_argument('--prune-dataset-ratio', type=int, default=10,
                        help='ratio of prune dataset size and number of classes (default: 10)')
    pruning_args.add_argument('--prune-batch-size', type=int, default=256,
                        help='input batch size for pruning (default: 256)')
    pruning_args.add_argument('--prune-bias', type=bool, default=False,
                        help='whether to prune bias parameters (default: False)')
    pruning_args.add_argument('--prune-batchnorm', type=bool, default=False,
                        help='whether to prune batchnorm layers (default: False)')
    pruning_args.add_argument('--prune-residual', type=bool, default=False,
                        help='whether to prune residual connections (default: False)')
    pruning_args.add_argument('--prune-train-mode', type=bool, default=False,
                        help='whether to prune in train mode (default: False)')
    pruning_args.add_argument('--reinitialize', type=bool, default=False,
                        help='whether to reinitialize weight parameters after pruning (default: False)')
    pruning_args.add_argument('--shuffle', type=bool, default=False,
                        help='whether to shuffle masks after pruning (default: False)')
    pruning_args.add_argument('--invert', type=bool, default=False,
                        help='whether to invert scores during pruning (default: False)')
    pruning_args.add_argument('--pruner-list', type=str, nargs='*', default=[],
                        help='list of pruning strategies for singleshot (default: [])')
    pruning_args.add_argument('--prune-epoch-list', type=int, nargs='*', default=[],
                        help='list of prune epochs for singleshot (default: [])')
    pruning_args.add_argument('--compression-list', type=float, nargs='*', default=[],
                        help='list of compression ratio exponents for singleshot/multishot (default: [])')
    pruning_args.add_argument('--level-list', type=int, nargs='*', default=[],
                        help='list of number of prune-train cycles (levels) for multishot (default: [])')
    ## Experiment Hyperparameters ##
    parser.add_argument('--experiment', type=str, default='singleshot', 
                        choices=['singleshot','multishot','unit-conservation', 'extract-mask',
                        'layer-conservation','imp-conservation','schedule-conservation'],
                        help='experiment name (default: example)')
    parser.add_argument('--expid', type=str, default='',
                        help='name used to save results (default: "")')
    parser.add_argument('--result-dir', type=str, default='Results/data',
                        help='path to directory to save results (default: "Results/data")')
    parser.add_argument('--gpu', type=int, default='0',
                        help='number of GPU device to use (default: 0)')
    parser.add_argument('--workers', type=int, default='4',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('--no-cuda', action='store_true',
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1,
                        help='random seed (default: 1)')
    parser.add_argument('--verbose', action='store_true',
                        help='print statistics during training and testing')
    parser.add_argument('--is_stat_eigenvalues', action='store_true', default=False,
                        help='Trace Eigenvalues of Hessian')
    
    parser.add_argument('--delta', type=float, default=1.0, help='delta for Huber Loss used in AMMD optimizer')

    parser.add_argument('--is_wandb', action='store_true', default=False)
    parser.add_argument('--wandb_note', type=str, default='')
    parser.add_argument('--wandb_tags', type=str, nargs='*', default=[])
    parser.add_argument('--wandb_project', type=str, default="")

    run_time = datetime.now().strftime('%Y-%m-%d_%H:%M')
    args = parser.parse_args()
    
    args_dict = vars(args)

    # args.wandb_name = f"{args.optimizer}_lr_{args.lr}"
    args. wandb_name = f"{args.pruner}_{args.compression}"
    args.wandb_group = f"{args.dataset}_{args.model}"
    args.wandb_job_type = f"compression_{args.compression}"

    setting_name = f"{args.model} {args.dataset}"
    setting_name = setting_name.capitalize()
    
    if args.is_wandb:
        # start a new wandb run to track this script
        wandb.init(
            # set the wandb project where this run will be logged
            project=f"Sparse-Optimization {setting_name} {args.wandb_project}",
            name = args.wandb_name,
            group = args.wandb_group,
            job_type = args.wandb_job_type,
            notes = args.wandb_note,
            tags = args.wandb_tags,
            # track hyperparameters and run metadata
            config=args_dict
        )
        wandb.define_metric('Steps')
        wandb.define_metric("*", step_metric="Steps")
        args.wb = wandb

    else:
        args.wb = None

    if args.experiment == 'extract-mask':
        args.logger = setup_logger(log_file='./logging/extract_mask.log')
    else:
        logging_path = f"./logging/{args.dataset}/{args.model_class}/{args.model}/{args.pruner}/{args.optimizer}" 
        os.makedirs(logging_path, exist_ok=True)
        if args.is_flipping:
            logfile = f"{run_time}_comp_{args.compression}_type_{args.flipping_type}_r_{args.flipping_ratio}_freq_{args.flipping_freq}_logs.log"
        else:
            logfile = f"{run_time}_comp_{args.compression}_logs.log"
        args.logging_file = f'{logging_path}/{logfile}'
        args.logger = setup_logger(log_file=f'{logging_path}/{logfile}')

        try:
            print_and_log(args.logger, args_dict)
        except:
            print_and_log(args.logger, 'Cannot not log arguments')

        

    # Construct Result Directory ##
    if args.expid == "":
        print("WARNING: this experiment is not being saved.")
        setattr(args, 'save', False)
    else:
        result_dir = '{}/{}/{}'.format(args.result_dir, args.experiment, args.expid)
        setattr(args, 'save', True)

    ## Run Experiment ##
    if args.experiment == 'singleshot':
        singleshot.run(args)
    if args.experiment == 'extract-mask':
        extract_mask.run(args)
