import warnings
warnings.filterwarnings("ignore")

import os
import sys
import argparse
import datetime
import time
import random
import os.path as osp
import torch.backends.cudnn as cudnn

import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

import torch
import torch.nn as nn
from torch.optim import lr_scheduler
from pyhessian import hessian
import models
import datasets
from utils import AverageMeter, Logger

parser = argparse.ArgumentParser()

parser.add_argument('-d', '--dataset', type=str, default='mnist', choices=['mnist', 'svhn', 'cifar10', 'cifar100'])
parser.add_argument('-j', '--workers', default=4, type=int, help="number of data loading workers (default: 4)")

parser.add_argument('--model', type=str, default='mlp_bn')
parser.add_argument('--batch_size', type=int, default=60000)
parser.add_argument('--n_samples', type=int, default=None)
parser.add_argument('--alpha', type=int, default=3, help="alpha for regularization")
parser.add_argument('--lr', type=float, default=0.001, help="learning rate for model")
parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)')
parser.add_argument('--max_epoch', type=int, default=200)
parser.add_argument('--stepsize', type=int, default=0)
parser.add_argument('--gamma', type=float, default=0.0, help="learning rate decay")
parser.add_argument('--momentum', type=float, default=0.9, help="momentum of gradient")
parser.add_argument('--trial_num', type=int, default=2, help="number of trials")
parser.add_argument('--save_dir', type=str, default='/home/zhangxj/workspace/HB/hb2/outputs')
###################################################################
parser.add_argument('--eval_freq', type=int, default=1)
parser.add_argument('--print_freq', type=int, default=50)
parser.add_argument('--gpu', type=str, default='0')
parser.add_argument('--use_cpu', action='store_true')
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--confusion_matrix', type=bool, default=False, help="whether to plot confusion_matrix")

args = parser.parse_args()


best_acc = 0.

def loss_plus_regularization(loss, model, momentum=0.9, eta=0.000, alpha=1):
    if alpha == 0:
        return loss
    
    elif alpha == 1:
        return loss / (1 - momentum)
    
    elif alpha == 2:
        grads = torch.autograd.grad(loss, model.parameters(), create_graph=True)

        grad_norm_squared = torch.sum(torch.stack([torch.norm(g)**2 for g in grads]))

        return loss / (1 - momentum) + \
               eta * (1 + momentum) / (4*(1 - momentum)**3) * grad_norm_squared
    elif alpha == 3:
        grads = torch.autograd.grad(loss, model.parameters(), create_graph=True, retain_graph=True)

        grad_norm_squared = torch.sum(torch.stack([torch.norm(g)**2 for g in grads]))

        grads_flattened = torch.cat([g.view(-1) for g in grads])
        
        Hv = torch.autograd.grad(
            grads_flattened, 
            model.parameters(),  
            grad_outputs=grads_flattened,  
            retain_graph=True 
        )
        
        Hv_flattened = torch.cat([h.view(-1) for h in Hv])
        
        grad_norm_squared_hessian = torch.dot(grads_flattened, Hv_flattened)

        return loss / (1 - momentum) + \
               eta * (1 + momentum) / (4*(1 - momentum)**3) * grad_norm_squared + \
               eta**2 * (1 + momentum)**2 / (4*(1-momentum)**5) * grad_norm_squared_hessian        

    else:
        raise ValueError(f"Invalid alpha: {alpha}")


def main():
    global args, best_acc

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    use_gpu = torch.cuda.is_available()
    if args.use_cpu: use_gpu = False

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    dirname = osp.join(args.save_dir, f"{args.trial_num}")
    if not osp.exists(dirname):
        os.mkdir(dirname)
    
    sys.stdout = Logger(osp.join(dirname, f'{args.dataset}-{args.model}-epoch_{args.max_epoch}-momentum_{args.momentum}-alpha_{args.alpha}.txt'))
    print(args)
    if use_gpu:
        print("Currently using GPU: {}".format(args.gpu))
        print("Using GPU: {}".format(torch.cuda.get_device_name(0)))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(args.seed)
    else:
        print("Currently using CPU")
    
    print("Creating dataset: {}".format(args.dataset))
    dataset = datasets.create(
        name=args.dataset, batch_size=args.batch_size, use_gpu=use_gpu,
        num_workers=args.workers, is_shuffle=True, n_samples=args.n_samples)
    
    trainloader, testloader = dataset.trainloader, dataset.testloader

    # model
    print("Creating model: {}".format(args.model))
    model = models.create(name=args.model, num_classes=dataset.num_classes)

    if use_gpu:
        model = nn.DataParallel(model).cuda()
    
    criterion = torch.nn.CrossEntropyLoss()

    print("momentum: {}, lr: {}, alpha: {}".format(args.momentum, args.lr, args.alpha))
    if args.alpha > 0:
        lr_euler = args.lr * 0.1
        factor = int(args.lr / lr_euler + 0.1)
        learning_rate = lr_euler
        momentum = 0
        print("lr_euler: {}, factor: {}".format(lr_euler, factor))
    else:
        factor = 1
        learning_rate = args.lr
        momentum = args.momentum

    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)

    start_time = time.time()
    for epoch in range(args.max_epoch):
        print("==> Epoch {}/{}".format(epoch+1, args.max_epoch))
        train_loss, acc = train(model, criterion, optimizer, trainloader, use_gpu)
        print("Train Loss {}, Train Accuracy {}".format(train_loss, acc))

        if epoch % factor == 0:
            dirname = osp.join(args.save_dir, f'{args.trial_num}/checkpoints_alpha_{args.alpha}')
            if not osp.exists(dirname):
                os.mkdir(dirname)

            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
            }, filename=os.path.join(dirname, f'checkpoint_step_{epoch}.tar'))

    # file.close()
    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    print("Finished. Total elapsed time (h:m:s): {}".format(elapsed))
    
def train(model, criterion, optimizer, trainloader, use_gpu):

    correct, total = 0, 0

    model.train()
    losses = AverageMeter()
    
    for batch_idx, (data, labels) in enumerate(trainloader):
        if use_gpu:
            data, labels = data.cuda(), labels.cuda()
        outputs = model(data)
        loss = criterion(outputs, labels)
        
        final_loss = loss_plus_regularization(loss, model, args.momentum, args.lr, args.alpha)

        predicted = outputs.max(1, keepdim=True)[1]
        correct += (predicted.squeeze().data.cpu() == labels.data.cpu()).sum()
        total += labels.size(0)

        optimizer.zero_grad()
        final_loss.backward()

        optimizer.step()
    
        losses.update(final_loss.item(), labels.size(0))

    acc = correct * 100. / total

    # return losses.avg, ghges.avg
    return losses.avg, acc
    
def save_checkpoint(state, filename='checkpoint.pth.tar'):
    """
    Save the training model
    """
    torch.save(state, filename)


if __name__ == '__main__':
    main()