import os
import torch
import argparse
from datetime import datetime

from utils.utils import *
from utils.method_DLF import Cls_distill_DLF
    
def main(args):
    ### Make directory
    methods = os.path.basename(__file__).split(".")[0]
    save_dir = os.path.join(args.save_dir, f'{methods}', args.dataset)
    os.makedirs(save_dir, exist_ok=True)
    save_dir = os.path.join(save_dir, f"{datetime.now().strftime('%Y%m%d_%H%M%S')}-model-{args.arch_s}_latent-{args.latent_dim}_lamb-{args.lamb}_seed-{args.seed}")
    os.makedirs(save_dir, exist_ok=True)

    ### Set gpu_devise and seed number
    if torch.cuda.is_available():
        torch.cuda.set_device(args.gpu_number)
    
    ### Define and train Batch Ensemble model
    cls_distill = Cls_distill_DLF(args, save_dir)
    cls_distill._fix_seed()
    cls_distill._make_loaders()
    cls_distill._define_teacher_model()
    cls_distill._define_model_and_optimizer()
    
    ### Pretraining
    cls_distill._pretrain()
    ### MMD 
    cls_distill._mmd_train()
    ### Distillation
    cls_distill._distill()
    
    
def create_parser():
    parser = argparse.ArgumentParser(description='PyTorch Training')
    ### Base info
    parser.add_argument('--explanation', default = '', type = str, help = 'explanation of code')
    parser.add_argument('--save_dir', default = './results', type = str)
    parser.add_argument('--data_dir', default = './dataset', type = str)
    parser.add_argument('--dataset', default='cifar10', type=str, metavar='PATH', help='dataset')
    parser.add_argument('--seed', default=1, type=int, metavar='N', help='number of seed')
    parser.add_argument('--gpu_number', default=0, type=int, metavar='N', help='number of gpu (default: 0)')
    parser.add_argument('--workers', default=1, type=int, metavar='N', help='number of data loading workers (default: 1)')
    
    ### Learning info
    parser.add_argument('--batch_size', default=128, type=int, metavar='N', help='mini-batch size (default: 128)')
    parser.add_argument('--num_ens', type = int, default = 4)
    parser.add_argument('--weight_decay', default=5e-4, type=float, metavar='LR', help='learning rate')
    parser.add_argument('--perturb_method', default = 'TDiv-SDiv', type = str)
    
    ### Teacher architecture and info
    parser.add_argument('--teacher_dir', default = ['./teachers/cifar10/WRN_28_1(200)_seed(0)',
                                                    './teachers/cifar10/WRN_28_1(200)_seed(1)',
                                                    './teachers/cifar10/WRN_28_1(200)_seed(2)',
                                                    './teachers/cifar10/WRN_28_1(200)_seed(3)'], type = str)
    parser.add_argument('--arch_t', metavar='ARCH', default='WRN_28_1', help='model architecture')
    
    ### Student architecture and info
    parser.add_argument('--arch_s', metavar='ARCH', default='WRN_16_1_DLF', help='model architecture')
    parser.add_argument('--mu_h_vec', action='store', type=int, nargs='*', default=[50], help="Examples: -i item1 item2, -i item3")
    parser.add_argument('--phi_h_vec', action='store', type=int, nargs='*', default=[50], help="Examples: -i item1 item2, -i item3")
    parser.add_argument('--latent_dim', default = 8, type=int, help='latent dim')
    parser.add_argument('--n_samples', default = 100, type=int, help='number of samples')
    parser.add_argument('--droprate', default=0, type=float, help='droprate')
    parser.add_argument('--T', default=1, type=float, help='temperature')
    
    ### Pretraining
    parser.add_argument('--pre_epochs', default=40, type=int, metavar='N', help='number of total epochs to run')
    parser.add_argument('--pre_lr', default=0.1, type=float, metavar='LR', help='learning rate')
    
    ### Choice of the initial parameter 
    parser.add_argument('--mmd_lr', default=0.01, type=float, metavar='LR', help='learning rate')
    parser.add_argument('--mmd_epochs', default=40, type=int, metavar='N', help='number of total epochs to run')
    parser.add_argument('--lamb', default=1, type=float, help='regularization for mmd loss')
    
    ### Distillation
    parser.add_argument('--lr', default=0.1, type=float, metavar='LR', help='learning rate')
    parser.add_argument('--lr_schedule', action='store', type=int, nargs='*', default=[150, 180], help="Examples: -i item1 item2, -i item3")
    parser.add_argument('--epochs', default=200, type=int, metavar='N', help='number of total epochs to run')
    parser.add_argument('--evaluation_epochs', default=1, type=int, metavar='EPOCHS',  help='evaluation frequency in epochs, 0 to turn evaluation off (default: 1)')
    
    
    return parser   

    
if __name__ == '__main__':
    args = create_parser().parse_args()
    for key, value in vars(args).items():
        print(f'\t [{key}]: {value}')
    main(args)