from __future__ import print_function
import os
import logging
import argparse
import time

import torch.backends.cudnn as cudnn
import torch.optim.lr_scheduler as lr_scheduler

from utils import *
from models.resnet import *
from models.resnet2 import ResNet18, ResNet34
from optim.optim_adahessian import Adahessian
from optim.adadqh import AdaDQH
from optim.adabelief import AdaBelief
from optim.aida import Aida

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

# Training settings
parser = argparse.ArgumentParser(description='PyTorch Example')
parser.add_argument('--batch-size', type=int, default=256, metavar='B',
                    help='input batch size for training (default: 256)')
parser.add_argument('--test-batch-size', type=int, default=256, metavar='TB',
                    help='input batch size for testing (default: 256)')
parser.add_argument('--epochs', type=int, default=160, metavar='E',
                    help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.15, metavar='LR',
                    help='learning rate (default: 0.15)')
parser.add_argument('--lr-min', type=float, default=0.0, dest='lr_min',
                    help='minimum learning rate (default: 0.0)')
parser.add_argument('--lr-decay', type=float, default=0.1,
                    help='learning rate ratio')
parser.add_argument('--lr-decay-epoch', type=int, nargs='+', default=[80, 120],
                    help='decrease learning rate at these epochs.')
parser.add_argument('--scheduler', type=str, default='multistep',
                    help='choose scheduler')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float,
                    metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('--weight-decouple', action='store_true',
                    help='weight decouple')
parser.add_argument('--model', type=str, default='resnet18',
                    help='choose model')
parser.add_argument('--depth', type=int, default=20,
                    help='choose the depth of resnet')
parser.add_argument('--optimizer', type=str, default='adahessian',
                    help='choose optim')
parser.add_argument('--eps', type=float, default=1e-8,
                    help='choose epsilon')
parser.add_argument('--amsgrad', action='store_true',
                    help='choose amsgrad')
parser.add_argument('--log_file', type=str, default="../logs/cifar10",
                    help='log file')

args = parser.parse_args()

LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
logging.basicConfig(filename=args.log_file, level=logging.INFO, format=LOG_FORMAT)
logging.captureWarnings(True)

# set random seed to reproduce the work
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)


for arg in vars(args):
    logging.info("{}: {}".format(arg, getattr(args, arg)))
if not os.path.isdir('checkpoint/'):
    os.makedirs('checkpoint/')
# get dataset
train_loader, test_loader = getData(
    name='cifar10', train_bs=args.batch_size, test_bs=args.test_batch_size)

# make sure to use cudnn.benchmark for second backprop
cudnn.benchmark = True

# get model and optimizer
if args.model == "resnet18":
    model = ResNet18().cuda()
elif args.model == "resnet34":
    model = ResNet34().cuda()
else:
    model = resnet(num_classes=10, depth=args.depth).cuda()

logging.info(model)
model = torch.nn.DataParallel(model)
logging.info('    Total params: %.2fM' % (sum(p.numel()
                                       for p in model.parameters()) / 1000000.0))

criterion = nn.CrossEntropyLoss()
if args.optimizer == 'sgd':
    optimizer = optim.SGD(
        model.parameters(),
        lr=args.lr,
        momentum=0.9,
        weight_decay=args.weight_decay)
elif args.optimizer == 'adam':
    optimizer = optim.Adam(
        model.parameters(),
        lr=args.lr,
        weight_decay=args.weight_decay)
elif args.optimizer == 'adamw':
    logging.info('For AdamW, we automatically correct the weight decay term for you! If this is not what you want, please modify the code!')
    args.weight_decay = args.weight_decay / args.lr
    optimizer = optim.AdamW(
        model.parameters(),
        lr=args.lr,
        weight_decay=args.weight_decay)
elif args.optimizer == 'adahessian':
    logging.info('For AdaHessian, we use the decoupled weight decay as AdamW. Here we automatically correct this for you! If this is not what you want, please modify the code!')
    args.weight_decay = args.weight_decay / args.lr
    optimizer = Adahessian(
        model.parameters(),
        lr=args.lr,
        weight_decay=args.weight_decay)
elif args.optimizer == 'adabelief':
    optimizer = AdaBelief(
        model.parameters(),
        lr=args.lr,
        weight_decay=args.weight_decay,
        weight_decouple=args.weight_decouple,
        rectify=False
    )
elif args.optimizer == 'aida':
    optimizer = Aida(
        model.parameters(),
        lr=args.lr,
        eps=args.eps,
        weight_decay=args.weight_decay,
    )
elif args.optimizer == 'adadqh':
    optimizer = AdaDQH(
        model.parameters(),
        lr=args.lr,
        eps=args.eps,
        weight_decay=args.weight_decay,
        weight_decouple=args.weight_decouple,
        amsgrad=args.amsgrad,
    )
else:
    raise Exception('We do not support this optimizer yet!!')

# learning rate schedule
if args.scheduler == "multistep":
    scheduler = lr_scheduler.MultiStepLR(
        optimizer,
        args.lr_decay_epoch,
        gamma=args.lr_decay,
        last_epoch=-1)
elif args.scheduler == "cosine":
    scheduler = lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=args.epochs * len(train_loader),
        eta_min=args.lr_min,
        last_epoch=-1,
    )
else:
    raise Exception('We do not support this scheduler yet!!')

best_acc = 0.0
for epoch in range(1, args.epochs + 1):
    starttime = time.time()
    logging.info('Current Epoch: %d', epoch)
    train_loss = 0.
    total_num = 0
    correct = 0

    model.train()
    # with tqdm(total=len(train_loader.dataset)) as progressbar:
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.cuda(), target.cuda()
        output = model(data)
        loss = criterion(output, target)
        if args.optimizer == 'adahessian':
            loss.backward(create_graph=True)
        else:
            loss.backward()
        train_loss += loss.item() * target.size()[0]
        total_num += target.size()[0]
        _, predicted = output.max(1)
        correct += predicted.eq(target).sum().item()
        optimizer.step()
        optimizer.zero_grad()
            # progressbar.update(target.size(0))
        if args.scheduler == "cosine":
            scheduler.step()

    if args.scheduler == "multistep":
        scheduler.step()

    endtime = time.time()
    logging.info('cost: {}'.format(endtime - starttime))
    train_loss /= total_num
    logging.info('Training Loss of Epoch {}: {}'.format(epoch, train_loss))
    acc = test(model, test_loader)
    logging.info("Testing of Epoch {}: {} \n".format(epoch, acc))

    if acc > best_acc:
        best_acc = acc
        torch.save({
            'epoch': epoch,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'best_accuracy': best_acc,
            }, 'checkpoint/netbest.pkl')

logging.info('Best Acc: {}\n'.format(best_acc))
