from functools import partial
import argparse
from torchvision import models
import multiprocessing
from datasets import DS_LIST
from methods import METHOD_LIST


def get_cfg():
    """ generates configuration from user input in console """
    parser = argparse.ArgumentParser(description="")
    parser.add_argument(
        "--method", type=str, choices=METHOD_LIST, default="w_mse", help="loss type",
    )
    parser.add_argument(
        "--wandb",
        type=str,
        default="ssl-sota",
        help="name of the project for logging at https://wandb.ai",
    )
    parser.add_argument(
        "--byol_tau", type=float, default=0.99, help="starting tau for byol loss"
    )
    parser.add_argument(
        "--num_samples",
        type=int,
        default=2,
        help="number of samples (d) generated from each image",
    )

    addf = partial(parser.add_argument, type=float)
    addf("--cj0", default=0.4, help="color jitter brightness")
    addf("--cj1", default=0.4, help="color jitter contrast")
    addf("--cj2", default=0.4, help="color jitter saturation")
    addf("--cj3", default=0.1, help="color jitter hue")
    addf("--cj_p", default=0.8, help="color jitter probability")
    addf("--gs_p", default=0.1, help="grayscale probability")
    addf("--crop_s0", default=0.2, help="crop size from")
    addf("--crop_s1", default=1.0, help="crop size to")
    addf("--crop_r0", default=0.75, help="crop ratio from")
    addf("--crop_r1", default=(4 / 3), help="crop ratio to")
    addf("--hf_p", default=0.5, help="horizontal flip probability")

    parser.add_argument(
        "--no_lr_warmup",
        dest="lr_warmup",
        action="store_false",
        help="do not use learning rate warmup",
    )
    parser.add_argument(
        "--no_add_bn", dest="add_bn", action="store_false", help="do not use BN in head"
    )
    parser.add_argument("--knn", type=int, default=5, help="k in k-nn classifier")
    parser.add_argument("--fname", type=str, help="load model from file")
    parser.add_argument(
        "--lr_step",
        type=str,
        choices=["cos", "step", "none"],
        default="step",
        help="learning rate schedule type",
    )
    parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
    parser.add_argument(
        "--eta_min", type=float, default=0, help="min learning rate (for --lr_step cos)"
    )
    parser.add_argument(
        "--adam_l2", type=float, default=1e-6, help="weight decay (L2 penalty)"
    )
    parser.add_argument("--T0", type=int, help="period (for --lr_step cos)")
    parser.add_argument(
        "--Tmult", type=int, default=1, help="period factor (for --lr_step cos)"
    )
    parser.add_argument(
        "--w_eps", type=float, default=1e-4, help="eps for stability for whitening"
    )
    parser.add_argument(
        "--head_layers", type=int, default=2, help="number of FC layers in head"
    )
    parser.add_argument(
        "--head_size", type=int, default=1024, help="size of FC layers in head"
    )

    parser.add_argument(
        "--w_size", type=int, default=128, help="size of sub-batch for W-MSE loss"
    )
    parser.add_argument(
        "--w_iter",
        type=int,
        default=1,
        help="iterations for whitening matrix estimation",
    )

    parser.add_argument(
        "--no_norm", dest="norm", action="store_false", help="don't normalize latents",
    )
    parser.add_argument(
        "--tau", type=float, default=0.5, help="contrastive loss temperature"
    )

    parser.add_argument("--epoch", type=int, default=200, help="total epoch number")
    parser.add_argument(
        "--eval_every_drop",
        type=int,
        default=5,
        help="how often to evaluate after learning rate drop",
    )
    parser.add_argument(
        "--eval_every", type=int, default=20, help="how often to evaluate"
    )
    parser.add_argument("--emb", type=int, default=64, help="embedding size")
    parser.add_argument(
        "--bs", type=int, default=384, help="number of original images in batch N",
    )
    parser.add_argument(
        "--drop",
        type=int,
        nargs="*",
        default=[50, 25],
        help="milestones for learning rate decay (0 = last epoch)",
    )
    parser.add_argument(
        "--drop_gamma",
        type=float,
        default=0.2,
        help="multiplicative factor of learning rate decay",
    )
    parser.add_argument(
        "--arch",
        type=str,
        choices=[x for x in dir(models) if "resn" in x],
        default="resnet18",
        help="encoder architecture",
    )
    parser.add_argument("--dataset", type=str, choices=DS_LIST, default="cifar10")
    parser.add_argument(
        "--num_workers",
        type=int,
        default=0,
        help="dataset workers number",
    )
    parser.add_argument(
        "--clf",
        type=str,
        default="sgd",
        choices=["sgd", "knn", "lbfgs"],
        help="classifier for test.py",
    )
    parser.add_argument(
        "--eval_head", action="store_true", help="eval head output instead of model",
    )
    parser.add_argument("--imagenet_path", type=str, default="~/IN100/")
    return parser.parse_args()
