from __future__ import print_function
import logging
import os
import sys
import datetime
import time 
import random
import numpy as np
import argparse
import copy
import pickle
from tqdm import tqdm, trange
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

from utils import *
from ekfac import *

from asam import ASAM, SAM, IGSreg

EPS = 1e-24

def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    #torch.backends.cudnn.deterministic = True  # type: ignore
    torch.backends.cudnn.benchmark = True  # type: ignore
    
def set_logger(loggingFileName):
    logger = logging.getLogger()
    if (logger.hasHandlers()):
        logger.handlers.clear()
    
    logger.setLevel(logging.INFO)

    formatter = logging.Formatter('[%(asctime)s] - %(message)s', '%Y/%m/%d %H:%M:%S')

    file_handler = logging.FileHandler(loggingFileName)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    return logger    

class SGD(SAM):
    @torch.no_grad()
    def descent_step(self):
        self.optimizer.step()
        self.optimizer.zero_grad()
        
def get_optimizer(method, model, lr, momentum, weight_decay, rho, eta, lr_scheduler, preconditioner, normalize):
    
    if lr_scheduler == 'cosine_warmup':
        base_optimizer = optim.SGD(model.parameters(), lr=0.01*lr, momentum=momentum, weight_decay=weight_decay)
    else:
        base_optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
    if method =='SAM' or method =='ASAM':
        opt = eval(method)(base_optimizer,
                           model,   
                           rho=rho,
                           eta=eta
                          )  
    elif method =='SGD' or method == 'GR':
        opt = eval('SGD')(base_optimizer, model)
    elif method == 'IGSreg':
        opt = eval(method)(base_optimizer,
                           model,   
                           preconditioner, 
                           rho=rho,
                           eta=eta, 
                           normalize=normalize
                          )  
    
    return opt

def train(args):
    
    # logger
    os.makedirs('./log/', exist_ok=True)
    logger = set_logger(os.path.join('./log/'+args.name+'_'+time_now+'_seed'+str(args.seed)+'.log'))
    logger.info(args)

    # set random seed to reproduce the work
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    # get dataset / model / optimizer / criterion / lr_scheduler
    train_loader, test_loader = get_data(dataset=args.dataset,
                                         train_bs=args.batch_size,
                                         test_bs=args.test_batch_size,
                                         data_augmentation=args.data_aug,
                                         normalization=True,
                                         shuffle=True,
                                         cutout=args.cutout
                                        )    
     
    model = get_model(args.model,
                      dataset=args.dataset,
                      num_classes=args.num_classes
                     )
    if args.cuda:
        model = model.cuda()
    if args.parallel:
        model = torch.nn.DataParallel(model)

        
    if args.method == 'IGSreg':
        preconditioner = EKFAC(model, args.eta, update_freq=args.update_freq, 
                               ra=args.ra, alpha=args.alpha, sua=args.sua, 
                              save_freq = args.save_freq, sqrt_eps=args.sqrt_eta,
                              topk = args.topk, kfe_eps = args.kfe_eps)
    else:
        preconditioner = None
        
    minimizer = get_optimizer(args.method, 
                              model,
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay,
                              rho=args.rho,
                              eta=args.eta, lr_scheduler=args.lr_scheduler,
                              preconditioner = preconditioner,
                              normalize = args.grad_normalize
                             )

    criterion = get_criterion(args.criterion, args.smoothing)

    lr_scheduler = get_lr_scheduler(args.lr_scheduler,
                                    optimizer=minimizer.optimizer,
                                    milestones=args.milestones,
                                    gamma=args.gamma,
                                    epochs=args.epochs,
                                    T_0=args.T_0, T_mult=args.T_mult, T_up=args.T_up, eta_max=args.lr
                                   )
    print("initial lr: ", minimizer.optimizer.__dict__['param_groups'][0]['lr'])
    max_test_acc = 0
    max_train_acc = 0
    timer_train_reg = 0
    acc = 1/args.num_classes
    test_loss = np.log(args.num_classes)
    grad_dict = defaultdict(dict)
    for epoch in range(1, args.epochs + 1):
        print('Current Epoch: ', epoch)
        start_time = time.time()
        train_loss = 0.
        total_num = 0
        correct = 0  
        lr = minimizer.optimizer.__dict__['param_groups'][0]['lr'] 

        with tqdm(total=len(train_loader.dataset)) as progressbar:
            for batch_idx, (data, target) in enumerate(train_loader):   
                temp_start = time.time()
                step = (epoch-1)*len(train_loader) + batch_idx 
                if args.cuda:
                    data, target = data.cuda(), target.cuda()

                model.train()

                enable_running_stats(model) 
                output = model(data) 
                if args.method == 'IGSreg' and epoch >= args.reg_start_epoch:
                    if args.triple_backward:
                        c_loss0 = criterion(output, target)
                        c_loss0.backward(retain_graph=True)
                        for n, p in model.named_parameters():
                            grad_dict[p] = p.grad.data.clone()
                        model.zero_grad()
                        K = torch.distributions.categorical.Categorical(logits = output)
                        sampled_target = K.sample()
                        c_loss = criterion(output, sampled_target)
                    elif args.sample_y:
                        K = torch.distributions.categorical.Categorical(logits = output)
                        sampled_target = K.sample()
                        c_loss = criterion(output, sampled_target)
                    else:
                        c_loss = criterion(output, target)
                elif args.method == 'GR':
                    if epoch >= args.reg_start_epoch:
                        std_loss = criterion(output, target)                    
                        grad_for_reg = torch.autograd.grad(std_loss, model.parameters(), 
                                                           retain_graph=True, create_graph=True)
                        loss_R = 0
                        for g in grad_for_reg:
                            loss_R += g.norm()**2   
                        c_loss = criterion(output, target) + args.rho*loss_R
                    else:
                        c_loss = criterion(output, target)
                else:
                    c_loss = criterion(output, target)

                if (args.method == 'IGSreg' and epoch >= args.reg_start_epoch) and args.triple_backward:
                    train_loss += target.size()[0]*c_loss0.item()
                else:
                    train_loss += target.size()[0]*c_loss.item()
                total_num += target.size()[0]
                _, predicted = output.max(1)
                correct_in_batch = predicted.eq(target).sum().item()
                correct += correct_in_batch
                
                progressbar.set_postfix(loss=train_loss/total_num,
                                        acc=100. * correct / total_num,
                                        epoch=epoch,
                                       lr=minimizer.optimizer.__dict__['param_groups'][0]['lr'])
                progressbar.update(target.size(0))

                if args.method == 'SGD' or args.method == 'GR':
                    c_loss.backward() 
                    minimizer.descent_step()
                else:
                    # args.method =='SAM' or args.method =='ASAM' or args.method == 'IGSreg'
                    if epoch >= args.reg_start_epoch:
                        # first forward-backward step
                        c_loss.backward()
                        alg_input = None
                        if args.method == 'IGSreg' and args.triple_backward:
                            alg_input = grad_dict
                        minimizer.ascent_step(alg_input=alg_input)

                        # second forward-backward step
                        disable_running_stats(model)
                        output = model(data) 
                        (criterion(output, target)).backward()   
                        minimizer.descent_step()
                    else:
                        #print('SGD update')
                        c_loss.backward() 
                        minimizer.descent_step_sgd()


                if args.method == 'IGSreg' and epoch >= args.reg_start_epoch:
                    g_nat_sqnorm = preconditioner.g_nat_sqnorm_traj[-1]
                else:
                    g_nat_sqnorm = 0.

                if (args.method == 'IGSreg' and epoch >= args.reg_start_epoch) and args.triple_backward:
                    e_train_loss = c_loss0.item()
                else:
                    e_train_loss = c_loss.item()
                e_train_acc = correct_in_batch / target.size()[0]
                e_reg = 0
                train_time = time.time()
                timer_train_reg += train_time - temp_start
                if step!=0 and args.test and step%args.test_every_n_steps==0: 
                    acc, test_loss = test(model,
                                          test_loader,
                                          print_opt=False
                                         )
                    max_test_acc = max(acc, max_test_acc)
                test_time = time.time()     
                log_list = [step,
                            train_time - start_time, 
                            test_time - train_time, 
                            lr,
                            e_train_loss,
                            e_train_acc,
                            e_reg, ###################
                            0,
                            test_loss,
                            acc,
                            g_nat_sqnorm ###################
                           ]
                log_input = (*log_list, )
                logger.info(('%d'+'\t%.4f'*(len(log_input)-1))%(log_input))     
        
        if correct / total_num > max_train_acc:
            max_train_acc = correct / total_num
        
        if args.lr_scheduler: 
            lr_scheduler.step()  
    logger.info("max test acc: %.4f"%(max_test_acc))
    logger.info("max train acc: %.4f"%(max_train_acc))
    logger.info("train reg time: %.4f"%(timer_train_reg))
    # save model
    model_state = copy.deepcopy(model.state_dict())
    os.makedirs('./'+args.saving_folder, exist_ok=True)
    PATH = './'+args.saving_folder+args.name+'_'+time_now+'_seed'+str(args.seed)+'.pth'
    print(PATH)
    torch.save(model_state, PATH)
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Training ')
    parser.add_argument('--gpu',
                        type=int,
                        default=0,
                        help='choose gpu number')
    parser.add_argument('--method',
                        type=str,
                        default='SAM',
                        help='SAM/ASAM/SGD/IGSreg/GR')

    parser.add_argument('--dataset',
                        type=str,
                        default='cifar10',
                        help='cifar10/mnist/cifar100')
    parser.add_argument('--cutout',
                        action='store_true',
                        help='do we use cutout or not') 
    parser.add_argument('--batch-size',
                        type=int,
                        default=128,
                        help='input batch size for training (default: 128)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=500,
                        help='input batch size for testing (default: 1024)')
    parser.add_argument('--epochs',
                        type=int,
                        default=200,
                        help='number of epochs to train (default: 200)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.1,
                        help='learning rate (default: 0.1)')
    parser.add_argument('--weight-decay',
                        default=0.0005,
                        type=float,
                        help='weight decay (default: 0.0005)') 
    parser.add_argument('--momentum',
                        default=0.9,
                        type=float,
                        help='momentum (default: 0.9)') 
    parser.add_argument('--model',
                        type=str,
                        default='Simple100',
                        help='vgg11_bn/3FCN/resnet20/resnet56/WRN164/WRN168')
    
    parser.add_argument('--data_aug',
                        default=True,
                        type=bool,
                        help='data augmentation (default: True)') 
    
    parser.add_argument('--no-cuda',
                        action='store_true',
                        help='do we use gpu or not')
    parser.add_argument('--no-parallel',
                        action='store_true',
                        help='do we use parallel or not') 

    parser.add_argument('--saving-folder',
                        type=str,
                        default='pretrained/',
                        help='choose saving name')
    parser.add_argument('--savemodels',
                        action='store_true',
                        help='save models')
    parser.add_argument('--name',
                        type=str,
                        default='noname',
                        help='choose saving name')
    parser.add_argument('--no-overwrite',
                        action='store_true',
                        help='do we rewrite or not')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        help='random seed (default: 1)')

    parser.add_argument('--criterion',
                        type=str,
                        default='label_smoothing',
                       help='cross-entropy/mse/label_smoothing')

    parser.add_argument('--lr-scheduler',
                        type=str,
                        default='cosine',
                        help='cosine/multistep/')

    parser.add_argument("--milestones", nargs='*', type=int, default=None)
    parser.add_argument("--gamma", type=float, default=0.2)

    parser.add_argument('--test-every-n-steps',
                        type=int,
                        default=200)

    parser.add_argument("--rho", type=float, default=2.)
    parser.add_argument('--no-grad-normalize',
                        action='store_true',
                        help='use the unnormlized rho (with the gradient norm)')

    parser.add_argument('--no-label-smoothing',
                        action='store_true',
                        help='label smoothing')
    parser.add_argument("--smoothing", type=float, default=0.1)

    parser.add_argument("--eta", default=0.01, type=float, help="Eta for ASAM.")
    
    parser.add_argument("--reg_start_epoch", type=int, default=1, help="the regularization starting epoch")
    parser.add_argument("--T_0", type=int, default=200, help="the initial period of cosine warmup scheduler")
    parser.add_argument("--T_up", type=int, default=10, help="the epochs to warm up for cosine warmup scheduler")
    parser.add_argument("--T_mult", type=int, default=1, 
                        help="the constant multiplied to the period after each period for cosine warmup scheduler")
    parser.add_argument("--no-test", action='store_true',
                        help='test')
    parser.add_argument("--seeds", nargs='*', type=int, default=None)

    parser.add_argument("--update_freq", type=int, default=100, help="the update frequency for EKFAC eigenvectors")
    parser.add_argument("--save_freq", type=int, default=20, help="the frequency to save eigenvalue stats (EKFAC)")
    parser.add_argument("--sua", action='store_true', help="use SUA for convnets or not (EKFAC)")
    parser.add_argument("--no-ra", action='store_true', help="use running average or not (EKFAC)")
    parser.add_argument("--alpha", type=float, default=0.75, help="the running average parameter (EKFAC)")
    parser.add_argument("--no-sample_y", action='store_true', help="use sampled y in calculating regularizer or not (EKFAC)")
    parser.add_argument("--sqrt_eta", action='store_true', help="use sqrt of eta in stabilizing eigenvalue problem of EKFAC")
    parser.add_argument("--topk", type=int, default=100, help="the number of top eigenvalues to save")
    parser.add_argument("--no-kfe_eps", action='store_true', help="add eps for the kfe eigenproblem or not (EKFAC)")
    parser.add_argument("--triple_backward", action='store_true', help="perform backward three times (one for g, one for F_hat, one for descend)")
    
    args = parser.parse_args()
    
    ## default: ON
    args.grad_normalize = not (args.no_grad_normalize)
    args.label_smoothing = not (args.no_label_smoothing)
    args.cuda = not (args.no_cuda)
    args.overwrite = not (args.no_overwrite)
    args.parallel = not (args.no_parallel)
    args.test = not args.no_test
    
    args.ra = not (args.no_ra)
    args.sample_y = not (args.no_sample_y)
    if args.triple_backward:
        args.sample_y = True
    args.kfe_eps = not (args.no_kfe_eps)
    
    ## modification
    if args.label_smoothing:
        args.criterion = 'label_smoothing'
        
    if args.dataset == 'cifar10' or args.dataset == 'mnist':
        args.num_classes = 10
    elif args.dataset == 'cifar100':
        args.num_classes = 100
    else:
        raise ValueError("Unknown dataset")
        
    # time
    time_now = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
    print(time_now)
    
    for arg in vars(args):
        print(arg, getattr(args, arg))
   
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    # set random seed to reproduce the work
    if args.seeds is None:
        seed_everything(args.seed)
        train(args)
    else:
        print(args.seeds)
        for seed in args.seeds:
            args.seed = seed
            seed_everything(args.seed)
            train(args)