#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import argparse
import torch


def args_parser():
    parser = argparse.ArgumentParser()
    # federated arguments
    parser.add_argument('--epochs', type=int, default=500, help="rounds of training")
    parser.add_argument('--round', type=int, default=0, help="rounds of communication")
    parser.add_argument('--num_users', type=int, default=10, help="number of users: K")
    parser.add_argument('--frac', type=float, default=1, help="the fraction of clients: C")
    parser.add_argument('--local_ep', type=int, default=5, help="the number of local epochs: E")
    parser.add_argument('--local_bs', type=int, default=10, help="local batch size: B")
    parser.add_argument('--bs', type=int, default=128, help="test batch size")
    parser.add_argument('--lr', type=float, default=0.002, help="learning rate")
    parser.add_argument('--momentum', type=float, default=0, help="SGD momentum (default: 0.5)")
    parser.add_argument('--gamma', type=str, default="0.1,0.1,0.05", help="local step sizes (comma-separated)")  # lr of y, theta, x
    parser.add_argument('--beta', type=str, default="0.01,0.01,0.01", help="local step sizes (comma-separated)")
    # parser.add_argument('--a', type=list, default=[1, 1, 1], help="local step sizes")
    parser.add_argument('--split', type=str, default='user', help="train-test split type, user or sample")

    # bilevel arguments
    parser.add_argument('--neumann', type=int, default=5, help="The iteration of nuemann series")
    parser.add_argument('--inner_ep', type=int, default=1, help="the number of hyper local epochs: H_E")
    parser.add_argument('--outer_tau', type=int, default=5, help="the number of hyper local epochs: H_E")
    parser.add_argument('--hlr', type=float, default=0.001, help="learning rate")
    parser.add_argument('--hvp_method', type=str, default='global_batch', help='hvp method')
    parser.add_argument('--no_blo', action='store_true', help='whether blo or not')
    # Minmax arguments
    parser.add_argument('--minmax_s', type=int, default=5, help="The heterogenity of synthetic dataset")

    # model arguments
    parser.add_argument('--model', type=str, default='mlp', help='model name')
    parser.add_argument('--kernel_num', type=int, default=9, help='number of each kind of kernel')
    parser.add_argument('--kernel_sizes', type=str, default='3,4,5',
                        help='comma-separated kernel size to use for convolution')
    parser.add_argument('--norm', type=str, default='batch_norm', help="batch_norm, layer_norm, or None")
    parser.add_argument('--num_filters', type=int, default=32, help="number of filters for conv nets")
    parser.add_argument('--max_pool', type=str, default='True',
                        help="Whether use max pooling rather than strided convolutions")
    parser.add_argument('--optim', type=str, default='sgd', help='optimizer name')
    parser.add_argument('--topoModel', type=str, default='circle', choices=['exp','circle', 'linear', 'circle1', 'circle2', 'cent'],
                        help='topology model')
    parser.add_argument('--alg', type=str, default='SUN-HR', choices=['SUN-SE','SUN-GT', 'SUN-HR','DSGDA-GT'])

    # other arguments
    parser.add_argument('--dataset', type=str, default='mnist', help="name of dataset")
    parser.add_argument('--iid', action='store_true', help='whether i.i.d or not')
    parser.add_argument('--num_classes', type=int, default=10, help="number of classes")
    parser.add_argument('--num_channels', type=int, default=1, help="number of channels of imges")
    parser.add_argument('--gpu', type=int, default=2, help="GPU ID, -1 for CPU")
    parser.add_argument('--stopping_rounds', type=int, default=10, help='rounds of early stopping')
    parser.add_argument('--verbose', action='store_true', help='verbose print')
    parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
    parser.add_argument('--all_clients', action='store_true', help='aggregation over all clients')
    parser.add_argument('--output', type=str, default=None, help="output path")

    # dataset related arguments

    # parser.add_argument('--dataset', default='FEMNIST', help='name of dataset to use for an experiment', type=str)
    parser.add_argument('--p', type=float, default=0.5, help="fraction of data to be used for each client")
    parser.add_argument('--in_channels',
                        help='input channels for image dataset (ignored when `Shakespeare` dataset is used)', type=int,
                        default=3)
    parser.add_argument('--test_fraction', help='fraction of test dataset at each client', type=float, default=0.2)

    # dataset split scenario
    parser.add_argument('--split_type', help='type of an expriment to conduct: iid|pathological|dirichlet|realistic',
                        type=str, default="realistic")

    # federated learning arguments
    parser.add_argument('--algorithm', help='type of an algorithm to use', type=str, default="fedavg",
                        choices=['fedavg', 'fedprox', 'scaffold', 'lg-fedavg', 'fedper', 'fedrep', 'ditto', 'apfl',
                                 'pfedme', 'superfed-mm', 'superfed-lm'])

    # optimization related arguments
    parser.add_argument('--optimizer', help='type of optimization method (should be a module of `torch.optim`)',
                        type=str, default='SGD')
    parser.add_argument('--criterion',
                        help='type of criterion for objective function (should be a module of `torch.nn`)', type=str,
                        default='CrossEntropyLoss')
    parser.add_argument('--tau', help='constant for fine tuning head or updating a local model (for fedrep, ditto)',
                        type=int, default=5)

    # model related arguments
    parser.add_argument('--model_name', help='model to use [TwoNN|TwoCNN|NextCharLM|ResNet9|MobileNet|VGG9]', type=str,
                        default="TwoNN", choices=['TwoNN', 'TwoCNN', 'NextCharLM', 'ResNet9', 'MobileNet', 'VGG9'])
    parser.add_argument('--fc_type', help='type of fully connected layer', type=str,
                        choices=['StandardLinear', 'LinesLinear'], default='StandardLinear')
    # default arguments
    parser.add_argument('--exp_name', help='experiment name', type=str, default="femnist")
    parser.add_argument('--data_path', help='data path', type=str, default='./data')
    parser.add_argument('--log_path', help='log path', type=str, default='./log')
    parser.add_argument('--result_path', help='result path', type=str, default='./result')
    parser.add_argument('--plot_path', help='plot path', type=str, default='./plot')
    parser.add_argument('--n_jobs', help='workeres for multiprocessing', type=int, default=1)

    args = parser.parse_args()
    args.beta = [float(x) for x in args.beta.strip('[]').split(',')]
    args.gamma = [float(x) for x in args.gamma.strip('[]').split(',')]
    args.device = torch.device('cuda:{}'.format(
        args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
    return args
