import argparse
import os
import shutil
from datetime import datetime

import torch 
import torch.backends.cudnn as cudnn
import torch.optim
import torch.nn as nn
import numpy as np
import train_util
import data_util
#import eval_loss_util

import models as models


model_names = sorted(name for name in models.__dict__
	if name.islower() and not name.startswith("__")
	and callable(models.__dict__[name]))


lr_scheds= ['wr_default']
parser = argparse.ArgumentParser(description='CIFAR-10 Training')
parser.add_argument('--epochs', default=300, type=int,
                    help='number of total epochs to run')
parser.add_argument('--start_epoch', default=0, type=int,
                    help='manual epoch number (useful on restarts)')
parser.add_argument('--frst_ann', '--fa', default=180, type=int,
                    help='first annealing time')
parser.add_argument('--snd_ann', '--sa', default=255, type=int,
                    help='second annealing time')
parser.add_argument('--n_batch_train', '--nbt',  default=128, type=int,
                    help='train mini-batch size (default: 1024)')
parser.add_argument('--n_batch_test', default=100, type=int,
                    help='test mini-batch size (default: 100)')
parser.add_argument('--path_data', default='./data', type=str,
                    help='path to store data')
parser.add_argument('--optim_choice',   default="sgd", type=str,
                    help='choice of optimizer')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                    help='initial learning rate')
parser.add_argument('--arch', default="vgg16", type=str,
                    help='choice of architecture')
parser.add_argument('--save', default="False", type=str,
                    help='Save or not')
parser.add_argument('--momentum', '--m', default=0.0, type=float, help='momentum')
parser.add_argument('--weight_decay', '--wd', default=5e-4, type=float,
                    help='weight decay (default: 5e-4)')
parser.add_argument('--print_freq', '-p', default=10, type=int,
                    help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str,
                    help='path to latest checkpoint (default: none)')
parser.add_argument('--name', default='CIFAR-10-VGG16', type=str,
                    help='name of experiment')
parser.add_argument('--dataset', choices=["cifar10", "cifar100"], default="cifar10",
                    help='cifar10 or cifar100')
parser.add_argument('--lr_sched', choices=lr_scheds, default='wr_default', 
                    help=' | '.join(lr_scheds))
parser.add_argument('--seed', '-s', default=0, type=int,
                    help='seed (default: 0)')
parser.add_argument('--save_model', default="False", type=str,
                    help='Save Model')
parser.add_argument('--model_random', default="False", type=str,
                    help='Model Random')

aug="/results"

def main():
    args = parser.parse_args()
    for arg in vars(args):
        print(arg, " : ", getattr(args, arg))

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    
    
    if args.dataset=="cifar100":
        args.path_data="./data/"
    else:
        args.path_data="./data/"
        
        
    train_loader, val_loader = data_util.load_data(args.n_batch_train,
                                                    args.n_batch_test, 
                                                    args.dataset,
                                                    args.path_data)
  

    print("=> creating model '{}'".format(args.arch))
    
    
    model_args = {
		"num_classes": 10 if args.dataset == "cifar10" else 100	
	}
	
    model = models.__dict__[args.arch](**model_args)
    
        
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)
    
    #if .model_random==True, just save the model and quit 
    if args.model_random=="True":
        curr_dir=os.getcwd()+"/results"
	
        path=curr_dir + "/"+args.arch
        if not os.path.exists(path): 
           os.makedirs(path)

        torch.save(model, path+"/"+"result_random_"+str(args.seed)) 
        return
   
    cudnn.benchmark = True

    criterion = nn.CrossEntropyLoss(reduction='mean').cuda()
    
    optim_hparams = {
        'initial_lr' : args.lr, 
        'momentum' : args.momentum,
        'weight_decay' : args.weight_decay
    }
    
    lr_hparams = {
            'initial_lr' : args.lr, 
            'lr_sched' : args.lr_sched,
            'frst_ann' : args.frst_ann,
            'snd_ann' : args.snd_ann}

    optimizer = train_util.create_optimizer(model,args.optim_choice,
        optim_hparams)

    test_tab=[]
    train_tab=[]


    for epoch in range(args.start_epoch, args.epochs):
        lr= train_util.adjust_lr(optimizer,args.optim_choice,
                epoch + 1,
                lr_hparams)
        
        print("Epoch" + str(epoch))
        for param_group in optimizer.param_groups:
            print("LR: "+str(param_group['lr']))
            print("mom: "+str(param_group['momentum']))
        
        


        train_loss = train_util.train_loop(
            train_loader,
            model,
            criterion,
            args.optim_choice,
            optimizer,
            epoch,
            device)

        val_acc = train_util.validate(
            val_loader,
            model,
            criterion,
            epoch,
            device)
        
        train_tab.append(train_loss)
        test_tab.append(val_acc)


        
 
    if args.n_batch_train==1024:
        OPT_1="GD"
    else: 
        OPT_1="SGD"
        
    for param_group in optimizer.param_groups:
        if param_group['momentum'] !=0:
            OPT_2="M"
            OPT=OPT_1+OPT_2
        else:
            OPT=OPT_1
                
    print("\n")    
    print("Final accuracy: {}".format(val_acc))
    print("Seed: {}".format(args.seed))
    print("Dataset: {}".format(args.dataset))
    print("Architecture: {}".format(args.arch))
    print("Optimization algorithm: {}".format(OPT))
    print("LR: {}; B: {}; M: {}; 1st anneal: {}; 2nd anneal: {}; WD: {}".format(args.lr,\
        args.n_batch_train, args.momentum, args.frst_ann, args.snd_ann,\
              args.weight_decay))
        
    #Save the model if .save_model==True
    if args.save_model=="True":
        curr_dir=os.getcwd()
        path=curr_dir + "/results" + "/"+args.arch
        
        if not os.path.exists(path): 
            os.makedirs(path)

        torch.save(model, path+"/result_trained_"+str(args.seed))    
        return
    
    
    if args.save=="True":

        curr_dir=os.getcwd()
        
        save_dir_res="{}/results".format(curr_dir)
        save_dir_main = "{}/{}/{}".format(save_dir_res, args.dataset, args.arch) 
        save_dir_train = "{}/train".format(save_dir_main)
        save_dir_test = "{}/test".format(save_dir_main)
        
        if not os.path.exists(save_dir_res):
            os.makedirs(save_dir_res)
            
        if not os.path.exists(save_dir_main):
            os.makedirs(save_dir_main)
            
        if not os.path.exists(save_dir_train):
            os.makedirs(save_dir_train)
            
        if not os.path.exists(save_dir_test):
            os.makedirs(save_dir_test)
        
       
            
        str_save="{}_LR{}_WD{}_FA{}_SA{}.npy".format(
                           OPT,str(args.lr),str(args.weight_decay),
                           str(args.frst_ann),str(args.snd_ann))
            
        trainstr_save="{}/train_{}".format(save_dir_train,str_save)
        teststr_save="{}/test_{}".format(save_dir_test,str_save)

        np.save(trainstr_save,np.array(train_tab))
        np.save(teststr_save,np.array(test_tab))

            
main()
