#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
# Pytorch version: 1.1.0 or above


## star convex verification, store weights theta, loss l, and updates (i.e., gradients for SGD) u.
## all experiments on CIFAR10, optimizer=SGD, momentum=0.9, learning rate=0.05(small Inception)
# learning rate = 0.1(fro small Alexnet and MLPs) weight-decay=0.95perepoch
import argparse

parser = argparse.ArgumentParser()

parser.add_argument('--command', default='train', choices=['train', 'test'])
parser.add_argument('--data', default='cifar10', choices=['cifar10', 'cifar100', 'MNIST'])
#parser.add_argument('--num-classes', default=10, choices=[10, 100, 10], type=int)

parser.add_argument('--U-batch-size', type=int, default=256)
parser.add_argument('--A-batch-size', type=int, default=256)
parser.add_argument('--epochs', type=int, default=10) # epochs=20 can ensure the complete training
parser.add_argument('--learning-rate', type=float, default=0.01) # Centralized Model of CIFAR-10

parser.add_argument('--arch', default='alexnet', choices=['ae','conv_ae','unet','mlp','alexnet','inception','vgg11','vgg16','resnet18','resnet34'], 
                    help='Alexnet and mlp can be applied in MNISAT while resnet not')
parser.add_argument('--mlp-spec', default='256x128', help='mlp spec: e.g. 512x128x512 indicates 3 hidden layers')
parser.add_argument('--name', default='', help='Experiment name')
#parser.add_argument('--data-path', default='./saved_paras', help='The temporary data storage path')
parser.add_argument('--script-path', default='./saved_paras', help='The script path')
parser.add_argument('--loss', default='cross_entropy', choices=['cross_entropy','MSE', 'KL_divergence'])
#parser.add_argument('--train-size', type=int, default=50000, choices=[50000, 500, 60000], help='The size of the training dataset')
#parser.add_argument('--test-size', type=int, default=10000, choices=[10000, 100, 10000], help='The size of the training dataset')
#parser.add_argument('--input-dim', type=int, default=3, choices=[3, 3, 1], help='The channel number of images')
#parser.add_argument('--ite-eval-epochs', type=int, default=1, help='The number of epochs shown in test')
parser.add_argument('--adjust-lr', type=int, default=-1) # default=9999

########################################################### Proj 1 gamma optimization ######################################################################################
parser.add_argument('--activation', default='relu', choices=['sigmoid', 'tanh', 'relu', 'leaky_relu'], help='The activation functions of the model')
parser.add_argument('--dropout', default='no_drop', choices=['no_drop', 'dropout_1', 'dropout'], help='If add dropout layers to the model')
parser.add_argument('--bn', default='bn', choices=['no_bn', 'bn', 'no_1', 'no_2', 'no_3', 'no_4', 'no_12', 'no_123'], help='If add batch normalization layers to the model')
parser.add_argument('--skip', type=int, default=0, choices=[0, 1, 2, 3, 4], help='The number of skip connections to delete')
parser.add_argument('--algorithm', default='sgd', choices=['sgd', 'sgd_m', 'adam'], help='The specific optimizer to use')
parser.add_argument('--mo', type=float, default=0.5, help='The momentum term in sgd-m, {0.1, 0.5, 0.9}')

########################################################### Proj 4 assisted learning ######################################################################################
parser.add_argument('--learn-type', default='SGD', choices=['SGD', 'USGD', 'fedavg', 'SAASGD'], help='Choose the type of learning')
parser.add_argument('--random', type=int, default=0, choices=[0, 1], help='Choose if use the general input data, 0==False, 1==True')
parser.add_argument('--U-iid-ratio', type=float, default=0.1, help='The ratio of major class for an User to do iid (0.1) or non-iid (>0.1) sampling when random==False, {0.1,0.3,0.5,0.7,0.9}')
parser.add_argument('--A-iid-ratio', type=float, default=0.1, help='The ratio of major class for an Agent to do iid (0.1) or non-iid (>0.1) sampling when random==False, {0.1,0.3,0.5,0.7,0.9}')
parser.add_argument('--unequal', type=int, default=0, choices=[0, 1], help='Choose if use the unequal input data for users, 0==False, 1==True')
parser.add_argument('--U-num-users', type=int, default=1, help='The number of User users or devices created') # num_users smaller, performance better, but does not matter much
parser.add_argument('--A-num-users', type=int, default=1, help='The number of Agent users or devices created')
parser.add_argument('--user-frac', type=float, default=1.0, help='The ratio of valid users chosen from the total users')
parser.add_argument('--local-steps-init', type=int, default=2000, help='The number of total local steps for local updates, default: 2000') # local_step larger, performance much better
parser.add_argument('--local-interval', type=int, default=50, help='The local interval for delivering the models between user and agent, {1,2,5,10}')
parser.add_argument('--data-init', type=int, default=50000, help='The number of total initial data smaples to create, default: 50000')
parser.add_argument('--rou', type=float, default=0.1, help='The User data-size ratio = User data-size/data_init, so Agent data-size ratio = 1-rou {0.1 ~ 0.5}')
parser.add_argument('--scale-range', type=int, default=100, help='The number of maximum range to vary initial data smaples')
parser.add_argument('--plot-hist', type=int, default=None, choices=[0, 1], help='Choose if plot the model histograms for differential privacy, None==False, 0==save models, 1==plot hist')
parser.add_argument('--s', type=float, default=0.01, help='When args.command=="test", the std of Gaussian noise to be added to model parameters for differential privacy, {0.01,0.03,0.1,0.3}')
parser.add_argument('--eps', type=float, default=None, help='The parameter epsilon for differential privacy to compute sigma, {0.1,1.0,10.0}')
parser.add_argument('--over-sampling', type=int, default=0, choices=[0, 1], help='Choose if use over-sampling to solve class-imbalaneced problem for SGD framework, 0==False, 1==True')


def format_experiment_name(args):
    if args.data == 'MNIST':
        args.num_classes = 10
        args.train_size = 60000
        args.test_size = 10000
        args.input_dim = 1
        args.learning_rate = 0.005
    elif args.data == 'cifar100':
        args.num_classes = 100
        args.train_size = 500
        args.test_size = 100
        args.input_dim = 3
    else:
        args.num_classes = 10
        args.train_size = 50000
        args.test_size = 10000
        args.input_dim = 3

    if args.algorithm == 'adam':
        args.learning_rate = args.learning_rate/100

    name = args.data+'_'
    name += args.arch+'_'
    if args.name != '':
        name += args.learn_type+'_'+args.name+'_'
    else:
        name += args.learn_type+'_'
    #name += args.algorithm+'_'
    #if args.algorithm == 'sgd_m':
    #    name += str(args.mo)+'_'
    #name += 'lr_'+str(args.learning_rate)+'_'
    #name += '_Udata_'+str(args.U_data_init)+'_Adata_'+str(args.A_data_init)
    #if args.user_frac != 1.0:
        #name += '_frac_'+str(args.user_frac)
    if args.learn_type == 'SAASGD':
        if args.eps != None:
            name += 'eps_'+str(args.eps)+'_'
        name += 'Ubs_'+str(args.U_batch_size)+'_Abs_'+str(args.A_batch_size)
        #if args.random == 1:
            #name += '_rand'
        #else:
            #if args.unequal == 1:
                #name += '_uneq'
            #else:
                #name += '_eq'
        #name += '_Uusers_'+str(args.U_num_users)+'_Ausers_'+str(args.A_num_users)
        name += '_inter_'+str(args.local_interval)
    else:
        name += 'bs_'+str(args.U_batch_size)
    name += '_Uiidr_'+str(args.U_iid_ratio)+'_Aiidr_'+str(args.A_iid_ratio)
    name += '_ls_'+str(args.local_steps_init)
    name += '_ds_'+str(args.data_init)+'_rou_'+str(args.rou)
    name += '_rounds_'+str(args.epochs)
    if args.over_sampling == 1:
        name += '_oversamp'
    return name


def parse_args():
    args = parser.parse_args()
    args.exp_name = format_experiment_name(args)
    return args
