import os
import argparse


parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('--job_type', default='train', help='type of job to run (train-evaluate)')
parser.add_argument('--dataset', default='cifar', help='dataset')
parser.add_argument('--model', default='ResNet18', help='Deep Learning model to train')
parser.add_argument('--method', default='orig', help='clipping method (use orig for no clipping)')
parser.add_argument('--mode', default='wBN')
parser.add_argument('--seed', default=-1, type=int)
parser.add_argument('--epoch', default=10, type=int)
parser.add_argument('--lr', default=0.01, type=float)
parser.add_argument('--LRsteps', default=40, type=int, help='LR scheduler step')
parser.add_argument('--widen_factor', default=1, type=int, help='widen factor for WideResNet')

parser.add_argument('--norm_cond', default='unnorm', help='unnorm or norm for transform')
parser.add_argument('--remain', default='use', type=str)
parser.add_argument('--salun_ratio', default='0.5', type=str, help='ratio of masking in salun')

parser.add_argument('--unlearn_method', default='retrain', type=str)
parser.add_argument('--unlearn_indices', default=None, type=str)
parser.add_argument('--model_path', default=None, type=str)
parser.add_argument('--model_count', default=1, type=int)
parser.add_argument('--trials', default=1, type=int)

parser.add_argument('--inclusion_mat_path', default='~/keep_files/keep_m128_d60000_s0.csv', type=str)

parser.add_argument('--reference_mat', default='/class_unlearn/logs/correct/scratch/cifar_unnorm/unlearn/reference/ResNet18_orig__m128_d60000_s0_160_prob_matrix_logits_onehot.pt', type=str)

args = parser.parse_args()
print('!!!!!!!! norm cond: ', args.norm_cond)
print('!!!!!!!! use_remain: ', args.remain)

unlearn_count = -1

if args.method == 'clip' and args.mode == 'noBN':
    args.reference_mat = '/class_unlearn/logs/correct/scratch/cifar_unnorm/unlearn/reference/ResNet18_fastclip_cs100_noBN_m128_d60000_s0_100_prob_matrix_logits_onehot.pt'
    print('clip noBN reference mat in use!')

if __name__ == '__main__':
    dataset = args.dataset
    model = args.model
    if args.seed == -1:
        # seed_list = [10**i for i in range(5)]
        seed_list = [10**i for i in range(3)]
        # seed_list = [10**i for i in range(3, 5)]
    elif args.seed == -32:
        seed_list = [i for i in range(32)]
    elif args.seed == -64:
        seed_list = [i for i in range(64)]
    elif args.seed == -128:
        seed_list = [i for i in range(128)]
    else:
        seed_list = [args.seed]

    # convsn_list = [1., 1.2, 0.8, 1.5, 0.5]
    # convsn_list = [1., 1.2, 0.8]
    # convsn_list = [0.8, 1.]
    convsn_list = [1.]

    steps = 50 # this is clipBN steps

    method = args.method
    print('method: ', method)
    if method == 'all':
        methods = ['orig', 'fastclip_tlower_cs100']
    elif method == 'clip':
        methods = ['clip'] # ['fastclip_cs100']
    elif method[:4] == 'fast':
        methods = ['fastclip_tlower_cs100', 'fastclip_tlower_cs50']
    else:
        methods = [method]

    if args.model  not in ['ResNet18', 'simpleConv', 'wideResnet', 'VGG']:
        raise ValueError('model must be one of ResNet18, DLA, SimpleConv, VGG')


    source_model_seeds = [1, 10, 100] # sms
    if args.model_count == 1:
        source_model_seeds = [-1] # sms
    else:
        source_model_seeds = source_model_seeds[:args.model_count]

    mode = args.mode
    if mode == 'all':
        modes = ['wBN', 'noBN']
    else:
        modes = [mode]

    for mode in modes:
        for method in methods:
            if method == 'orig':
                convsn_list_tmp = [1.0]
            else:
                convsn_list_tmp = convsn_list
                print('method: ', method)
            for convsn in convsn_list_tmp:
                for seed in seed_list:
                    for sms in source_model_seeds:
                        try:
                            if args.job_type == 'train':
                                # command = f"sbatch job_submit.slurm --dataset {dataset} --method {method} --mode {mode} --seed {seed}"
                                # print(args.unlearn_method)
                                command = f"sbatch job_submit.slurm {method} {mode} {seed} {convsn} {args.model} {args.lr} {args.dataset} {args.unlearn_method}"

                            elif args.job_type == 'conf': 
                                command = f"sbatch job_conf_submit.slurm {method} {mode} {convsn} {args.widen_factor} {args.model} {args.lr} {args.dataset} {args.model_path} {args.epoch} {args.unlearn_indices}"

                            elif args.job_type == 'RMIA':
                                if sms < 0:
                                    command = f"sbatch job_rmia_submit.slurm {method} {mode} {seed} {args.model} {args.dataset} {args.unlearn_indices} {args.model_path} {args.epoch} {args.lr} {args.LRsteps} {args.trials} {args.inclusion_mat_path} {args.reference_mat} {args.norm_cond} {unlearn_count}"
                                else:

                                    # if 'retrain' in args.model_path:
                                    #     command = f"sbatch job_mia_submit.slurm {method} {mode} {seed} {args.model} {args.dataset} {args.unlearn_indices} {args.model_path}{sms} {args.epoch} {args.mask_path}"

                                    args.unlearn_method = args.model_path.split('unlearn/')[1].split('/')[0]

                                    if args.unlearn_method == 'retrain':
                                        print('model is retrain!')
                                        model_name = args.model_path.split('/')[-1]
                                    else:
                                        if args.mask_path is None:
                                            print('mask path is None')
                                            model_name = args.model_path.split('/')[-3]
                                        else:
                                            model_name = args.model_path.split('/')[-4]

                                    print('model name: ', model_name)

                                    model_path_parts = args.model_path.split(model_name)
                                    if args.unlearn_method == 'retrain':
                                        model_path = model_path_parts[0] + model_name + f'{sms}/'
                                    else:
                                        model_path = model_path_parts[0] + model_name[:-1] + f'{sms}/' + model_path_parts[1] 

                                    print('model path: ', model_path)

                                    command = f"sbatch job_rmia_submit.slurm {method} {mode} {seed} {args.model} {args.dataset} {args.unlearn_indices} {model_path} {args.epoch} {args.mask_path} {args.lr} {args.LRsteps} {args.trials} {args.inclusion_mat_path} {args.reference_mat} {args.norm_cond} {unlearn_count}"

                            elif args.job_type == 'RMIA_ref':
                                command = f"sbatch job_rmia_ref_submit.slurm {method} {mode} {seed} {args.model} {args.dataset} {args.unlearn_indices} {args.model_path} {args.epoch} {args.mask_path} {args.lr} {args.LRsteps} {args.trials} {args.inclusion_mat_path} {args.norm_cond}"

                            elif args.job_type == 'unlearn':
                                if sms < 0:
                                    command = f"sbatch job_unlearn_submit.slurm {method} {mode} {seed} {convsn} {args.model} {args.lr} {args.dataset} {args.unlearn_indices} {args.model_path} {args.unlearn_method} {args.LRsteps} {args.norm_cond} {args.epoch} {args.remain} {args.salun_ratio}"
                                else:
                                    command = f"sbatch job_unlearn_submit.slurm {method} {mode} {seed} {convsn} {args.model} {args.lr} {args.dataset} {args.unlearn_indices} {args.model_path}{sms} {args.unlearn_method} {args.LRsteps} {args.norm_cond} {args.epoch} {args.remain} {args.salun_ratio}"
                            
                            elif args.job_type == 'ulira':
                                if sms < 0:
                                    command = f"sbatch job_ulira_submit.slurm {method} {mode} {seed} {convsn} {args.model} {args.lr} {args.dataset} {args.unlearn_indices} {args.model_path} {args.unlearn_method} {args.LRsteps} {args.norm_cond} {args.epoch} {args.remain} {args.salun_ratio}"
                                else:
                                    command = f"sbatch job_ulira_submit.slurm {method} {mode} {seed} {convsn} {args.model} {args.lr} {args.dataset} {args.unlearn_indices} {args.model_path}{sms} {args.unlearn_method} {args.LRsteps} {args.norm_cond} {args.epoch} {args.remain} {args.salun_ratio}"
                            
                            
                            print(command)
                            os.system(command)

                        except Exception as e:
                            print(e)
                            continue
