#*
# Authors: Anonymous
# This file is part of OASIS library.
#
# This file is based on the AdaHessian repository
# https://github.com/amirgholami/adahessian
#*

from __future__ import print_function
import logging
import os
import sys

import numpy as np
import argparse
from tqdm import tqdm, trange

import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torchvision import datasets, transforms
from torch.autograd import Variable

from utils import *
from resnet import *
from optim_adahessian import Adahessian
from adaadahessian_scale_momentum import AdaAdaHessian

# Training settings
parser = argparse.ArgumentParser(description='PyTorch Example')
parser.add_argument('--batch-size', type=int, default=256, metavar='B',
                    help='input batch size for training (default: 256)')
parser.add_argument('--test-batch-size', type=int, default=256, metavar='TB',
                    help='input batch size for testing (default: 256)')
parser.add_argument('--epochs', type=int, default=200, metavar='E',
                    help='number of epochs to train (default: 10)')
parser.add_argument('--warmstart-samples-idbatch', type=int, default=0,
                    help='warmstart samples for hessian diagonal estimation based on independent batches (default: 0)')
parser.add_argument('--idbatch-mul', type=int, default=10,
                    help='number of times each independendtly sampled batch is itself sampled')
parser.add_argument('--warmstart-samples-fbatch', type=int, default=1,
                    help='warmstart samples for hessian diagonal estimation based on a single batch (default: 1)')
parser.add_argument('--alpha', type=float, default=1e-03,
                    help='alpha for truncation (default: 1e-03)')
parser.add_argument('--beta1', type=float, default=0.9,
                    help='beta_1 for diagonal approximation (default: 0.9)')
parser.add_argument('--beta2', type=float, default=0.999,
                    help='beta_2 for diagonal approximation (default: 0.999)')
parser.add_argument('--lr', type=float, default=0.15, metavar='LR',
                    help='learning rate (default: 0.15)')
parser.add_argument('--lr-decay-use', type=bool, default=True,
                    help='whether to use lr decay scheduler')
parser.add_argument('--lr-decay', type=float, default=0.2,
                    help='learning rate ratio')
parser.add_argument('--lr-decay-epoch', type=int, nargs='+', default=[60, 120, 160],
                    help='decrease learning rate at these epochs.')
parser.add_argument('--lr-decay-gamma', type=float, default=0.97,
                    help='gamma argument for ExponentialLR scheduler (default: 0.97)')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float,
                    metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('--with_scaling', type=int, default=1,
                    help='True if without warmsampling and with scaling')
parser.add_argument('--depth', type=int, default=18,
                    help='choose the depth of resnet')
parser.add_argument('--optimizer', type=str, default='adahessian',
                    help='choose optim')

args = parser.parse_args()
# set random seed to reproduce the work
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)

for arg in vars(args):
    print(arg, getattr(args, arg))
if not os.path.isdir('checkpoint/'):
    os.makedirs('checkpoint/')
    
logs_folder = 'logs_scale_momentum_no_wd_ResNet_' + str(args.depth) + '/'
    
if not os.path.isdir(logs_folder):
    os.makedirs(logs_folder)
# get dataset
train_loader, test_loader = getData(
    name='cifar100', train_bs=args.batch_size, test_bs=args.test_batch_size)

# make sure to use cudnn.benchmark for second backprop
cudnn.benchmark = True

# get model and optimizer
model = ResNet18().cuda()
print(model)
model = torch.nn.DataParallel(model)
print('    Total params: %.2fM' % (sum(p.numel()
                                       for p in model.parameters()) / 1000000.0))

criterion = nn.CrossEntropyLoss()
if args.optimizer == 'sgd':
    optimizer = optim.SGD(
        model.parameters(),
        lr=args.lr,
        momentum=0.9,
        weight_decay=args.weight_decay)
elif args.optimizer == 'adam':
    optimizer = optim.Adam(
        model.parameters(),
        lr=args.lr,
        weight_decay=args.weight_decay)
elif args.optimizer == 'adamw':
    print('For AdamW, we automatically correct the weight decay term for you! If this is not what you want, please modify the code!')
    args.weight_decay = args.weight_decay / args.lr
    optimizer = optim.AdamW(
        model.parameters(),
        lr=args.lr,
        weight_decay=args.weight_decay)
elif args.optimizer == 'adahessian':
    print('For AdaHessian, we use the decoupled weight decay as AdamW. Here we automatically correct this for you! If this is not what you want, please modify the code!')
    args.weight_decay = args.weight_decay / args.lr
    optimizer = Adahessian(
        model.parameters(),
        lr=args.lr,
        weight_decay=args.weight_decay)
elif args.optimizer == 'adaadahessian':
    args.weight_decay = args.weight_decay / args.lr
    optimizer = AdaAdaHessian(
        model.parameters(),
        lr=args.lr,
        alpha=args.alpha,
        beta_1=args.beta1,
        beta_2=args.beta2,
        weight_decay=args.weight_decay,
        warmstart_samples_fbatch=args.warmstart_samples_fbatch,
        with_scaling = args.with_scaling)
else:
    raise Exception('We do not support this optimizer yet!!')

if args.lr_decay_use:
    # learning rate schedule
    scheduler = lr_scheduler.MultiStepLR(
        optimizer,
        args.lr_decay_epoch,
        gamma=args.lr_decay,
        last_epoch=-1)

# args_train = args.optimizer + \
#              '_lr_' + str(np.around(args.lr, 8)) + \
#              '_alpha_' + str(args.alpha) + \
#              '_decay_' + str(int(args.lr_decay_use))
            
args_train = args.optimizer + \
             '_lr_' + str(np.around(args.lr, 8)) + \
             '_alpha_' + str(args.alpha) + \
             '_beta1_' + str(args.beta1) + \
             '_beta2_' + str(args.beta2) + \
             '_wd_' + str(args.weight_decay) + \
             '_ws_idb_' + str(args.warmstart_samples_idbatch) + \
             '_idb_m_' + str(args.idbatch_mul) + \
             '_lr_d_g_' + str(args.lr_decay_gamma) + \
             '_d_' + str(args.depth) + \
             '_s_' + str(args.seed)  + \
             '_ws_' + str(args.with_scaling) + \
             '_bs_' + str(args.batch_size) + \
             '_ep_'+str(args.epochs)

best_acc = 0.0
train_loss_stats = []
test_acc_stats = []

if args.optimizer == 'adaadahessian':
    if args.lr_decay_use:
        scheduler.step()
    model.train()

    print('Starting warmstart sampling')

    batch_counter = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        if batch_counter >= args.warmstart_samples_idbatch:
            break
        data, target = data.cuda(), target.cuda()
        output = model(data)
        optimizer.zero_grad()
        loss = criterion(output, target)
        loss.backward(create_graph=True)
        for i in range(args.idbatch_mul):
            optimizer.accumulate_h_diag()
        batch_counter += 1        

    print('Completed warmstart sampling')

for epoch in range(1, args.epochs + 1):
    print('Current Epoch: ', epoch)
    train_loss = 0.
    total_num = 0
    correct = 0  
    if args.lr_decay_use:
        scheduler.step()
    model.train()
    with tqdm(total=len(train_loader.dataset)) as progressbar:
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.cuda(), target.cuda()
            output = model(data)
            optimizer.zero_grad()
            loss = criterion(output, target)
            loss.backward(create_graph=True)
            train_loss += loss.item() * target.size()[0]
            total_num += target.size()[0]
            _, predicted = output.max(1)
            correct += predicted.eq(target).sum().item()
            optimizer.step()
            progressbar.update(target.size(0))

    acc = test(model, test_loader)
    train_loss /= total_num
    
#     print(f'Training Loss of Epoch {epoch}: {np.around(train_loss, 2)}')
#     print(f'Testing of Epoch {epoch}: {np.around(acc * 100, 2)} \n')
    
    print(np.around(train_loss, 2))
    print(np.around(acc * 100, 2))
    
    train_loss_stats.append(train_loss)    
    test_acc_stats.append(acc * 100)
    
    torch.save({
        'train_loss_stats': train_loss_stats,
        'test_acc_stats': test_acc_stats,
    }, logs_folder + args_train + '_log.pkl')
    
    if acc > best_acc:
        best_acc = acc
        torch.save({
            'epoch': epoch,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'best_accuracy': best_acc,
            }, 'checkpoint/' + args_train + '_netbest.pkl')

# print(f'Best Acc: {np.around(best_acc * 100, 2)}')

print(np.around(best_acc * 100, 2))
