import argparse
from Experiments import Baseline_FC
from Experiments import Random_Walks
from Experiments import Biased_Random_Walks
from Experiments import Grad_Biased_Randwalk
if __name__ == '__main__':

    parser = argparse.ArgumentParser(description = 'Network Construction')
    parser.add_argument('--experiment', type = str, default = 'Biased_Random_Walks',
                        choices = ['Baseline_FC','Random_Walks','PHEW','Grad_Biased_Randwalk'])
    parser.add_argument('--expid', type = str, default = '0')
    parser.add_argument('--model', type = str, default = 'vgg19-bn',
                        choices = ['vgg11','vgg11-bn','vgg13','vgg13-bn','vgg16','vgg16-bn','vgg19','vgg19-bn',
                                   'ResNet18','ResNet34','ResNet50','ResNet101','ResNet152', 'lenet5', 'ResNet32', 'mlp_3'])
    parser.add_argument('--dataset', type = str, default = 'cifar10',
                        choices = ['mnist','cifar10','cifar100','tiny-imagenet'])
    parser.add_argument('--optimizer', type = str, default = 'sgd',
                        choices = ['sgd', 'adam', 'momentum', 'rms'])
    parser.add_argument('--train_batch_size', type=int, default=128)
    parser.add_argument('--test_batch_size', type=int, default=256)
    parser.add_argument('--epochs', type = int, default = 160)
    parser.add_argument('--lr', type=float, default=0.1)
    parser.add_argument('--lr_drops', type=int, nargs='*', default=[80, 120])
    parser.add_argument('--lr_drop_rate', type=float, default=0.1)
    parser.add_argument('--weight_decay', type=float, default=0.0001)
    parser.add_argument('--prune_perc', type = float, default = [0.9,0.95,0.98])
    parser.add_argument('--pre_epochs', type = int, default = 160)
    parser.add_argument('--compression', type = float, nargs='*', default=[0.25,0.5,0.75,1.0,1.25,1.5,1.75,2.0,2.25,2.5,2.75,3.0,3.25,3.5,3.75,4.0,4.25,4.5,4.75,5.0,5.25,5.5,5.75,6.0])
    parser.add_argument('--prune_dataset_size', type=int, default=20000)
    parser.add_argument('--prune_batch_size', type=int, default=256)
    parser.add_argument('--lottery_iterations', type = int, default = 20)
    parser.add_argument('--lottery_pre-epochs', type = int, default = 20)
    parser.add_argument('--gpu', type=int, default='0')
    parser.add_argument('--workers', type = int, default = 4)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--synflow_iterations', type=int, default=100)

    args = parser.parse_args()
    print(args.experiment)

    if args.experiment == 'Baseline_FC':
        Baseline_FC.run(args)

    if args.experiment == 'Random_Walks':
        Random_Walks.run(args)

    if args.experiment == 'PHEW':
        Biased_Random_Walks.run(args)

    if args.experiment == 'Grad_Biased_Randwalk':
        Grad_Biased_Randwalk.run(args)
