import argparse
import os
import torch
import numpy as np
import random
import string
import datetime

def parse_arg():
    parser = argparse.ArgumentParser(description='PyTorch version of SparseCLv2.') 
    ###job params
    parser.add_argument('--use_random', action='store_true', help='whether to randomly generate seed')
    parser.add_argument('--seed', type=int, default=12, metavar='S', help='random seed (also job id)')
    parser.add_argument('--model_name', default='SparseCL', help='the name of models', choices=['SparseCL'])
    parser.add_argument('--workers', default=8, type=int, metavar='N', help='number of data loader workers')
    parser.add_argument('--batch_size', type=int, default=2048, help="the size of batch samples")
    parser.add_argument('--epochs', type=int, default=100, help="training epochs")
    parser.add_argument('--weight_decay', default=1e-6, type=float, metavar='W', help='weight decay')
    parser.add_argument('--lr', default=0.05, type=float, metavar='LR', help='initial (base) learning rate', dest='lr')
    parser.add_argument('--linear_lr', default=0.1, type=float, metavar='LR', help='learning rate for linear evaluation', dest='linear_lr')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
    parser.add_argument('--linear_epochs', type=int, default=100, help="training epochs")
    parser.add_argument('--num_class', type=int, default=1000, help="the categories of different labels")
    parser.add_argument('--mlp_dim', type=int, default=4096, help="dimension of contrastive representation")
    parser.add_argument('--feature_dim', type=int, default=4096, help="dimension of contrastive representation")
    parser.add_argument('--warmup_epochs', default=10, type=int, metavar='N', help='number of warmup epochs')
    parser.add_argument('--crop_min', default=0.08, type=float, help='minimum scale for random cropping (default: 0.08)')
    parser.add_argument('--temperature', type=float, default=0.1, help="the temperature for self-supervised representation")
    parser.add_argument('--threshold', type=float, default=0.7, help="the threshold of cosine similarity for positive pairs")
    parser.add_argument('--alpha', type=float, default=0.02, help="the hyperparameter of sparsity term.")
    parser.add_argument('--eta', type=float, default=0.02, help="the hyperparameter of lars optimizer.")
    parser.add_argument('--base_momentum', default=0.995, type=float, help='moco momentum of updating key encoder (default: 0.99)')
    parser.add_argument('--train_data_dir', default="/root/data/ImageNet-1k/train/", type=str, help='the directory of train dataset')
    parser.add_argument('--val_data_dir', default="/root/data/ImageNet-1k/val/", type=str, help='the directory of val dataset')
    parser.add_argument('--saver_dir', default="/home/sunset/code/SparseCLv2/ImageNet-1k/saver/", type=str, help='the directory of train or test dataset')
    parser.add_argument('--checkpoint_dir', default="/home/sunset/code/SparseCLv2/ImageNet-1k/checkpoint/", type=str, help='the directory of train or test dataset')
    parser.add_argument('--results', default="/home/sunset/code/SparseCLv2/ImageNet-1k/results/", type=str, metavar='DIR', help='path to checkpoint directory')
    parser.add_argument('--world-size', default=-1, type=int, help='number of nodes for distributed training')
    parser.add_argument('--local_rank', default=-1, type=int, help='node rank for distributed training')
    parser.add_argument('--nnodes', default=-1, type=int, help='node rank for distributed training')
    parser.add_argument('--node_rank', default=-1, type=int, help='node rank for distributed training')
    parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, help='url used to set up distributed training')
    parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend')
    parser.add_argument('--master_address', default='127.0.0.1', type=str, help='url used to set up distributed training')
    parser.add_argument('--master_port', default='29500', type=str, help='url used to set up distributed training')

    args = parser.parse_args()

    args.save_name_pre = '{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_{}'.format(args.model_name, args.seed, args.mlp_dim, args.feature_dim, args.batch_size, args.epochs, args.lr, args.crop_min, args.eta, args.temperature, args.threshold, args.alpha)
    print("save_name_pre:", args.save_name_pre)

    ###process args
    if args.use_random:
        args.seed = random.randint(0, 1e8)
    args.device = torch.device("cuda")

    args.log_dir = args.saver_dir
    if args.local_rank==0 and not os.path.exists(args.saver_dir):
        os.makedirs(args.saver_dir)

    ####Fixed the seed
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    return args
