import argparse
import re

# using GPU if available
# device = torch.device("cpu")


def process_args():
    parser = argparse.ArgumentParser(
        description=
        'non-negative / unbiased PU learning Pytorch implementation',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--p-quantities', type=int, nargs='+',
                        default=[1000, 500, 200, 100, 50, 30, 20, 10, 5],
                        help='List of P quantities to test')
    parser.add_argument('--n-runs', type=int, default=1,
                        help='Number of runs per P quantity')
    parser.add_argument('--feature-eval-freq', type=int, default=10)

    parser.add_argument('--seed', type=int, default=0, help='seed')
    parser.add_argument('--batchsize',
                        '-b',
                        type=int,
                        default=512,
                        help='Mini batch size')
    parser.add_argument('--gpu',
                        default=0,
                        type=int,
                        help='id(s) for CUDA_VISIBLE_DEVICES')
    parser.add_argument('--dataset',
                        '-d',
                        default='cifar10',
                        type=str,
                        choices=['mnist', 'cifar10', 'imdb', 'yelp', 'amazon', 'alzheimer', 'stl10', 'fashionmnist', 'imagenet'],
                        help='The dataset name')
    parser.add_argument('--positive-flag', default=1, type=int, choices=[1, 2])
    parser.add_argument('--positive-label-list',
                        type=lambda s: eval(s) if isinstance(s, str) else s,
                        default=[0, 1, 2, 3, 4, 5])
    parser.add_argument('--prior', type=float, default=0.5)
    parser.add_argument('--labeled',
                        '-l',
                        default=100,
                        type=int,
                        help='# of labeled data')
    parser.add_argument('--unlabeled',
                        '-u',
                        default=59900,
                        type=int,
                        help='# of unlabeled data')
    parser.add_argument('--epoch',
                        '-e',
                        default=100,
                        type=int,
                        help='# of epochs to learn')
    parser.add_argument('--val-iterations',
                        '-v',
                        default=20,
                        type=int,
                        help='# of iterations per-epoch')
    parser.add_argument('--beta',
                        '-B',
                        default=0.,
                        type=float,
                        help='Beta parameter of nnPU')
    parser.add_argument('--gamma',
                        '-G',
                        default=1.,
                        type=float,
                        help='Gamma parameter of nnPU')
    parser.add_argument('--loss-gamma',
                        '-GL',
                        default=1,
                        type=int,
                        help='Gamma parameter of Loss')
    parser.add_argument('--loss-tau',
                        '-TL',
                        default=0.5,
                        type=float,
                        help='Tau parameter of Loss')
    parser.add_argument('--loss-eta',
                        '-EL',
                        default=1.0,
                        type=float,
                        help='Eta parameter of Loss')
    parser.add_argument('--loss',
                        type=str,
                        default="sigmoid",
                        choices=[
                            'hinge', 'squared-hinge', 'logistic', 'sigmoid',
                            'focal', 'exponential', 'perceptron', 'zero-one',
                            'squared', 'tangent', 'savage', 'unhinged',
                            'gsigmoid', 'rescaled-hinge', 'pinball', 'double-hinge',
                            'ramp', 'smooth-hinge', 'modified-huber', 'mine'
                        ],
                        help='The name of a loss function')
    parser.add_argument('--model',
                        '-m',
                        default='3mlp',
                        choices=[
                            'mlp',
                            'cnn_cifar',
                            'cnn_stl',
                            'lenet5',
                            'resnet50',
                        ],
                        help='The name of a classification model')
    parser.add_argument('--optim',
                        '-O',
                        default='adam',
                        choices=[
                            'adam',
                            'sgd',
                            'adagrad',
                        ],
                        help='The name of a optimizer')
    parser.add_argument('--momentum',
                        '-mom',
                        default=0,
                        type=float,
                        help='momentum')
    parser.add_argument('--balance-weight',
                        '-bw',
                        default=1e-3,
                        type=float,
                        help='balance weight for BalancePU and PULB+BalancePU')
    parser.add_argument('--lam-f',
                        '-lamf',
                        default=1e-2,
                        type=float,
                        help='lam_f para for FOPU')
    parser.add_argument('--nesterov',
                        '-nest',
                        action='store_true',
                        default=True,
                        help='use nesterov momentum')
    parser.add_argument('--stepsize',
                        '-s',
                        default=1e-3,
                        type=float,
                        help='Stepsize of gradient method')
    parser.add_argument('--weight-decay',
                        '-w',
                        default=1e-6,
                        type=float,
                        help='Weight decay of gradient method')

    parser.add_argument('--milestones',
                        nargs='+',
                        default=[50, 100],
                        type=int,
                        help='milestones for scheduler')
    parser.add_argument('--scheduler-gamma',
                        default=0.1,
                        type=float,
                        help='gamma for scheduler')

    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Directory to output the result')
    parser.add_argument('--num-classifier',
                        '-numc',
                        type=int,
                        default=2)
    parser.add_argument('--m-percent',
                        '-mp',
                        type=float,
                        default=1)

    # parameters for positive loss
    parser.add_argument('--positive-weight',
                        default=1.0,
                        type=float,
                        help='coefficient for positive loss')
    parser.add_argument('--positive-rampdown-starts',
                        default=200,
                        type=int,
                        metavar='EPOCHS',
                        help='epoch at which positive loss ramp-down starts')
    parser.add_argument('--positive-rampdown-ends',
                        default=200,
                        type=int,
                        metavar='EPOCHS',
                        help='epoch at which positive loss ramp-down ends')

    # parameters for unlabeled loss
    parser.add_argument('--unlabeled-weight',
                        default=1.0,
                        type=float,
                        help='coefficient for unlabeled loss')
    parser.add_argument('--unlabeled-rampdown-starts',
                        default=200,
                        type=int,
                        metavar='EPOCHS',
                        help='epoch at which unlabeled loss ramp-down starts')
    parser.add_argument('--unlabeled-rampdown-ends',
                        default=200,
                        type=int,
                        metavar='EPOCHS',
                        help='epoch at which unlabeled loss ramp-down ends')

    # parameters for pseudo-labels
    parser.add_argument('--target-weight', default=1.0, type=float)
    parser.add_argument('--target-rampup-starts',
                        default=200,
                        type=int,
                        metavar='EPOCHS')
    parser.add_argument('--target-rampup-ends',
                        default=200,
                        type=int,
                        metavar='EPOCHS')

    # parameters for mean-teacher
    parser.add_argument('--mean-teacher', default=True, action='store_true')
    parser.add_argument('--ema-update', default=False, action='store_true')
    parser.add_argument('--ema-start', type=int, default=0)
    parser.add_argument('--ema-end', type=int, default=100)
    parser.add_argument('--ema-step', type=int, default=40000)
    parser.add_argument('--ema-decay', type=float, default=0.997)

    # parameter for entropy regularization
    parser.add_argument('--entropy-weight',
                        default=0.1,
                        type=float,
                        help='coefficient for entropy regularization')

    parser.add_argument('--elr-weight',
                        default=0.,
                        type=float,
                        help='coefficient for early-learning regularization')
    parser.add_argument(
        '--elr-rampdown-starts',
        default=200,
        type=int,
        metavar='EPOCHS',
        help='epoch at which elr regularization ramp-down starts')
    parser.add_argument(
        '--elr-rampdown-ends',
        default=200,
        type=int,
        metavar='EPOCHS',
        help='epoch at which elr regularization ramp-down ends')

    # parameters for mixup
    parser.add_argument('--mix-option',
                        default=True,
                        action='store_true',
                        help='mix option, whether to mix or not')
    parser.add_argument(
        '--separate-mix',
        default=False,
        action='store_true',
        help='mix separate from labeled data and unlabeled data')
    # parser.add_argument('--mix-layers-set',
    #                     nargs='+',
    #                     default=[0, 1, 2, 3],
    #                     type=int,
    #                     help='define mix layer set')
    parser.add_argument(
        '--mix-layer',
        default=-1,
        type=int,
        help='number of layers on which mixup is applied including input layer'
    )
    parser.add_argument('--alpha',
                        default=1.0,
                        type=float,
                        help='alpha for beta distribution')

    # parameters for heuristic mixup
    parser.add_argument('--h-positive', type=int, default=100)
    parser.add_argument('--start-hmix', type=int, default=20)
    parser.add_argument('--p-upper', type=float, default=0.6)
    parser.add_argument('--p-lower', type=float, default=0.4)

    parser.add_argument('--no-progress',
                        action='store_true',
                        help="don't use progress bar")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="For distributed training: local_rank")

    parser.add_argument(
        '--preset',
        '-p',
        type=str,
        default='cifar10',
        # choices=[
        #     'mnist-1', 'mnist-2', 'fmnist-1','fmnist-2',
        #     'cifar10-1','cifar10-2', 'stl10-1', 'stl10-2',
        #     'alzheimer', 'fmnist-3','fmnist-4',
        #     'cifar10-prior0.1', 'cifar10-prior0.2',
        #     'fmnist-prior0.1', 'fmnist-prior0.2',
        #     'mnist-prior0.1', 'mnist-prior0.2',
        #     'cifar10-prior0.8', 'cifar10-prior0.9',
        #     'fmnist-prior0.8', 'fmnist-prior0.9',
        #     'mnist-prior0.8', 'mnist-prior0.9',
        #     'fmnist-prior0.3', 'fmnist-prior0.4', 'fmnist-prior0.5', 'fmnist-prior0.6'
        # ],
        help="Preset of configuration\n" + "mnist: The setting of Figure1\n" +
        "mnist-6mlp: The setting of MNIST experiment in Experiment\n" +
        "cifar10: The setting of CIFAR10 experiment in Experiment")

    parser.add_argument(
        '--method',
        '-method',
        type=str,
        default='nnPU-objective',
        choices=[
            'uPU',
            'nnPU-objective', 'nnPU-out',
            'absPU',
            'DistPU-brief', 'DistPU',
            'PULB',
            'BalancePU',
            'PULB+BalancePU',
            'ScalePU-brief',
            'ScalePU',
            'ScalePU-VarianceLambda',
            'ScalePU-Focal',
            'DistPU-VarianceLambda',
            'absPU-VarianceLambda',
            'ScalePU-DVAW',
            'ScalePU-FocalDVAW',
            'ScalePU-SoftPseudo',
            'ScalePU-SoftPseudo2',
            'ScalePU-Geometric'
        ])

    parser.add_argument('--co-mu', type=float, default=2e-3, help='coefficient of L_ent')
    parser.add_argument('--alpha-mix', type=float, default=6.0)

    parser.add_argument('--co-entropy', type=float, default=0.004)
    parser.add_argument('--co-mix-entropy', type=float, default=0.04)
    parser.add_argument('--co-mixup', type=float, default=5.0)

    parser.add_argument('--a0', type=float, default=2.0)
    parser.add_argument('--b0', type=float, default=1.0)
    parser.add_argument('--b1', type=float, default=0.8)
    parser.add_argument('--k1', type=float, default=8.0)
    parser.add_argument('--k2', type=float, default=8.0)

    parser.add_argument('--c-alpha', type=float, default=2.0)
    parser.add_argument('--c-delta', type=float, default=2.5)

    parser.add_argument('--n-bag', type=int, default=1)

    parser.add_argument('--threshold-strategy', type=str, default='fixed', choices=['fixed', 'adaptive', 'class_adaptive'])


    # train_soft_MA.py
    parser.add_argument('--use-pico-update', type=int, default=0)


    # train_soft_TS.py
    parser.add_argument('--TS-mean-teacher', type=int, default=0)
    # 一致性损失相关参数
    parser.add_argument('--consistency', type=float, default=0.3,
                        help='一致性损失的权重系数')
    parser.add_argument('--consistency-rampup', type=int, default=400,
                        help='一致性损失的rampup周期')
    parser.add_argument('--consistency-type', type=str, default='mse',
                        choices=['mse', 'mae', 'smooth_l1', 'kl', 'cosine'],
                        help='一致性损失函数类型')
    # 教师网络软标签相关参数
    parser.add_argument('--use-teacher-soft-labels', action='store_true', default=True,
                        help='是否使用教师网络的软标签')
    parser.add_argument('--use-ema-for-test', action='store_true', default=True,
                        help='是否在测试时使用EMA模型')
    # 一致性约束相关参数
    parser.add_argument('--TS-mixup-consistency', action='store_true', default=False,
                        help='是否使用特殊的一致性损失')
    parser.add_argument('--TS-mixup-consistency-type', type=str, default='mse',
                        choices=['mse', 'kl'],
                        help='一致性损失类型')
    parser.add_argument('--TS-mixup-alpha', type=float, default=0.1,
                        help='一致性损失的阈值')
    parser.add_argument('--TS-mixup-weight', type=float, default=0.3,
                        help='一致性损失的权重')
    # 阈值策略的其他参数
    parser.add_argument('--soft-label-warmup', type=int, default=0,
                        help='软标签预热期')


    # train_soft_mixup.py
    parser.add_argument('--mixup-alpha', type=float, default=1.0,
                        help='一致性损失的阈值')
    parser.add_argument('--mixup-prob', type=float, default=0.5,
                        help='应用mixup的概率')
    parser.add_argument('--mixup-weight', type=float, default=0.3,
                        help='一致性损失的权重')

    # warm-up
    parser.add_argument('--warmup-epochs', type=int, default=25)

    # ScalePU
    parser.add_argument('--lambda-reg', type=float, default=0.1)

    # ScalePU-VarianceLambda
    parser.add_argument('--var-threshold', type=float, default=0.1)

    # ScalePU-Focal
    parser.add_argument('--focal-mode', type=str, default='topk')
    parser.add_argument('--focal-ratio', type=float, default=0.9)
    # parser.add_argument('--clip-quantile', type=str, default=None)

    # ScalePU-SoftPseudo
    parser.add_argument('--tau-init', type=float, default=0.7,
                        help='Initial confidence threshold for pseudo-labeling')
    parser.add_argument('--tau-target', type=float, default=0.9,
                        help='Target confidence threshold after warmup')

    # Holistic-PU
    parser.add_argument('--warming-epochs', default=30, type=int, help='Warming phase epochs')
    parser.add_argument('--ft-epochs', default=70, type=int, help='Fine-tuning phase epochs')
    parser.add_argument('--holistic-alpha', default=2, type=float, help='Scaling parameter for trend score')
    parser.add_argument('--label-smoothing', default=0.1, type=float, help='Label smoothing parameter')

    # 几何正则化参数
    parser.add_argument('--gamma_geo', type=float, default=0.01,
                        help='Geometric regularization strength')
    parser.add_argument('--beta_sep', type=float, default=1.0,
                        help='Separation weight for geometric regularization')
    parser.add_argument('--margin', type=float, default=1.0,
                        help='Separation margin')

    # 元学习参数（版本2需要）
    parser.add_argument('--use_meta_learning', action='store_true', default=False,
                        help='Use meta-learning to optimize hyperparameters')
    parser.add_argument('--meta_iterations', type=int, default=50,
                        help='Number of meta-learning iterations')
    parser.add_argument('--inner_steps', type=int, default=3,
                        help='Inner loop training steps for meta-learning')
    parser.add_argument('--inner_lr', type=float, default=0.01,
                        help='Learning rate for inner loop')
    parser.add_argument('--meta_lr', type=float, default=0.01,
                        help='Learning rate for meta-parameter optimization')



    args = parser.parse_args()
    if args.preset == "mnist-1":
        # args.labeled = 1000
        args.unlabeled = 60000
        args.dataset = "mnist"
        args.positive_flag = 1
        args.positive_label_list = [0, 2, 4, 6, 8]
        args.prior = 0.5
        # args.batchsize = 512
        # args.model = "LeNet"
        # args.stepsize = 0.002
        # args.weight_decay = 3e-5
        args.positive_weight = 1.
        args.elr_weight = 0.
        args.scheduler_gamma = .25
        args.milestones = [20, 40, 60, 80]
        args.start_hmix = 5
        args.p_lower = 0.05
        args.p_upper = 0.95
    elif args.preset == "mnist-2":
        # args.labeled = 1000
        args.unlabeled = 60000
        args.dataset = "mnist"
        args.positive_flag = 2
        args.positive_label_list = [1, 3, 5, 7, 9]
        args.prior = 0.5
        # args.batchsize = 512
        # args.model = "LeNet"
        # args.stepsize = 0.002
        # args.weight_decay = 3e-5
        args.positive_weight = 1.
        args.unlabeled_weight = 0.9
        args.elr_weight = 0.
        args.scheduler_gamma = .5
        args.milestones = [20, 40, 60, 80, 100]
        args.start_hmix = 5
        args.p_lower = 0.25
        args.p_upper = 0.75
    elif args.preset == "fmnist-1":
        # args.labeled = 1000
        args.unlabeled = 60000
        args.dataset = "fashionmnist"
        args.positive_flag = 1
        args.positive_label_list = [0, 2, 4, 6]
        args.prior = 0.4
        # args.batchsize = 512
        # args.model = "LeNet"
        # args.stepsize = 0.002
        # args.weight_decay = 3e-5
        args.positive_weight = 1.
        args.elr_weight = 0.
        args.scheduler_gamma = .25
        args.milestones = [20, 40, 60, 80]
        args.start_hmix = 5
        args.p_lower = 0.05
        args.p_upper = 0.95
    elif args.preset == "fmnist-2":
        # args.labeled = 1000
        args.unlabeled = 60000
        args.dataset = "fashionmnist"
        args.positive_flag = 2
        args.positive_label_list = [1, 3, 5, 7, 8, 9]
        args.prior = 0.6
        # args.batchsize = 512
        # args.model = "LeNet"
        # args.stepsize = 0.002
        # args.weight_decay = 3e-5
        args.positive_weight = 1.
        args.unlabeled_weight = 0.9
        args.elr_weight = 0.
        args.scheduler_gamma = .5
        args.milestones = [20, 40, 60, 80, 100]
        args.start_hmix = 5
        args.p_lower = 0.25
        args.p_upper = 0.75
    elif args.preset == "fmnist-3":
        # args.labeled = 1000
        args.unlabeled = 60000
        args.dataset = "fashionmnist"
        args.positive_flag = 1
        args.positive_label_list = [1, 4, 7]
        args.prior = 0.3
        # args.batchsize = 512
        # args.model = "LeNet"
        # args.stepsize = 0.002
        # args.weight_decay = 3e-5
        args.positive_weight = 1.
        args.elr_weight = 0.
        args.scheduler_gamma = .25
        args.milestones = [20, 40, 60, 80]
        args.start_hmix = 5
        args.p_lower = 0.05
        args.p_upper = 0.95
    elif args.preset == "fmnist-4":
        # args.labeled = 1000
        args.unlabeled = 60000
        args.dataset = "fashionmnist"
        args.positive_flag = 2
        args.positive_label_list = [0, 2, 3, 5, 6, 8, 9]
        args.prior = 0.7
        # args.batchsize = 512
        # args.model = "LeNet"
        # args.stepsize = 0.002
        # args.weight_decay = 3e-5
        args.positive_weight = 1.
        args.unlabeled_weight = 0.9
        args.elr_weight = 0.
        args.scheduler_gamma = .5
        args.milestones = [20, 40, 60, 80, 100]
        args.start_hmix = 5
        args.p_lower = 0.25
        args.p_upper = 0.75
    elif args.preset == "cifar10-1":
        # args.labeled = 1000
        args.unlabeled = 50000
        args.dataset = "cifar10"
        args.positive_flag = 1
        args.positive_label_list = [0, 1, 8, 9]
        args.prior = 0.4
        # args.batchsize = 512
        # args.model = "CNN_CIFAR"
        # args.stepsize = 0.0015
        # args.weight_decay = 3e-5
        args.positive_weight = 1.
        args.elr_weight = 0.
        args.scheduler_gamma = .5
        args.milestones = [20, 40, 60, 80, 100]
        args.mix_layer = 3
        args.start_hmix = 10
        args.p_lower = 0.2
        args.p_upper = 0.8
    elif args.preset == "cifar10-2":
        # args.labeled = 1000
        args.unlabeled = 50000
        args.dataset = "cifar10"
        args.positive_flag = 2
        args.positive_label_list = [2, 3, 4, 5, 6, 7]
        args.prior = 0.6
        # args.batchsize = 512
        # args.model = "CNN_CIFAR"
        # args.stepsize = 0.0015
        # args.weight_decay = 3e-5
        args.positive_weight = 1.
        args.elr_weight = 0.
        args.scheduler_gamma = .25
        args.milestones = [20, 40, 60, 80, 100]
        args.mix_layer = 3
        args.start_hmix = 10
        args.p_lower = 0.2
        args.p_upper = 0.8
    elif args.preset == "stl10-1":
        # args.labeled = 1000
        args.unlabeled = 105000
        args.dataset = "stl10"
        args.positive_flag = 1
        args.positive_label_list = [0, 2, 3, 8, 9]
        args.prior = 0.506
        # args.batchsize = 512
        # args.model = "CNN_STL"
        # args.stepsize = 1e-4
        # args.weight_decay = 3e-5
        args.positive_weight = 1.
        args.elr_weight = 0.
        args.scheduler_gamma = .5
        args.milestones = [20, 40, 60, 80]
        args.mix_layer = 3
        args.start_hmix = 10
        args.p_lower = 0.2
        args.p_upper = 0.8
    elif args.preset == "stl10-2":
        # args.labeled = 1000
        args.unlabeled = 105000
        args.dataset = "stl10"
        args.positive_flag = 2
        args.positive_label_list = [1, 4, 5, 6, 7]
        args.prior = 0.494
        # args.batchsize = 512
        # args.model = "cnnstl"
        # args.stepsize = 3e-4
        # args.weight_decay = 3e-5
        args.positive_weight = 1.
        args.unlabeled_weight = .8
        args.elr_weight = 0.
        args.scheduler_gamma = .3
        args.milestones = [20, 40, 60, 80]
        args.mix_layer = 3
        args.start_hmix = 10
        args.p_lower = 0.15
        args.p_upper = 0.85
    elif args.preset == "alzheimer":
        # args.labeled = 1000
        args.unlabeled = 5121
        args.dataset = "alzheimer"
        args.positive_flag = 2
        args.positive_label_list = [1]  #[0, 1, 3]
        args.prior = 0.5
        # args.batchsize = 512
        # args.model = "CNN_CIFAR"
        # args.stepsize = 0.0015
        # args.weight_decay = 3e-5
        args.positive_weight = 1.
        args.elr_weight = 0.
        args.scheduler_gamma = .25
        args.milestones = [20, 40, 60, 80, 100]
        args.mix_layer = 3
        args.start_hmix = 10
        args.p_lower = 0.2
        args.p_upper = 0.8
    elif "cifar10-prior" in args.preset:
        # args.labeled = 1000
        args.unlabeled = 50000
        args.dataset = "cifar10"
        # args.positive_label_list = [0, 1, 8, 9]
        match = re.search(r"prior([0-9.]+)", args.preset)
        if match:
            args.prior = float(match.group(1))
    elif "fmnist-prior" in args.preset:
        args.unlabeled = 60000
        args.dataset = "fashionmnist"
        args.positive_flag = 1
        # args.positive_label_list = [0, 2, 4, 6]
        match = re.search(r"prior([0-9.]+)", args.preset)
        if match:
            args.prior = float(match.group(1))
    elif "mnist-prior" in args.preset:
        args.unlabeled = 60000
        args.dataset = "mnist"
        args.positive_flag = 1
        # args.positive_label_list = [0, 2, 4, 6]
        match = re.search(r"prior([0-9.]+)", args.preset)
        if match:
            args.prior = float(match.group(1))
    elif args.preset == "imagenette":
        # args.labeled = 1000
        args.unlabeled = 6000
        args.dataset = "imagenet"
        args.positive_label_list = [0,1,2,8,9]
        args.prior = 0.5
        # args.batchsize = 512
        # args.model = "CNN_CIFAR"
        # args.stepsize = 0.0015
        # args.weight_decay = 3e-5

    if args.labeled > 64:
        args.batchsize = 64
    elif args.labeled > 16:
        args.batchsize = 16
    else:
        args.batchsize = 4

    assert (args.batchsize > 0)
    assert (args.epoch > 0)
    assert (0 < args.labeled < 30000)
    if (args.dataset == "mnist") or (args.dataset == "fashionmnist"):
        assert (0 < args.unlabeled <= 60000)
    elif args.dataset == "cifar10":
        assert (0 < args.unlabeled <= 50000)
    elif args.dataset == "stl10":
        assert (0 < args.unlabeled <= 105000)
    assert (0. <= args.beta)
    assert (0. <= args.gamma <= 1.)
    return args
