#*
# Authors: Anonymous
# This file is part of OASIS library.
#
# This file is based on the AdaHessian repository
# https://github.com/amirgholami/adahessian
#*

from __future__ import print_function
import logging
import os
import sys
import copy

import numpy as np
import argparse
from tqdm import tqdm, trange

import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torchvision import datasets, transforms
from torch.autograd import Variable

from utils import *
from resnet import *
from optim_adahessian import Adahessian
from adaadahessian_ada_lr_d import AdaAdaHessian

# Training settings
parser = argparse.ArgumentParser(description='PyTorch Example')
parser.add_argument('--batch-size', type=int, default=256, metavar='B',
                    help='input batch size for training (default: 256)')
parser.add_argument('--test-batch-size', type=int, default=256, metavar='TB',
                    help='input batch size for testing (default: 256)')
parser.add_argument('--epochs', type=int, default=160, metavar='E',
                    help='number of epochs to train (default: 10)')
parser.add_argument('--warmstart-samples-idbatch', type=int, default=0,
                    help='warmstart samples for hessian diagonal estimation based on independent batches (default: 0)')
parser.add_argument('--idbatch-mul', type=int, default=10,
                    help='number of times each independendtly sampled batch is itself sampled')
parser.add_argument('--warmstart-samples-fbatch', type=int, default=1,
                    help='warmstart samples for hessian diagonal estimation based on a single batch (default: 1)')
parser.add_argument('--alpha', type=float, default=1e-03,
                    help='alpha for truncation (default: 1e-03)')
parser.add_argument('--beta', type=float, default=0.999,
                    help='beta for diagonal approximation (default: 0.999)')
parser.add_argument('--zeta', type=float, default=0.5,
                    help='zeta for eta update (default: 0.5)')
parser.add_argument('--gamma', type=float, default=1.0,
                    help='gamma for eta update (default: 1.0)')
parser.add_argument('--rho', type=float, default=0.5,
                    help='regulates the amount of decrease in damping')
parser.add_argument('--lr', type=float, default=1e-6, metavar='LR',
                    help='learning rate (default: 1e-6)')
parser.add_argument('--lr-decay-use', type=bool, default=False,
                    help='whether to use lr decay scheduler')
parser.add_argument('--lr-decay', type=float, default=0.1,
                    help='learning rate ratio')
parser.add_argument('--lr-decay-epoch', type=int, nargs='+', default=[80, 120],
                    help='decrease learning rate at these epochs.')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float,
                    metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('--depth', type=int, default=20,
                    help='choose the depth of resnet')
parser.add_argument('--optimizer', type=str, default='adahessian',
                    help='choose optim')

args = parser.parse_args()
# set random seed to reproduce the work
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)

for arg in vars(args):
    print(arg, getattr(args, arg))

if not os.path.isdir('checkpoint/'):
    os.makedirs('checkpoint/')

logs_folder = 'logs_ada_lr_d_ResNet_' + str(args.depth) + '/'
    
if not os.path.isdir(logs_folder):
    os.makedirs(logs_folder)
# get dataset
train_loader, test_loader = getData(
    name='cifar10', train_bs=args.batch_size, test_bs=args.test_batch_size)

# make sure to use cudnn.benchmark for second backprop
cudnn.benchmark = True

# get model and optimizer
model = resnet(num_classes=10, depth=args.depth).cuda()
print(model)
model = torch.nn.DataParallel(model)
if args.optimizer == 'adaadahessian':    
    prev_model = copy.deepcopy(model)
    prev_model.cuda()
print('    Total params: %.2fM' % (sum(p.numel()
                                       for p in model.parameters()) / 1000000.0))

criterion = nn.CrossEntropyLoss()
if args.optimizer == 'sgd':
    optimizer = optim.SGD(
        model.parameters(),
        lr=args.lr,
        momentum=0.9,
        weight_decay=args.weight_decay)
elif args.optimizer == 'adam':
    optimizer = optim.Adam(
        model.parameters(),
        lr=args.lr,
        weight_decay=args.weight_decay)
elif args.optimizer == 'adamw':
    print('For AdamW, we automatically correct the weight decay term for you! If this is not what you want, please modify the code!')
    args.weight_decay = args.weight_decay / args.lr
    optimizer = optim.AdamW(
        model.parameters(),
        lr=args.lr,
        weight_decay=args.weight_decay)
elif args.optimizer == 'adahessian':
    print('For AdaHessian, we use the decoupled weight decay as AdamW. Here we automatically correct this for you! If this is not what you want, please modify the code!')
    args.weight_decay = args.weight_decay / args.lr
    optimizer = Adahessian(
        model.parameters(),
        lr=args.lr,
        weight_decay=args.weight_decay)
elif args.optimizer == 'adaadahessian':
    args.weight_decay = args.weight_decay
    optimizer = AdaAdaHessian(
        model.parameters(),
        lr=args.lr,
        alpha=args.alpha,
        beta=args.beta,
        zeta=args.zeta,
        gamma=args.gamma,
        weight_decay=args.weight_decay,
        warmstart_samples_fbatch=args.warmstart_samples_fbatch)
    prev_optimizer = AdaAdaHessian(
        prev_model.parameters(),
        lr=args.lr,
        alpha=args.alpha,
        beta=args.beta,
        zeta=args.zeta,
        gamma=args.gamma,
        weight_decay=args.weight_decay,
        warmstart_samples_fbatch=args.warmstart_samples_fbatch)
else:
    raise Exception('We do not support this optimizer yet!!')

# if args.lr_decay_use:
#     # learning rate schedule
#     scheduler = lr_scheduler.MultiStepLR(
#         optimizer,
#         args.lr_decay_epoch,
#         gamma=args.lr_decay,
#         last_epoch=-1)

if args.lr_decay_use and (args.optimizer != 'adaadahessian'):
    scheduler = lr_scheduler.MultiStepLR(
        optimizer,
        args.lr_decay_epoch,
        gamma=args.lr_decay,
        last_epoch=-1)

# args_train = args.optimizer + \
#              '_lr_' + str(np.around(args.lr, 8)) + \
#              '_alpha_' + str(args.alpha) + \
#              '_decay_' + str(int(args.lr_decay_use))
            
args_train = args.optimizer + \
             '_lr_' + str(np.around(args.lr, 8)) + \
             '_alpha_' + str(args.alpha) + \
             '_beta_' + str(args.beta) + \
             '_zeta_' + str(args.zeta) + \
             '_gamma_' + str(args.gamma) + \
             '_rho_' + str(args.rho) + \
             '_wd_' + str(args.weight_decay) + \
             '_ws_idb_' + str(args.warmstart_samples_idbatch) + \
             '_d_' + str(args.depth) + \
             '_s_' + str(args.seed)  +\
             '_bs_' + str(args.batch_size)

best_acc = 0.0
train_loss_stats = []
test_acc_stats = []

if args.optimizer == 'adaadahessian':
    
    eta_stats = []

    print('Starting warmstart sampling')

    batch_counter = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        if batch_counter >= args.warmstart_samples_idbatch:
            break
        data, target = data.cuda(), target.cuda()
        output = model(data)
        optimizer.zero_grad()
        loss = criterion(output, target)
        loss.backward(create_graph=True)
        for i in range(args.idbatch_mul):
            optimizer.accumulate_h_diag()
        batch_counter += 1        

    print('Completed warmstart sampling')
    
    for epoch in range(1, args.epochs + 1):
        print('Current Epoch: ', epoch)
        train_loss = 0.
        total_num = 0
        correct = 0
        model.train()
        prev_model.train()
        
        if epoch in args.lr_decay_epoch:
            optimizer.update_damping(args.rho)
            prev_optimizer.update_damping(args.rho)                
        
        with tqdm(total=len(train_loader.dataset)) as progressbar:
            for batch_idx, (data, target) in enumerate(train_loader):
                data, target = data.cuda(), target.cuda()
                
                optimizer.zero_grad()
                prev_optimizer.zero_grad()
                
                prev_output = prev_model(data)
                prev_loss = criterion(prev_output, target)
                prev_loss.backward()
                
                output = model(data)
                loss = criterion(output, target)
                loss.backward(create_graph=True)
                
                train_loss += loss.item() * target.size()[0]
                total_num += target.size()[0]
                _, predicted = output.max(1)
                correct += predicted.eq(target).sum().item()
                
                # if not the very first step
                if 'first_step' in optimizer.param_groups[-1]:
                    optimizer.compute_dif_norms(prev_optimizer)
                prev_model.load_state_dict(model.state_dict())
                
                for group in optimizer.param_groups:
                    if isinstance(group['lr'], torch.Tensor):
                        eta_stats.append(group['lr'].item())
                    else:
                        eta_stats.append(group['lr'])
                
                optimizer.step()
                
                progressbar.update(target.size(0))

        acc = test(model, test_loader)
        train_loss /= total_num
        
        print(np.around(train_loss, 2))
        print(np.around(acc * 100, 2))

        train_loss_stats.append(train_loss)    
        test_acc_stats.append(acc * 100)

        torch.save({
            'train_loss_stats': train_loss_stats,
            'test_acc_stats': test_acc_stats,
            'eta_stats': eta_stats
        }, logs_folder + args_train + '_log.pkl')

        if acc > best_acc:
            best_acc = acc
            torch.save({
                'epoch': epoch,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'best_accuracy': best_acc,
                }, 'checkpoint/' + args_train + '_netbest.pkl')

else:    
    if args.lr_decay_use:
        scheduler.step()
    model.train()
    
    for epoch in range(1, args.epochs + 1):
        print('Current Epoch: ', epoch)
        train_loss = 0.
        total_num = 0
        correct = 0  
        if args.lr_decay_use:
            scheduler.step()
        model.train()
        with tqdm(total=len(train_loader.dataset)) as progressbar:
            for batch_idx, (data, target) in enumerate(train_loader):
                data, target = data.cuda(), target.cuda()
                output = model(data)
                optimizer.zero_grad()
                loss = criterion(output, target)
                loss.backward(create_graph=True)
                train_loss += loss.item() * target.size()[0]
                total_num += target.size()[0]
                _, predicted = output.max(1)
                correct += predicted.eq(target).sum().item()
                optimizer.step()
                progressbar.update(target.size(0))

        acc = test(model, test_loader)
        train_loss /= total_num
    
#     print(f'Training Loss of Epoch {epoch}: {np.around(train_loss, 2)}')
#     print(f'Testing of Epoch {epoch}: {np.around(acc * 100, 2)} \n')
    
        print(np.around(train_loss, 2))
        print(np.around(acc * 100, 2))

        train_loss_stats.append(train_loss)    
        test_acc_stats.append(acc * 100)

        torch.save({
            'train_loss_stats': train_loss_stats,
            'test_acc_stats': test_acc_stats,
        }, logs_folder + args_train + '_log.pkl')

        if acc > best_acc:
            best_acc = acc
            torch.save({
                'epoch': epoch,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'best_accuracy': best_acc,
                }, 'checkpoint/' + args_train + '_netbest.pkl')

# print(f'Best Acc: {np.around(best_acc * 100, 2)}')

print(np.around(best_acc * 100, 2))
