
import argparse
import sys

parser = argparse.ArgumentParser()

#Generic
parser.add_argument("--method", type=str, default='hybrid') #method in ['sl','scl','hybrid']
#parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--device", type=str, default="cuda:0") 
parser.add_argument("--workers", type=int, default=0) 
parser.add_argument("--bs", type=int, default=128) 
parser.add_argument("--epochs", type=int, default=50) 
parser.add_argument("--epochs_clf", type=int, default=50) #number of epoches for linear classifier of scl and hybrid

parser.add_argument('--epoch_t', type=int, default=50,
                      help='epoch threshold N_0 (default: 50)')
    
#Model
parser.add_argument("--backbone", type=str, default='resnet18')
parser.add_argument("--scratch", action='store_true') 
parser.add_argument("--test_mode", action='store_true')

#args for optimizer
parser.add_argument("--wd", type=float, default=1e-4) 
parser.add_argument("--lr", type=float, default=1e-4) 
parser.add_argument("--lr2", type=float, default=1e-1) 

#Parameter
parser.add_argument("--tau", type=float, default=0.06) 
parser.add_argument("--alpha", type=float, default=0.5) 

#dataset
parser.add_argument('--dataset', type=str, default='cifar10',choices=['cifar10', 'cifar100', 'tiny200', 'mnist','BuS'], help='dataset')
parser.add_argument('--mean', type=str, help='mean of dataset in path in form of str tuple')
parser.add_argument('--std', type=str, help='std of dataset in path in form of str tuple')
parser.add_argument('--data_folder', type=str, default=None, help='path to custom dataset')
parser.add_argument('--size', type=int, default=224, help='parameter for RandomResizedCrop')
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--drop_train_last', type=bool, default=True)

parser.add_argument('--save_freq', type=int, default=5, help='save frequency')

parser.add_argument('--out_dir', type=str, default="results")


# jacobian regularization
parser.add_argument('--jac_sel', type=int, default=1, help='select the implemention of jacobian regularization')
parser.add_argument('--jac_reg', action="store_true", default=False, help='use jacobian regularization')
parser.add_argument('--lambda_JR', type=float, default=0.01)
parser.add_argument('--jac_reg_projector', action="store_true", default=False, help='use jacobian regularization for projector')

# attack
parser.add_argument('--attack', action="store_true", default=False)
parser.add_argument('--nb_iters', type=int, default=50, help="number of iterations for attack")
parser.add_argument('--nb_tans', type=int, default=1, help='number of tangent recomputation')
parser.add_argument('--eps_iter', type=float, default=1.0, help="step size for attack")
parser.add_argument('--adv_rate', type=float, default=0.8, help='rate of adversarial examples')

parser.add_argument('--ckpt', type=str, default=None)

print("Command line arguments:", sys.argv[1:])  

args = parser.parse_args()

# check if dataset is path that passed required arguments
if args.dataset == 'path':
    assert args.data_folder is not None \
        and args.mean is not None \
        and args.std is not None
    dataset_name = f"{args.data_folder.split('/')[-1]}"
else:
    dataset_name = args.dataset


# set the path according to the environment
if args.data_folder is None:
    args.data_folder = './datasets/'

if args.dataset == "mnist":
    args.backbone ='Net4mnist'

ep = f"{args.epochs}+{args.epochs_clf}" if args.method == 'scl' else f"{args.epochs}"
jac_reg = f"{args.jac_reg}_lambdaJR{args.lambda_JR}" if args.jac_reg else f"{args.jac_reg}"
if args.jac_reg_projector:
    jac_reg += f"+jacreg4proj"
args.trial_name = f"{dataset_name}_size{args.size}_lr{args.lr}_lr2{args.lr2}_decay_{args.wd}_bsz{args.bs}_temp{args.tau}_epoch{ep}_jacreg{jac_reg}_sel{args.jac_sel}"
args.save_folder = f"./{args.out_dir}-{args.method}-{args.backbone}/{args.trial_name}"

if args.attack:
    args.save_folder += f"/adv{args.adv_rate}_iter{args.nb_iters}_eps{args.eps_iter}"

import os
if not os.path.exists(args.save_folder):
    os.makedirs(args.save_folder, exist_ok=True)
print("args.save_folder:",args.save_folder)