import argparse

import numpy as np
import torch

# using GPU if available
# device = torch.device("cpu")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def process_args():
    parser = argparse.ArgumentParser(
        description=
        'non-negative / unbiased PU learning Pytorch implementation',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--seed', type=int, default=0, help='seed')
    parser.add_argument('--batchsize',
                        '-b',
                        type=int,
                        default=30000,
                        help='Mini batch size')
    parser.add_argument('--dataset',
                        '-d',
                        default='cifar10',
                        type=str,
                        choices=['mnist', 'cifar10', 'imdb', 'yelp', 'amazon'],
                        help='The dataset name')
    parser.add_argument('--positive-flag',
                        default=1,
                        type=int,
                        choices=[1, 2])
    parser.add_argument('--labeled',
                        '-l',
                        default=100,
                        type=int,
                        help='# of labeled data')
    parser.add_argument('--unlabeled',
                        '-u',
                        default=59900,
                        type=int,
                        help='# of unlabeled data')
    parser.add_argument('--epoch',
                        '-e',
                        default=100,
                        type=int,
                        help='# of epochs to learn')
    parser.add_argument('--val-iterations',
                        '-v',
                        default=20,
                        type=int,
                        help='# of iterations per-epoch')
    parser.add_argument('--beta',
                        '-B',
                        default=0.,
                        type=float,
                        help='Beta parameter of nnPU')
    parser.add_argument('--gamma',
                        '-G',
                        default=1.,
                        type=float,
                        help='Gamma parameter of nnPU')
    parser.add_argument('--loss',
                        type=str,
                        default="sigmoid",
                        choices=['sigmoid', 'logistic'],
                        help='The name of a loss function')
    parser.add_argument('--model',
                        '-m',
                        default='3mlp',
                        choices=[
                            'linear',
                            '3mlp',
                            '4mlp',
                            '6mlp',
                            'cnn',
                            'cnn1d_nlp',
                            'cnn_nlp',
                            'lenet',
                            'cnnstl',
                        ],
                        help='The name of a classification model')
    parser.add_argument('--optim',
                        '-O',
                        default='adam',
                        choices=[
                            'adam',
                            'sgd',
                            'adagrad',
                        ],
                        help='The name of a optimizer')
    parser.add_argument('--stepsize',
                        '-s',
                        default=1e-3,
                        type=float,
                        help='Stepsize of gradient method')
    parser.add_argument('--weight-decay',
                        '-w',
                        default=5e-3,
                        type=float,
                        help='Weight decay of gradient method')

    parser.add_argument('--milestones',
                        nargs='+',
                        default=[50, 100],
                        type=int,
                        help='milestones for scheduler')
    parser.add_argument('--scheduler-gamma',
                        default=0.1,
                        type=float,
                        help='gamma for scheduler')

    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Directory to output the result')

    # parameters for positive loss
    parser.add_argument('--positive-weight',
                        default=1.0,
                        type=float,
                        help='coefficient for positive loss')
    parser.add_argument('--positive-rampdown-starts',
                        default=200,
                        type=int,
                        metavar='EPOCHS',
                        help='epoch at which positive loss ramp-down starts')
    parser.add_argument('--positive-rampdown-ends',
                        default=200,
                        type=int,
                        metavar='EPOCHS',
                        help='epoch at which positive loss ramp-down ends')

    # parameters for unlabeled loss
    parser.add_argument('--unlabeled-weight',
                        default=1.0,
                        type=float,
                        help='coefficient for unlabeled loss')
    parser.add_argument('--unlabeled-rampdown-starts',
                        default=200,
                        type=int,
                        metavar='EPOCHS',
                        help='epoch at which unlabeled loss ramp-down starts')
    parser.add_argument('--unlabeled-rampdown-ends',
                        default=200,
                        type=int,
                        metavar='EPOCHS',
                        help='epoch at which unlabeled loss ramp-down ends')

    # parameters for pseudo-labels
    parser.add_argument('--target-weight', default=1.0, type=float)
    parser.add_argument('--target-rampup-starts',
                        default=200,
                        type=int,
                        metavar='EPOCHS')
    parser.add_argument('--target-rampup-ends',
                        default=200,
                        type=int,
                        metavar='EPOCHS')

    # parameters for mean-teacher
    parser.add_argument('--mean-teacher', default=True, action='store_true')
    parser.add_argument('--ema-update', default=False, action='store_true')
    parser.add_argument('--ema-start', type=int, default=0)
    parser.add_argument('--ema-end', type=int, default=100)
    parser.add_argument('--ema-step', type=int, default=40000)
    parser.add_argument('--ema-decay', type=float, default=0.997)

    # parameter for entropy regularization
    parser.add_argument('--entropy-weight',
                        default=0.1,
                        type=float,
                        help='coefficient for entropy regularization')

    parser.add_argument('--elr-weight',
                        default=5.,
                        type=float,
                        help='coefficient for early-learning regularization')
    parser.add_argument('--elr-rampdown-starts',
                        default=200,
                        type=int,
                        metavar='EPOCHS',
                        help='epoch at which elr regularization ramp-down starts')
    parser.add_argument('--elr-rampdown-ends',
                        default=200,
                        type=int,
                        metavar='EPOCHS',
                        help='epoch at which elr regularization ramp-down ends')


    # parameters for mixup
    parser.add_argument('--mix-option',
                        default=True,
                        action='store_true',
                        help='mix option, whether to mix or not')
    parser.add_argument(
        '--separate-mix',
        default=False,
        action='store_true',
        help='mix separate from labeled data and unlabeled data')
    # parser.add_argument('--mix-layers-set',
    #                     nargs='+',
    #                     default=[0, 1, 2, 3],
    #                     type=int,
    #                     help='define mix layer set')
    parser.add_argument(
        '--mix-layer',
        default=-1,
        type=int,
        help='number of layers on which mixup is applied including input layer'
    )
    parser.add_argument('--alpha',
                        default=1.0,
                        type=float,
                        help='alpha for beta distribution')

    # parameters for heuristic mixup
    parser.add_argument('--h-positive', type=int, default=100)
    parser.add_argument('--start-hmix', type=int, default=20)
    parser.add_argument('--p-upper', type=float, default=0.6)
    parser.add_argument('--p-lower', type=float, default=0.4)

    parser.add_argument(
        '--preset',
        '-p',
        type=str,
        default='cifar10',
        choices=[
            'mnist', 'mnist-6mlp', 'fashionmnist', 'fashionmnist-lenet-1',
            'fashionmnist-lenet-2','fashionmnist-6mlp', 'cifar10-1', 'cifar10-2',
            'stl10-1', 'stl10-2', 'imdb', 'yelp_full', '20ng', 'yelp', 'amazon'
        ],
        help="Preset of configuration\n" + "mnist: The setting of Figure1\n" +
        "mnist-6mlp: The setting of MNIST experiment in Experiment\n" +
        "cifar10: The setting of CIFAR10 experiment in Experiment")

    args = parser.parse_args()
    if args.preset == "mnist":
        args.labeled = 100
        args.unlabeled = 60000
        args.dataset = "mnist"
        args.batchsize = 30000
        args.epoch = 100
        args.model = "3mlp"
    elif args.preset == "mnist-6mlp":
        args.labeled = 1000
        args.unlabeled = 60000
        args.dataset = "mnist"
        args.batchsize = 30000
        args.epoch = 200
        args.model = "6mlp"
    elif args.preset == "fashionmnist":
        args.labeled = 100
        args.unlabeled = 60000
        args.dataset = "fashionmnist"
        args.batchsize = 30000
        args.epoch = 100
        args.model = "3mlp"
    elif args.preset == "fashionmnist-6mlp":
        args.labeled = 1000
        args.unlabeled = 60000
        args.dataset = "fashionmnist"
        args.batchsize = 30000
        args.positive_weight = 5.
        args.elr_weight = 5.
        args.epoch = 200
        args.model = "6mlp"
    elif args.preset == "imdb":
        args.labeled = 1000
        args.unlabeled = 25000
        args.dataset = "imdb"
        args.batchsize = 1000
        args.optim = "adagrad"
        args.stepsize = 1e-3
        args.positive_weight = 2.
        args.elr_weight = 0.
        args.epoch = 100
        args.model = "cnn1d_nlp"
    elif args.preset == "yelp_full":
        args.labeled = 1000
        args.unlabeled = 50000
        args.dataset = "yelp_full"
        args.batchsize = 1000
        args.optim = "adagrad"
        args.stepsize = 1e-3
        args.positive_weight = 2.
        args.elr_weight = 0.
        args.epoch = 100
        args.model = "cnn1d_nlp"
    elif args.preset == "yelp":
        args.labeled = 1000
        args.unlabeled = 50000
        args.dataset = "yelp"
        args.batchsize = 1000
        args.optim = "adagrad"
        args.stepsize = 1e-3
        args.positive_weight = 2.
        args.elr_weight = 0.
        args.epoch = 100
        args.model = "cnn1d_nlp"
    elif args.preset == "amazon":
        args.labeled = 1000
        args.unlabeled = 50000
        args.dataset = "amazon"
        args.batchsize = 1000
        args.optim = "adagrad"
        args.stepsize = 1e-3
        args.positive_weight = 2.
        args.elr_weight = 0.
        args.epoch = 100
        args.model = "cnn1d_nlp"
    elif args.preset == "20ng":
        args.labeled = 1000
        args.unlabeled = 11314
        args.dataset = "20ng"
        args.batchsize = 1000
        args.optim = "adagrad"
        args.stepsize = 1e-3
        args.positive_weight = 2.
        args.elr_weight = 0.
        args.epoch = 100
        args.model = "cnn1d_nlp"
    elif args.preset == "fashionmnist-lenet-1":
        args.labeled = 1000
        args.unlabeled = 60000
        args.dataset = "fashionmnist"
        args.positive_flag = 1
        args.batchsize = 30000
        args.model = "lenet"
        args.stepsize = 5e-4
        args.weight_decay = 5e-4
        args.positive_weight = 1.
        args.elr_weight = 5.
        args.epoch = 200
        args.scheduler_gamma = .5
        args.milestones = [20, 40, 60, 80, 100]
        args.start_hmix = 5
        args.p_lower = 0.05
        args.p_upper = 0.95
    elif args.preset == "cifar10-1":
        args.labeled = 1000
        args.unlabeled = 50000
        args.dataset = "cifar10"
        args.positive_flag = 1
        args.batchsize = 500
        args.model = "cnn"
        args.stepsize = 5e-4
        args.positive_weight = 1.
        args.elr_weight = 5.
        args.epoch = 200
        args.scheduler_gamma = .2
        args.milestones = [20, 40, 60, 80, 100]
        # args.scheduler_gamma = .1
        # args.milestones = [50, 100, 150]
        args.start_hmix = 5
        args.p_lower = 0.1
        args.p_upper = 0.9
    elif args.preset == "stl10-1":
        args.labeled = 1000
        args.unlabeled = 50000
        args.dataset = "stl10"
        args.positive_flag = 1
        args.batchsize = 500
        args.model = "cnnstl"
        args.stepsize = 1e-3
        args.positive_weight = 1.
        args.elr_weight = 2.
        args.epoch = 200
        args.scheduler_gamma = .5
        args.milestones = [20, 40, 60, 80, 100, 120, 160]
        args.p_lower = 0.15
        args.p_upper = 0.85
    elif args.preset == "fashionmnist-lenet-2":
        args.labeled = 1000
        args.unlabeled = 60000
        args.dataset = "fashionmnist"
        args.positive_flag = 2
        args.batchsize = 30000
        args.model = "lenet"
        args.optim = "adam"
        args.stepsize = 5e-4
        args.weight_decay = 5e-4
        args.positive_weight = 1.
        args.unlabeled_weight = .5
        args.elr_weight = 2.
        args.epoch = 200
        args.scheduler_gamma = .5
        args.milestones = [20, 40, 60, 80, 100]
        args.start_hmix = 5
        args.p_lower = 0.05
        args.p_upper = 0.95
    elif args.preset == "cifar10-2":
        args.labeled = 1000
        args.unlabeled = 50000
        args.dataset = "cifar10"
        args.positive_flag = 2
        args.batchsize = 500
        args.model = "cnn"
        args.optim = "adam"
        args.stepsize = 5e-4
        args.weight_decay = 1e-6
        args.positive_weight = 1.
        args.elr_weight = 10.
        args.elr_rampdown_starts = 200
        args.epoch = 200
        args.scheduler_gamma = .5
        args.milestones = [20, 40, 60, 80, 100]
        args.mix_layer = 4
        args.alpha = 1.
        args.start_hmix = 5
        args.p_lower = 0.2
        args.p_upper = 0.8
    elif args.preset == "stl10-2":
        args.labeled = 1000
        args.unlabeled = 50000
        args.dataset = "stl10"
        args.positive_flag = 2
        args.batchsize = 500
        args.model = "cnnstl"
        args.optim = "adam"
        args.stepsize = 1e-3
        args.weight_decay = 1e-2
        args.positive_weight = 1.
        args.elr_weight = 1.
        args.epoch = 200
        args.scheduler_gamma = .1
        args.milestones = [50, 100, 150]
        args.mix_layer = 3

    assert (args.batchsize > 0)
    assert (args.epoch > 0)
    assert (0 < args.labeled < 30000)
    if (args.dataset == "mnist") or (args.dataset == "fashionmnist"):
        assert (0 < args.unlabeled <= 60000)
    else:
        assert (0 < args.unlabeled <= 50000)
    assert (0. <= args.beta)
    assert (0. <= args.gamma <= 1.)
    return args
