import argparse
import re

# using GPU if available
# device = torch.device("cpu")


def process_args():
    parser = argparse.ArgumentParser(
        description=
        'non-negative / unbiased PU learning Pytorch implementation',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--seed', type=int, default=0, help='seed')
    parser.add_argument('--batchsize',
                        '-b',
                        type=int,
                        default=30000,
                        help='Mini batch size')
    parser.add_argument('--gpu',
                        default=0,
                        type=int,
                        help='id(s) for CUDA_VISIBLE_DEVICES')
    parser.add_argument('--dataset',
                        '-d',
                        default='cifar10',
                        type=str,
                        choices=['mnist', 'cifar10', 'imdb', 'yelp', 'amazon', 'alzheimer', 'stl10', 'fashionmnist'],
                        help='The dataset name')
    parser.add_argument('--positive-flag', default=1, type=int, choices=[1, 2])
    parser.add_argument('--positive-label-list',
                        type=lambda s: eval(s) if isinstance(s, str) else s,
                        default=[0, 1, 2, 3, 4, 5])
    parser.add_argument('--prior', type=float, default=0.5)
    parser.add_argument('--labeled',
                        '-l',
                        default=100,
                        type=int,
                        help='# of labeled data')
    parser.add_argument('--unlabeled',
                        '-u',
                        default=59900,
                        type=int,
                        help='# of unlabeled data')
    parser.add_argument('--epoch',
                        '-e',
                        default=500,
                        type=int,
                        help='# of epochs to learn')
    parser.add_argument('--val-iterations',
                        '-v',
                        default=20,
                        type=int,
                        help='# of iterations per-epoch')
    parser.add_argument('--beta',
                        '-B',
                        default=0.,
                        type=float,
                        help='Beta parameter of nnPU')
    parser.add_argument('--gamma',
                        '-G',
                        default=1.,
                        type=float,
                        help='Gamma parameter of nnPU')
    parser.add_argument('--loss-gamma',
                        '-GL',
                        default=1,
                        type=int,
                        help='Gamma parameter of Loss')
    parser.add_argument('--loss-tau',
                        '-TL',
                        default=0.5,
                        type=float,
                        help='Tau parameter of Loss')
    parser.add_argument('--loss-eta',
                        '-EL',
                        default=1.0,
                        type=float,
                        help='Eta parameter of Loss')
    parser.add_argument('--loss',
                        type=str,
                        default="sigmoid",
                        choices=[
                            'hinge', 'squared-hinge', 'logistic', 'sigmoid',
                            'focal', 'exponential', 'perceptron', 'zero-one',
                            'squared', 'tangent', 'savage', 'unhinged',
                            'gsigmoid', 'rescaled-hinge', 'pinball', 'double-hinge',
                            'ramp', 'smooth-hinge', 'modified-huber', 'mine'
                        ],
                        help='The name of a loss function')
    parser.add_argument('--model',
                        '-m',
                        default='3mlp',
                        choices=[
                            'mlp',
                            'cnn_cifar',
                            'cnn_stl',
                            'lenet5',
                            'resnet50',
                        ],
                        help='The name of a classification model')
    parser.add_argument('--optim',
                        '-O',
                        default='adam',
                        choices=[
                            'adam',
                            'sgd',
                            'adagrad',
                        ],
                        help='The name of a optimizer')
    parser.add_argument('--momentum',
                        '-mom',
                        default=0,
                        type=float,
                        help='momentum')
    parser.add_argument('--balance-weight',
                        '-bw',
                        default=1e-3,
                        type=float,
                        help='balance weight for BalancePU and PULB+BalancePU')
    parser.add_argument('--lam-f',
                        '-lamf',
                        default=1e-2,
                        type=float,
                        help='lam_f para for FOPU')
    parser.add_argument('--nesterov',
                        '-nest',
                        action='store_true',
                        default=True,
                        help='use nesterov momentum')
    parser.add_argument('--stepsize',
                        '-s',
                        default=1e-3,
                        type=float,
                        help='Stepsize of gradient method')
    parser.add_argument('--weight-decay',
                        '-w',
                        default=1e-6,
                        type=float,
                        help='Weight decay of gradient method')

    parser.add_argument('--milestones',
                        nargs='+',
                        default=[50, 100],
                        type=int,
                        help='milestones for scheduler')
    parser.add_argument('--scheduler-gamma',
                        default=0.1,
                        type=float,
                        help='gamma for scheduler')

    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Directory to output the result')
    parser.add_argument('--num-classifier',
                        '-numc',
                        type=int,
                        default=2)
    parser.add_argument('--m-percent',
                        '-mp',
                        type=float,
                        default=1)

    # parameters for positive loss
    parser.add_argument('--positive-weight',
                        default=1.0,
                        type=float,
                        help='coefficient for positive loss')
    parser.add_argument('--positive-rampdown-starts',
                        default=200,
                        type=int,
                        metavar='EPOCHS',
                        help='epoch at which positive loss ramp-down starts')
    parser.add_argument('--positive-rampdown-ends',
                        default=200,
                        type=int,
                        metavar='EPOCHS',
                        help='epoch at which positive loss ramp-down ends')

    # parameters for unlabeled loss
    parser.add_argument('--unlabeled-weight',
                        default=1.0,
                        type=float,
                        help='coefficient for unlabeled loss')
    parser.add_argument('--unlabeled-rampdown-starts',
                        default=200,
                        type=int,
                        metavar='EPOCHS',
                        help='epoch at which unlabeled loss ramp-down starts')
    parser.add_argument('--unlabeled-rampdown-ends',
                        default=200,
                        type=int,
                        metavar='EPOCHS',
                        help='epoch at which unlabeled loss ramp-down ends')

    # parameters for pseudo-labels
    parser.add_argument('--target-weight', default=1.0, type=float)
    parser.add_argument('--target-rampup-starts',
                        default=200,
                        type=int,
                        metavar='EPOCHS')
    parser.add_argument('--target-rampup-ends',
                        default=200,
                        type=int,
                        metavar='EPOCHS')

    # parameters for mean-teacher
    parser.add_argument('--mean-teacher', default=True, action='store_true')
    parser.add_argument('--ema-update', default=False, action='store_true')
    parser.add_argument('--ema-start', type=int, default=0)
    parser.add_argument('--ema-end', type=int, default=100)
    parser.add_argument('--ema-step', type=int, default=40000)
    parser.add_argument('--ema-decay', type=float, default=0.997)

    # parameter for entropy regularization
    parser.add_argument('--entropy-weight',
                        default=0.1,
                        type=float,
                        help='coefficient for entropy regularization')

    parser.add_argument('--elr-weight',
                        default=0.,
                        type=float,
                        help='coefficient for early-learning regularization')
    parser.add_argument(
        '--elr-rampdown-starts',
        default=200,
        type=int,
        metavar='EPOCHS',
        help='epoch at which elr regularization ramp-down starts')
    parser.add_argument(
        '--elr-rampdown-ends',
        default=200,
        type=int,
        metavar='EPOCHS',
        help='epoch at which elr regularization ramp-down ends')

    # parameters for mixup
    parser.add_argument('--mix-option',
                        default=True,
                        action='store_true',
                        help='mix option, whether to mix or not')
    parser.add_argument(
        '--separate-mix',
        default=False,
        action='store_true',
        help='mix separate from labeled data and unlabeled data')
    # parser.add_argument('--mix-layers-set',
    #                     nargs='+',
    #                     default=[0, 1, 2, 3],
    #                     type=int,
    #                     help='define mix layer set')
    parser.add_argument(
        '--mix-layer',
        default=-1,
        type=int,
        help='number of layers on which mixup is applied including input layer'
    )
    parser.add_argument('--alpha',
                        default=1.0,
                        type=float,
                        help='alpha for beta distribution')

    # parameters for heuristic mixup
    parser.add_argument('--h-positive', type=int, default=100)
    parser.add_argument('--start-hmix', type=int, default=20)
    parser.add_argument('--p-upper', type=float, default=0.6)
    parser.add_argument('--p-lower', type=float, default=0.4)

    parser.add_argument('--no-progress',
                        action='store_true',
                        help="don't use progress bar")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="For distributed training: local_rank")

    parser.add_argument(
        '--preset',
        '-p',
        type=str,
        default='cifar10',
        choices=[
            'mnist-1', 'mnist-2', 'fmnist-1','fmnist-2',
            'cifar10-1','cifar10-2', 'stl10-1', 'stl10-2',
            'alzheimer', 'fmnist-3','fmnist-4',
            'cifar10-prior0.1', 'cifar10-prior0.2',
            'fmnist-prior0.1', 'fmnist-prior0.2',
            'mnist-prior0.1', 'mnist-prior0.2',
            'cifar10-prior0.8', 'cifar10-prior0.9',
            'fmnist-prior0.8', 'fmnist-prior0.9',
            'mnist-prior0.8', 'mnist-prior0.9'
        ],
        help="Preset of configuration\n" + "mnist: The setting of Figure1\n" +
        "mnist-6mlp: The setting of MNIST experiment in Experiment\n" +
        "cifar10: The setting of CIFAR10 experiment in Experiment")

    parser.add_argument(
        '--method',
        '-method',
        type=str,
        default='nnPU-objective',
        choices=[
            'uPU',
            'nnPU-objective', 'nnPU-out',
            'absPU',
            'DistPU-brief', 'DistPU',
            'FOPU',
            'PULB', 'PULB2',
            'BalancePU', 'BalancePU2',
            'PULB+BalancePU', 'PULB2+BalancePU2',
            'AdaptivePU', 'CurvatureAdjustedPU', 'AdaptiveSmoothPU', 'CompositePU', 'AdaptiveRiskPU', 'AdaptiveWeightPU', 'InstanceWeightedPU',
            'PULLP'
        ])

    parser.add_argument('--co-mu', type=float, default=2e-3, help='coefficient of L_ent')
    parser.add_argument('--alpha-mix', type=float, default=6.0)

    parser.add_argument('--co-entropy', type=float, default=0.004)
    parser.add_argument('--co-mix-entropy', type=float, default=0.04)
    parser.add_argument('--co-mixup', type=float, default=5.0)

    parser.add_argument('--a0', type=float, default=2.0)
    parser.add_argument('--b0', type=float, default=1.0)
    parser.add_argument('--b1', type=float, default=0.8)
    parser.add_argument('--k1', type=float, default=8.0)
    parser.add_argument('--k2', type=float, default=8.0)

    parser.add_argument('--c-alpha', type=float, default=2.0)
    parser.add_argument('--c-delta', type=float, default=2.5)

    parser.add_argument('--n-bag', type=int, default=1)

    parser.add_argument('--threshold-strategy', type=str, default='fixed', choices=['fixed', 'adaptive', 'class_adaptive'])


    # train_soft_MA.py
    parser.add_argument('--use-pico-update', type=int, default=0)


    # train_soft_TS.py
    parser.add_argument('--TS-mean-teacher', type=int, default=0)
    # 一致性损失相关参数
    parser.add_argument('--consistency', type=float, default=0.3,
                        help='一致性损失的权重系数')
    parser.add_argument('--consistency-rampup', type=int, default=400,
                        help='一致性损失的rampup周期')
    parser.add_argument('--consistency-type', type=str, default='mse',
                        choices=['mse', 'mae', 'smooth_l1', 'kl', 'cosine'],
                        help='一致性损失函数类型')
    # 教师网络软标签相关参数
    parser.add_argument('--use-teacher-soft-labels', action='store_true', default=True,
                        help='是否使用教师网络的软标签')
    parser.add_argument('--use-ema-for-test', action='store_true', default=True,
                        help='是否在测试时使用EMA模型')
    # 一致性约束相关参数
    parser.add_argument('--TS-mixup-consistency', action='store_true', default=False,
                        help='是否使用特殊的一致性损失')
    parser.add_argument('--TS-mixup-consistency-type', type=str, default='mse',
                        choices=['mse', 'kl'],
                        help='一致性损失类型')
    parser.add_argument('--TS-mixup-alpha', type=float, default=0.1,
                        help='一致性损失的阈值')
    parser.add_argument('--TS-mixup-weight', type=float, default=0.3,
                        help='一致性损失的权重')
    # 阈值策略的其他参数
    parser.add_argument('--soft-label-warmup', type=int, default=0,
                        help='软标签预热期')


    # train_soft_mixup.py
    parser.add_argument('--mixup-alpha', type=float, default=1.0,
                        help='一致性损失的阈值')
    parser.add_argument('--mixup-prob', type=float, default=0.5,
                        help='应用mixup的概率')
    parser.add_argument('--mixup-weight', type=float, default=0.3,
                        help='一致性损失的权重')

    # warm-up
    parser.add_argument('--warmup-epochs', type=int, default=25)


    args = parser.parse_args()
    if args.preset == "mnist-1":
        # args.labeled = 1000
        args.unlabeled = 60000
        args.dataset = "mnist"
        args.positive_flag = 1
        args.positive_label_list = [0, 2, 4, 6, 8]
        args.prior = 0.5
        # args.batchsize = 512
        # args.model = "LeNet"
        # args.stepsize = 0.002
        # args.weight_decay = 3e-5
        args.positive_weight = 1.
        args.elr_weight = 0.
        args.scheduler_gamma = .25
        args.milestones = [20, 40, 60, 80]
        args.start_hmix = 5
        args.p_lower = 0.05
        args.p_upper = 0.95
    elif args.preset == "mnist-2":
        # args.labeled = 1000
        args.unlabeled = 60000
        args.dataset = "mnist"
        args.positive_flag = 2
        args.positive_label_list = [1, 3, 5, 7, 9]
        args.prior = 0.5
        # args.batchsize = 512
        # args.model = "LeNet"
        # args.stepsize = 0.002
        # args.weight_decay = 3e-5
        args.positive_weight = 1.
        args.unlabeled_weight = 0.9
        args.elr_weight = 0.
        args.scheduler_gamma = .5
        args.milestones = [20, 40, 60, 80, 100]
        args.start_hmix = 5
        args.p_lower = 0.25
        args.p_upper = 0.75
    elif args.preset == "fmnist-1":
        # args.labeled = 1000
        args.unlabeled = 60000
        args.dataset = "fashionmnist"
        args.positive_flag = 1
        args.positive_label_list = [0, 2, 4, 6]
        args.prior = 0.4
        # args.batchsize = 512
        # args.model = "LeNet"
        # args.stepsize = 0.002
        # args.weight_decay = 3e-5
        args.positive_weight = 1.
        args.elr_weight = 0.
        args.scheduler_gamma = .25
        args.milestones = [20, 40, 60, 80]
        args.start_hmix = 5
        args.p_lower = 0.05
        args.p_upper = 0.95
    elif args.preset == "fmnist-2":
        # args.labeled = 1000
        args.unlabeled = 60000
        args.dataset = "fashionmnist"
        args.positive_flag = 2
        args.positive_label_list = [1, 3, 5, 7, 8, 9]
        args.prior = 0.6
        # args.batchsize = 512
        # args.model = "LeNet"
        # args.stepsize = 0.002
        # args.weight_decay = 3e-5
        args.positive_weight = 1.
        args.unlabeled_weight = 0.9
        args.elr_weight = 0.
        args.scheduler_gamma = .5
        args.milestones = [20, 40, 60, 80, 100]
        args.start_hmix = 5
        args.p_lower = 0.25
        args.p_upper = 0.75
    elif args.preset == "fmnist-3":
        # args.labeled = 1000
        args.unlabeled = 60000
        args.dataset = "fashionmnist"
        args.positive_flag = 1
        args.positive_label_list = [1, 4, 7]
        args.prior = 0.3
        # args.batchsize = 512
        # args.model = "LeNet"
        # args.stepsize = 0.002
        # args.weight_decay = 3e-5
        args.positive_weight = 1.
        args.elr_weight = 0.
        args.scheduler_gamma = .25
        args.milestones = [20, 40, 60, 80]
        args.start_hmix = 5
        args.p_lower = 0.05
        args.p_upper = 0.95
    elif args.preset == "fmnist-4":
        # args.labeled = 1000
        args.unlabeled = 60000
        args.dataset = "fashionmnist"
        args.positive_flag = 2
        args.positive_label_list = [0, 2, 3, 5, 6, 8, 9]
        args.prior = 0.7
        # args.batchsize = 512
        # args.model = "LeNet"
        # args.stepsize = 0.002
        # args.weight_decay = 3e-5
        args.positive_weight = 1.
        args.unlabeled_weight = 0.9
        args.elr_weight = 0.
        args.scheduler_gamma = .5
        args.milestones = [20, 40, 60, 80, 100]
        args.start_hmix = 5
        args.p_lower = 0.25
        args.p_upper = 0.75
    elif args.preset == "cifar10-1":
        # args.labeled = 1000
        args.unlabeled = 50000
        args.dataset = "cifar10"
        args.positive_flag = 1
        args.positive_label_list = [0, 1, 8, 9]
        args.prior = 0.4
        # args.batchsize = 512
        # args.model = "CNN_CIFAR"
        # args.stepsize = 0.0015
        # args.weight_decay = 3e-5
        args.positive_weight = 1.
        args.elr_weight = 0.
        args.scheduler_gamma = .5
        args.milestones = [20, 40, 60, 80, 100]
        args.mix_layer = 3
        args.start_hmix = 10
        args.p_lower = 0.2
        args.p_upper = 0.8
    elif args.preset == "cifar10-2":
        # args.labeled = 1000
        args.unlabeled = 50000
        args.dataset = "cifar10"
        args.positive_flag = 2
        args.positive_label_list = [2, 3, 4, 5, 6, 7]
        args.prior = 0.6
        # args.batchsize = 512
        # args.model = "CNN_CIFAR"
        # args.stepsize = 0.0015
        # args.weight_decay = 3e-5
        args.positive_weight = 1.
        args.elr_weight = 0.
        args.scheduler_gamma = .25
        args.milestones = [20, 40, 60, 80, 100]
        args.mix_layer = 3
        args.start_hmix = 10
        args.p_lower = 0.2
        args.p_upper = 0.8
    elif args.preset == "stl10-1":
        # args.labeled = 1000
        args.unlabeled = 105000
        args.dataset = "stl10"
        args.positive_flag = 1
        args.positive_label_list = [0, 2, 3, 8, 9]
        args.prior = 0.506
        # args.batchsize = 512
        # args.model = "CNN_STL"
        # args.stepsize = 1e-4
        # args.weight_decay = 3e-5
        args.positive_weight = 1.
        args.elr_weight = 0.
        args.scheduler_gamma = .5
        args.milestones = [20, 40, 60, 80]
        args.mix_layer = 3
        args.start_hmix = 10
        args.p_lower = 0.2
        args.p_upper = 0.8
    elif args.preset == "stl10-2":
        # args.labeled = 1000
        args.unlabeled = 105000
        args.dataset = "stl10"
        args.positive_flag = 2
        args.positive_label_list = [1, 4, 5, 6, 7]
        args.prior = 0.494
        # args.batchsize = 512
        # args.model = "cnnstl"
        # args.stepsize = 3e-4
        # args.weight_decay = 3e-5
        args.positive_weight = 1.
        args.unlabeled_weight = .8
        args.elr_weight = 0.
        args.scheduler_gamma = .3
        args.milestones = [20, 40, 60, 80]
        args.mix_layer = 3
        args.start_hmix = 10
        args.p_lower = 0.15
        args.p_upper = 0.85
    elif args.preset == "alzheimer":
        # args.labeled = 1000
        args.unlabeled = 5121
        args.dataset = "alzheimer"
        args.positive_flag = 2
        args.positive_label_list = [0, 1, 3]
        args.prior = 0.5
        # args.batchsize = 512
        # args.model = "CNN_CIFAR"
        # args.stepsize = 0.0015
        # args.weight_decay = 3e-5
        args.positive_weight = 1.
        args.elr_weight = 0.
        args.scheduler_gamma = .25
        args.milestones = [20, 40, 60, 80, 100]
        args.mix_layer = 3
        args.start_hmix = 10
        args.p_lower = 0.2
        args.p_upper = 0.8
    elif "cifar10-prior" in args.preset:
        # args.labeled = 1000
        args.unlabeled = 50000
        args.dataset = "cifar10"
        # args.positive_label_list = [0, 1, 8, 9]
        match = re.search(r"prior([0-9.]+)", args.preset)
        if match:
            args.prior = float(match.group(1))
    elif "fmnist-prior" in args.preset:
        args.unlabeled = 60000
        args.dataset = "fashionmnist"
        args.positive_flag = 1
        # args.positive_label_list = [0, 2, 4, 6]
        match = re.search(r"prior([0-9.]+)", args.preset)
        if match:
            args.prior = float(match.group(1))
    elif "mnist-prior" in args.preset:
        args.unlabeled = 60000
        args.dataset = "mnist"
        args.positive_flag = 1
        # args.positive_label_list = [0, 2, 4, 6]
        match = re.search(r"prior([0-9.]+)", args.preset)
        if match:
            args.prior = float(match.group(1))

    assert (args.batchsize > 0)
    assert (args.epoch > 0)
    assert (0 < args.labeled < 30000)
    if (args.dataset == "mnist") or (args.dataset == "fashionmnist"):
        assert (0 < args.unlabeled <= 60000)
    elif args.dataset == "cifar10":
        assert (0 < args.unlabeled <= 50000)
    elif args.dataset == "stl10":
        assert (0 < args.unlabeled <= 105000)
    assert (0. <= args.beta)
    assert (0. <= args.gamma <= 1.)
    return args
