import argparse

import numpy as np
import torch

# using GPU if available
# device = torch.device("cpu")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def process_args():
    parser = argparse.ArgumentParser(
        description=
        'non-negative / unbiased PU learning Pytorch implementation',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--seed', type=int, default=0, help='seed')
    parser.add_argument('--batchsize',
                        '-b',
                        type=int,
                        default=30000,
                        help='Mini batch size')
    parser.add_argument('--dataset',
                        '-d',
                        default='cifar10',
                        type=str,
                        choices=['mnist', 'cifar10', 'imdb', 'yelp', 'amazon'],
                        help='The dataset name')
    parser.add_argument('--positive-flag', default=1, type=int, choices=[1, 2])
    parser.add_argument('--positive-label-list',
                        type=list,
                        default=[0, 1, 2, 3, 4, 5])
    parser.add_argument('--prior', type=float, default=0.5)
    parser.add_argument('--labeled',
                        '-l',
                        default=100,
                        type=int,
                        help='# of labeled data')
    parser.add_argument('--unlabeled',
                        '-u',
                        default=59900,
                        type=int,
                        help='# of unlabeled data')
    parser.add_argument('--epoch',
                        '-e',
                        default=100,
                        type=int,
                        help='# of epochs to learn')
    parser.add_argument('--val-iterations',
                        '-v',
                        default=20,
                        type=int,
                        help='# of iterations per-epoch')
    parser.add_argument('--loss',
                        type=str,
                        default="sigmoid",
                        choices=['sigmoid', 'logistic'],
                        help='The name of a loss function')
    parser.add_argument('--model',
                        '-m',
                        default='3mlp',
                        choices=[
                            'linear',
                            '3mlp',
                            '4mlp',
                            '6mlp',
                            'cnn',
                            'cnn1d_nlp',
                            'cnn_nlp',
                            'lenet',
                            'cnnstl',
                        ],
                        help='The name of a classification model')
    parser.add_argument('--optim',
                        '-O',
                        default='adam',
                        choices=[
                            'adam',
                            'sgd',
                            'adagrad',
                        ],
                        help='The name of a optimizer')
    parser.add_argument('--stepsize',
                        '-s',
                        default=1e-3,
                        type=float,
                        help='Stepsize of gradient method')
    parser.add_argument('--weight-decay',
                        '-w',
                        default=5e-3,
                        type=float,
                        help='Weight decay of gradient method')

    parser.add_argument('--milestones',
                        nargs='+',
                        default=[50, 100],
                        type=int,
                        help='milestones for scheduler')
    parser.add_argument('--scheduler-gamma',
                        default=0.1,
                        type=float,
                        help='gamma for scheduler')

    parser.add_argument('--out',
                        '-o',
                        default='result',
                        help='Directory to output the result')

    # parameters for positive loss
    parser.add_argument('--positive-weight',
                        default=1.0,
                        type=float,
                        help='coefficient for positive loss')

    # parameters for unlabeled loss
    parser.add_argument('--unlabeled-weight',
                        default=1.0,
                        type=float,
                        help='coefficient for unlabeled loss')

    # parameters for mean-teacher
    parser.add_argument('--mean-teacher', default=True, action='store_true')
    parser.add_argument('--ema-update', default=False, action='store_true')
    parser.add_argument('--ema-start', type=int, default=0)
    parser.add_argument('--ema-end', type=int, default=100)
    parser.add_argument('--ema-step', type=int, default=40000)
    parser.add_argument('--ema-decay', type=float, default=0.997)

    # parameter for entropy regularization
    parser.add_argument('--entropy-weight',
                        default=0.1,
                        type=float,
                        help='coefficient for entropy regularization')

    parser.add_argument('--elr-weight',
                        default=0.,
                        type=float,
                        help='coefficient for early-learning regularization')

    # parameters for mixup
    parser.add_argument(
        '--mix-layer',
        default=-1,
        type=int,
        help='number of layers on which mixup is applied including input layer'
    )
    parser.add_argument('--alpha',
                        default=1.0,
                        type=float,
                        help='alpha for beta distribution')

    # parameters for heuristic mixup
    parser.add_argument('--h-positive', type=int, default=100)
    parser.add_argument('--start-hmix', type=int, default=20)
    parser.add_argument('--p-upper', type=float, default=0.6)
    parser.add_argument('--p-lower', type=float, default=0.4)

    parser.add_argument(
        '--preset',
        '-p',
        type=str,
        default='cifar10',
        choices=[
            'fashionmnist-lenet-1',
            'fashionmnist-lenet-2',
            'cifar10-1',
            'cifar10-2',
            'stl10-1',
            'stl10-2',
        ],
        help="Preset of configuration\n" + "mnist: The setting of Figure1\n" +
        "mnist-6mlp: The setting of MNIST experiment in Experiment\n" +
        "cifar10: The setting of CIFAR10 experiment in Experiment")

    args = parser.parse_args()
    if args.preset == "fashionmnist-lenet-1":
        args.labeled = 1000
        args.unlabeled = 60000
        args.dataset = "fashionmnist"
        args.positive_flag = 1
        args.positive_label_list = [1, 4, 7]
        args.prior = 0.3
        args.batchsize = 512
        args.model = "lenet"
        args.stepsize = 5e-4
        args.weight_decay = 5e-3
        args.positive_weight = 1.
        args.elr_weight = 0.
        args.epoch = 200
        args.scheduler_gamma = .25
        args.milestones = [20, 40, 60, 80]
        args.start_hmix = 5
        args.p_lower = 0.05
        args.p_upper = 0.95
    elif args.preset == "fashionmnist-lenet-2":
        args.labeled = 1000
        args.unlabeled = 60000
        args.dataset = "fashionmnist"
        args.positive_flag = 2
        args.positive_label_list = [0, 2, 3, 5, 6, 8, 9]
        args.prior = 0.7
        args.batchsize = 512
        args.model = "lenet"
        args.stepsize = 5e-5
        args.weight_decay = 1e-3
        args.positive_weight = 1.
        args.unlabeled_weight = 0.9
        args.elr_weight = 0.
        args.epoch = 200
        args.scheduler_gamma = .5
        args.milestones = [20, 40, 60, 80, 100]
        args.start_hmix = 5
        args.p_lower = 0.25
        args.p_upper = 0.75
    elif args.preset == "cifar10-1":
        args.labeled = 1000
        args.unlabeled = 50000
        args.dataset = "cifar10"
        args.positive_flag = 1
        args.positive_label_list = [0, 1, 8, 9]
        args.prior = 0.4
        args.batchsize = 512
        args.model = "cnn"
        args.stepsize = 5e-5
        args.weight_decay = 1e-6
        args.positive_weight = 1.
        args.elr_weight = 0.
        args.epoch = 200
        args.scheduler_gamma = .5
        args.milestones = [20, 40, 60, 80, 100]
        args.mix_layer = 3
        args.start_hmix = 10
        args.p_lower = 0.2
        args.p_upper = 0.8
    elif args.preset == "cifar10-2":
        args.labeled = 1000
        args.unlabeled = 50000
        args.dataset = "cifar10"
        args.positive_flag = 2
        args.positive_label_list = [2, 3, 4, 5, 6, 7]
        args.prior = 0.6
        args.batchsize = 512
        args.model = "cnn"
        args.stepsize = 5e-5
        args.weight_decay = 1e-6
        args.positive_weight = 1.
        args.unlabeled_weight = 0.8
        args.elr_weight = 0.
        args.epoch = 200
        args.scheduler_gamma = .25
        args.milestones = [20, 40, 60, 80, 100]
        args.mix_layer = 3
        args.start_hmix = 10
        args.p_lower = 0.2
        args.p_upper = 0.8
    elif args.preset == "stl10-1":
        args.labeled = 1000
        args.unlabeled = 105000
        args.dataset = "stl10"
        args.positive_flag = 1
        args.positive_label_list = [0, 2, 3, 8, 9]
        args.prior = 0.506
        args.batchsize = 512
        args.model = "cnnstl"
        args.stepsize = 1e-4
        args.weight_decay = 1e-3
        args.positive_weight = 1.
        args.elr_weight = 0.
        args.epoch = 200
        args.scheduler_gamma = .5
        args.milestones = [20, 40, 60, 80]
        args.mix_layer = 3
        args.start_hmix = 10
        args.p_lower = 0.2
        args.p_upper = 0.8
    elif args.preset == "stl10-2":
        args.labeled = 1000
        args.unlabeled = 105000
        args.dataset = "stl10"
        args.positive_flag = 2
        args.positive_label_list = [1, 4, 5, 6, 7]
        args.prior = 0.494
        args.batchsize = 512
        args.model = "cnnstl"
        args.stepsize = 3e-4
        args.weight_decay = 1e-5
        args.positive_weight = 1.
        args.unlabeled_weight = 0.8
        args.elr_weight = 0.
        args.epoch = 200
        args.scheduler_gamma = .3
        args.milestones = [20, 40, 60, 80]
        args.mix_layer = 3
        args.start_hmix = 10
        args.p_lower = 0.15
        args.p_upper = 0.85

    assert (args.batchsize > 0)
    assert (args.epoch > 0)
    assert (0 < args.labeled < 30000)
    if (args.dataset == "mnist") or (args.dataset == "fashionmnist"):
        assert (0 < args.unlabeled <= 60000)
    elif args.dataset == "cifar10":
        assert (0 < args.unlabeled <= 50000)
    elif args.dataset == "stl10":
        assert (0 < args.unlabeled <= 105000)
    assert (0. <= args.beta)
    assert (0. <= args.gamma <= 1.)
    return args
