# %%
import SFL_passive
import SFL_threesets
import argparse
import SFL_multilevel

parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')

# training setting 
parser.add_argument('--filename', required=True, type=str, help='please type save_file name for the testing purpose')
parser.add_argument('--random_seed', default=123, type=int, help='random_seed')
parser.add_argument('--arch', default="resnet20", type=str, help='please type save_file name for the testing purpose')
parser.add_argument('--cutlayer', default=3, type=int, help='number of layers in client-side')
parser.add_argument('--batch_size', default=128, type=int, help='training batch size in server side, the client side batch size is divided by the num of agent.')
parser.add_argument('--num_agent', default=10, type=int, help='number of clients')
parser.add_argument('--num_epochs', default=200, type=int, help='number of epochs')
parser.add_argument('--learning_rate', default=0.1, type=float, help='Learning Rate for server-side model')
parser.add_argument('--local_lr', default=0.1, type=float, help='Learning Rate for client-side model')
parser.add_argument('--load', action='store_true', default=False, help='if True, load pre-trained model.')
parser.add_argument('--save', action='store_true', default=False, help='if True, save the best trained model.')

#Non-IID dataset setting
parser.add_argument('--dataset', default="cifar10", type=str, help='dataset: CIFAR10 or CIFAR100')
parser.add_argument('--datadist', default="iid", type=str, help='iid,dir\lambda,orderdir\lambda. The last one is only used for 50 clients to reproduce the results.')

#Heterogeneous clients setting 
parser.add_argument('--max_channel', required=True, type=int, help='max channel of bottleneck layer of clients compared with the original channel size of the split model.')
parser.add_argument('--channeldist', required=True, type=str, help='channel_a_b_c. a,b,c represents the channel size used for low, middle and high clients compared with the max_channel a,b,c<=1.')
parser.add_argument('--comm_env',required=True, type=str, help="device_a_b_c. a,b,c represents the number of low, middle, and high clients(a+b+c=1).")

# hyperparameter for logit calibration
parser.add_argument('--tau',  default=5, type=float, help='The parameter of the logit calibration')

#Other competitive methods
parser.add_argument('--sparsity',  action='store_true', default=False, help='if True, use --sparsity for both activations and gradients.')
parser.add_argument('--indicator', default="useBL", type=str, help='the policy used to handle heterogenity. useBL or dropdata')

# Use bottleneck layer
parser.add_argument('--no_subnetwork', action='store_true', default=False, help='if True, do not use sub-network. Naive different-sized BL.')
parser.add_argument('--heteroSFL',  action='store_true', default=False, help='if True, use --heteroSFL.')
parser.add_argument('--no_BDKS',  action='store_true', default=False, help='if True, use --no_BDKS for heteroSFL.')
parser.add_argument('--no_W2N',  action='store_true', default=False, help='if True, use --no_W2N, where for heteroSFL narrow in high-end trained by logit calibration.')
parser.add_argument('--theta',  default=0.1, type=float, help='The parameter of the minority class hard threshold.')
parser.add_argument('--alpha',  default=0.1, type=float, help='The parameter to control the N2W magnitute.')

# 3 client set settings
parser.add_argument('--threesets', action='store_true', default=False, help='if True, use another python code to run the threesets case.')
parser.add_argument('--mutliset',  action='store_true', default=False, help='if True, use another python code to run the mutliset case.')

args = parser.parse_args()

if args.threesets:
    mi = SFL_threesets.MIA(args)
elif args.mutliset:
    mi = SFL_multilevel.MIA(args)

else:
    mi = SFL_passive.MIA(args)

mi.logger.debug(str(args))
log_frequency = 500
mi(verbose=True, progress_bar=True, log_frequency=log_frequency)


    


# %%
