import os
import torch
import argparse
from datetime import datetime

from utils.utils import *
from utils.method_DLF import Reg_distill_DLF
    
def main(args):
    ### Make directory
    methods = os.path.basename(__file__).split(".")[0]
    save_dir = os.path.join(args.save_dir, f'{methods}', args.dataset)
    os.makedirs(save_dir, exist_ok=True)
    save_dir = os.path.join(save_dir, f"{datetime.now().strftime('%Y%m%d_%H%M%S')}-model-{'_'.join([args.arch_s]+ [str(x) for x in args.h_vec])}_latent-{args.latent_dim}_lamb-{args.lamb}_seed-{args.seed}")
    os.makedirs(save_dir, exist_ok=True)

    ### Set gpu_devise and seed number
    if torch.cuda.is_available():
        torch.cuda.set_device(args.gpu_number)
    
    ### Define and train Batch Ensemble model
    reg_distill = Reg_distill_DLF(args, save_dir)
    reg_distill._fix_seed()
    reg_distill._make_loaders()
    reg_distill._define_teacher_model()
    reg_distill._define_model_and_optimizer()
    
    ### Pretraining
    reg_distill._pretrain()    
    ### Choice of the initial parameter
    reg_distill._mmd_train()
    ### Distillation
    reg_distill._distill()
    

def create_parser():
    parser = argparse.ArgumentParser(description='PyTorch Training')
    ### Base Info
    parser.add_argument('--explanation', default = '', type = str, help = 'explanation of code')
    parser.add_argument('--save_dir', default = './results', type = str)
    parser.add_argument('--data_dir', default = './dataset', type = str)
    parser.add_argument('--dataset', default='Boston', type=str, metavar='PATH', help='dataset')
    parser.add_argument('--seed', default=1, type=int, metavar='N', help='number of seed')
    parser.add_argument('--data_seed', default=1, type=int, metavar='N', help='number of seed')
    parser.add_argument('--gpu_number', default=1, type=int, metavar='N', help='number of gpu (default: 0)')
    parser.add_argument('--workers', default=0, type=int, metavar='N', help='number of data loading workers (default: 1)')
    
    ### Learning info
    parser.add_argument('--batch_size', default=100, type=int, metavar='N', help='mini-batch size (default: 128)')
    parser.add_argument('--num_ens', default=50, type=int, help='ensemble_size')
    parser.add_argument('--ratio_valid', default=0.2, type=float, help='ratio-valid')
    parser.add_argument('--weight_decay', default=1e-4, type=float, metavar='LR', help='weight-decay')
    
    ### Teacher architecture and info
    parser.add_argument('--teacher_dir', default = './teachers/Boston/MLP_100_100', type = str)
    parser.add_argument('--arch_t', metavar='ARCH', default='MLP', help='model architecture corresponding teacher')
    parser.add_argument('--teacher_h_vec', action='store', type=int, nargs='*', default=[100, 100], help="Examples: -i item1 item2, -i item3")
    
    ### Student architecture
    parser.add_argument('--arch_s', metavar='ARCH', default='MLP_DLF', help='model architecture corresponding student')
    parser.add_argument('--activation', type=str, default='ReLU', choices=['ReLU', 'LeakyReLU', 'SiLU', 'hardtanh', 'Softplus'])  
    parser.add_argument('--h_vec', action='store', type=int, nargs='*', default=[50])
    parser.add_argument('--mu_h_vec', action='store', type=int, nargs='*', default=[50])
    parser.add_argument('--phi_h_vec', action='store', type=int, nargs='*', default=[50])
    parser.add_argument('--latent_dim', default=10, type=int, help='latent_dim')  
    
    ### Pretraining
    parser.add_argument('--pre_epochs', default=300, type=int, metavar='N', help='number of total epochs to run')
    parser.add_argument('--pre_lr', default=1e-2, type=float, metavar='LR', help='learning rate')
    
    ### Choice of the initial parameter 
    parser.add_argument('--mmd_epochs', default=300, type=int, metavar='N', help='number of total epochs to run')
    parser.add_argument('--mmd_lr', default=1e-2, type=float, metavar='LR', help='learning rate')
    parser.add_argument('--lamb', default=1., type=float, help='MMD')
    
    ### Distillation
    parser.add_argument('--lr', default=1e-2, type=float, metavar='LR', help='learning rate')
    parser.add_argument('--lr_schedule', action='store', type=int, nargs='*', default=[300], help="Examples: -i item1 item2, -i item3")
    parser.add_argument('--epochs', default=400, type=int, metavar='N', help='number of total epochs to run')
    parser.add_argument('--evaluation_epochs', default=1, type=int, metavar='idxs', help='evaluation frequency in epochs')
    parser.add_argument('--print_freq', default=100, type=int, metavar='N', help='print frequency (default: 10)')
    
    return parser

if __name__ == '__main__':
    args = create_parser().parse_args()
    for key, value in vars(args).items():
        print(f'\t [{key}]: {value}')
    main(args)
    