import argparse
import os
import torch
import numpy as np
import random
import string
import datetime

def parse_arg():
    parser = argparse.ArgumentParser(description='PyTorch version of SparseCLv2.') 
    ###job params
    parser.add_argument('--use_random', action='store_true', help='whether to randomly generate seed')
    parser.add_argument('--use_predictor', action='store_true', help='whether to predictor')
    parser.add_argument('--seed', type=int, default=12, metavar='S', help='random seed (also job id)')
    parser.add_argument('--model_name', default='SparseCL', help='the name of models', choices=['SparseCL'])
    parser.add_argument('--stage', type=int, default=0, help="0 train model; 1 linear evaluate model")
    parser.add_argument('--dataset', default='CIFAR-10', help='dataset for training', choices=['CIFAR-10', 'CIFAR-100'])
    parser.add_argument('--workers', default=4, type=int, metavar='N', help='number of data loader workers')
    parser.add_argument('--batch_size', type=int, default=256, help="the size of batch samples")
    parser.add_argument('--epochs', type=int, default=1000, help="training epochs")
    parser.add_argument('--weight_decay', default=1e-5, type=float, metavar='W', help='weight decay')
    parser.add_argument('--lr', default=0.2, type=float, metavar='LR', help='initial (base) learning rate', dest='lr')
    parser.add_argument('--linear_lr', default=0.3, type=float, metavar='LR', help='initial (base) learning rate', dest='linear_lr')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
    parser.add_argument('--linear_epochs', type=int, default=100, help="training epochs")
    parser.add_argument('--num_class', type=int, default=10, help="the categories of different labels")
    parser.add_argument('--mlp_dim', type=int, default=2048, help="dimension of contrastive representation")
    parser.add_argument('--feature_dim', type=int, default=2048, help="dimension of contrastive representation")
    parser.add_argument('--warmup_epochs', default=10, type=int, metavar='N', help='number of warmup epochs')
    parser.add_argument('--temperature', type=float, default=0.1, help="the temperature for self-supervised representation")
    parser.add_argument('--threshold', type=float, default=0.6, help="the threshold of cosine similarity for positive pairs")
    parser.add_argument('--alpha', type=float, default=0.02, help="the hyperparameter of sparsity term.")
    parser.add_argument('--eta', type=float, default=0.02, help="lars optimizer trust_coefficient.")
    parser.add_argument('--base_momentum', default=0.995, type=float, help='moco momentum of updating key encoder (default: 0.99)')
    parser.add_argument('--data_dir', default="/data/fxh/Instance_wise_Similarity/data/", type=str, help='the directory of train or test dataset')
    parser.add_argument('--saver_dir', default="./saver/", type=str, help='the directory of train or test dataset')
    parser.add_argument('--checkpoint_dir', default="./checkpoint/", type=str, help='the directory of train or test dataset')
    parser.add_argument('--results', default="./results/", type=str, metavar='DIR', help='path to checkpoint directory')
    args = parser.parse_args()

    args.save_name_pre = '{}_{}_{}_{}_{}_{}_{}_{}_{}'.format(args.model_name, args.seed, args.dataset, args.batch_size, args.epochs, args.use_predictor, args.lr, args.threshold, args.alpha)
    print("save_name_pre:", args.save_name_pre)

    args.data_dir = os.path.join(args.data_dir, args.dataset)
    print("data dir:", args.data_dir)

    if args.dataset == "CIFAR-10":
        args.num_class = 10
    elif args.dataset == "CIFAR-100":
        args.num_class = 100

    ###process args
    if args.use_random:
        args.seed = random.randint(0, 1e8)
    args.device = torch.device("cuda")

    args.log_dir = args.saver_dir
    if not os.path.exists(args.saver_dir):
        os.makedirs(args.saver_dir)

    ####Fixed the seed
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    return args
