import warnings
warnings.filterwarnings("ignore")

import os
import sys
import argparse
import datetime
import time
import random
import os.path as osp
import torch.backends.cudnn as cudnn

import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

import torch
import torch.nn as nn
from torch.optim import lr_scheduler
from pyhessian import hessian
import models
import datasets
from utils import AverageMeter, Logger

parser = argparse.ArgumentParser()

parser.add_argument('-d', '--dataset', type=str, default='cifar10', choices=['mnist', 'svhn', 'cifar10', 'cifar100'])
parser.add_argument('-j', '--workers', default=4, type=int, help="number of data loading workers (default: 4)")

parser.add_argument('--model', type=str, default='mlp')
parser.add_argument('--batch_size', type=int, default=5000)
parser.add_argument('--n_samples', type=int, default=5000)
parser.add_argument('--lr', type=float, default=0.1, help="learning rate for model")
parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)')
parser.add_argument('--max_epoch', type=int, default=2000)
parser.add_argument('--stepsize', type=int, default=0)
parser.add_argument('--gamma', type=float, default=0.0, help="learning rate decay")
parser.add_argument('--reg_eta', type=float, default=0.000, help="regularization rate")
parser.add_argument('--momentum', type=float, default=0.9, help="momentum of gradient")
parser.add_argument('--save_dir', type=str, default='/home/xxx/workspace/HBF/results/GD')
###################################################################
parser.add_argument('--eval_freq', type=int, default=1)
parser.add_argument('--print_freq', type=int, default=50)
parser.add_argument('--gpu', type=str, default='0')
parser.add_argument('--use_cpu', action='store_true')
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--confusion_matrix', type=bool, default=False, help="whether to plot confusion_matrix")

args = parser.parse_args()


best_acc = 0.

def loss_plus_regularization(loss, model):
    grads = torch.autograd.grad(loss, model.parameters(), create_graph=True)
    grad_vec = torch.cat([g.reshape(-1) for g in grads])
    

    hessian_vector_product = torch.autograd.grad(
        grad_vec, 
        model.parameters(),  
        grad_outputs=grad_vec,  
        retain_graph=True 
    )
    hvp_vec = torch.cat([hvp.reshape(-1) for hvp in hessian_vector_product])
    
    grad_norm_squared = torch.dot(grad_vec, hvp_vec)
    
    return grad_norm_squared

def main(iter):
    print("+"*30 + f"Experiment {iter}" + "+"*20)
    global args, best_acc
    args.seed = iter

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    use_gpu = torch.cuda.is_available()
    if args.use_cpu: use_gpu = False

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    dirname = osp.join(args.save_dir, f"{iter}")
    if not osp.exists(dirname):
        os.mkdir(dirname)
    
    sys.stdout = Logger(osp.join(dirname, f'{args.dataset}-{args.model}-epoch_{args.max_epoch}-momentum_{args.momentum}_reg_{args.reg_eta}-log.txt'))
    print(args)
    if use_gpu:
        print("Currently using GPU: {}".format(args.gpu))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(args.seed)
    else:
        print("Currently using CPU")
    
    print("Creating dataset: {}".format(args.dataset))
    dataset = datasets.create(
        name=args.dataset, batch_size=args.batch_size, use_gpu=use_gpu,
        num_workers=args.workers, is_shuffle=True, n_samples=args.n_samples)
    
    trainloader, testloader = dataset.trainloader, dataset.testloader

    # model
    print("Creating model: {}".format(args.model))
    model = models.create(name=args.model, num_classes=dataset.num_classes)

    if use_gpu:
        model = nn.DataParallel(model).cuda()
    
    criterion = torch.nn.CrossEntropyLoss()

    print(args.momentum, args.lr)
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=0, momentum=args.momentum)

    if args.stepsize > 0:
        # scheduler = lr_scheduler.StepLR(optimizer, step_size=args.stepsize, gamma=args.gamma)

        if args.dataset == 'cifar10':
            def adjust_learning_rate(optimizer, epoch):
                if epoch < 40:
                    lr = 0.01
                elif epoch < 100:
                    lr = args.lr
                elif epoch < 150:
                    lr = args.lr * args.gamma
                elif epoch <200:
                    lr = args.lr * args.gamma * args.gamma
                else:
                    lr = args.lr * args.gamma * args.gamma * args.gamma

                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr

        if args.dataset == 'cifar100':
            scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150, 200], gamma=args.gamma, last_epoch=args.start_epoch - 1)

    start_time = time.time()
    train_loss_list = []
    test_acc_list = []
    train_acc_list = []
    ghg_list = []
    smoothness_list = []
    sharpness_list = []
    for epoch in range(args.max_epoch):
        print("==> Epoch {}/{}".format(epoch+1, args.max_epoch))
        # train_loss, ghg = train(model, criterion, optimizer, trainloader, use_gpu)
        train_loss, ghg, smoothness, sharpness, train_acc = train(model, criterion, optimizer, trainloader, use_gpu)
        
        # if args.dataset == 'cifar10':
        #     if args.stepsize:
        #         adjust_learning_rate(optimizer, epoch)
        # if args.dataset == 'cifar100':
        #     if args.stepsize > 0: scheduler.step()

        if args.eval_freq > 0 and (epoch+1) % args.eval_freq == 0 or (epoch+1) == args.max_epoch:
            print("==> Test")
            test_acc, err = test(model, testloader, use_gpu, dirname)
            print("Train Loss {}\t Train Accuracy (%): {} \t Test Accuracy (%): {}\t gHg: {}\t smoothness {}\t sharpness {}".format(train_loss, train_acc, test_acc, ghg, smoothness, sharpness))

            train_loss_list.append(train_loss)
            test_acc_list.append(test_acc.item())
            ghg_list.append(ghg)
            smoothness_list.append(smoothness)
            sharpness_list.append(sharpness)
            train_acc_list.append(train_acc.item())
            
            # is_best = acc > best_acc
            # best_acc = max(acc, best_acc)

        # save_checkpoint({
        #     'epoch': epoch + 1,
        #     'state_dict': model.state_dict(),
        #     'best_prec1': best_acc,
        # }, is_best, filename=os.path.join(args.save_dir, 'checkpoints/checkpoint_{}.tar'.format(epoch)))

    df = pd.DataFrame({'epoch': list(range(1, args.max_epoch+1)), 
                       'train_loss': train_loss_list, 
                       'train_acc': train_acc_list, 
                       'test_acc': test_acc_list, 
                       'ghg': ghg_list,
                       'smoothness': smoothness_list,
                       'sharpness': sharpness_list
    })
    df.to_csv(osp.join(dirname, f'{args.dataset}-{args.model}-epoch_{args.max_epoch}-momentum_{args.momentum}_reg_{args.reg_eta}-result.csv'), index=False)

    # file.close()
    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    print("Finished. Total elapsed time (h:m:s): {}".format(elapsed))
    print(f"\n\n\n\n\n")
    
global_old_gradients = None
def train(model, criterion, optimizer, trainloader, use_gpu):
    global global_old_gradients

    correct, total = 0, 0

    model.train()
    losses = AverageMeter()
    ghges = AverageMeter()
    L_values = AverageMeter()
    Sharpness = AverageMeter()
    
    for batch_idx, (data, labels) in enumerate(trainloader):
        if use_gpu:
            data, labels = data.cuda(), labels.cuda()
        outputs = model(data)
        loss = criterion(outputs, labels)

        ghg = loss_plus_regularization(loss, model)
        reg_loss = args.reg_eta * ghg

        predicted = outputs.max(1, keepdim=True)[1]
        correct += (predicted.squeeze().data.cpu() == labels.data.cpu()).sum()
        total += labels.size(0)

        hessian_comp = hessian(model, criterion, data=(data, labels), cuda=use_gpu)
        top_eigenvalues, _ = hessian_comp.eigenvalues()

        if global_old_gradients is None:
            old_gradients = [torch.zeros_like(param.grad) for param in model.parameters()]
        else:
            old_gradients = global_old_gradients   

        optimizer.zero_grad()
        loss.backward()

        # total_loss = loss + reg_loss
        # total_loss.backward()

        new_gradients = [param.grad.clone() for param in model.parameters()]
        optimizer.step()

        flattened_old_gradients = torch.cat([grad.view(-1) for grad in old_gradients])
        flattened_new_gradients = torch.cat([grad.view(-1) for grad in new_gradients])

        numerator = torch.dot(flattened_old_gradients, flattened_old_gradients - flattened_new_gradients)
        denominator = torch.norm(flattened_old_gradients) ** 2

        # numerator2 = 0
        # denominator2 = 0
        # for grad_old, grad_new in zip(old_gradients, new_gradients):
        #     numerator2 += torch.sum(grad_old * (grad_old - grad_new))
        #     denominator2 += torch.sum(grad_old ** 2)
        
        eta = optimizer.param_groups[0]['lr']
        if denominator == 0:
            smoothnes_value = numerator / (eta * denominator + 1e-6)
        else:
            smoothnes_value = numerator / (eta * denominator)

        global_old_gradients = new_gradients

        losses.update(loss.item(), labels.size(0))
        ghges.update(ghg.item(), labels.size(0))
        # ghges.update(0, labels.size(0))
        L_values.update(smoothnes_value.item(), labels.size(0))
        Sharpness.update(top_eigenvalues[0], labels.size(0))
        
        if (batch_idx+1) % args.print_freq == 0:
            print("Batch {}/{}\t Loss {:.6f} ({:.6f}) \tSmoothnes {:.6f} ({:.6f})\t Sharpness {:.6f} ({:.6f})".format(
                batch_idx+1, 
                len(trainloader), 
                losses.val, losses.avg, 
                L_values.val, L_values.avg,
                Sharpness.val, L_values.avg,
            ))

    acc = correct * 100. / total

    # return losses.avg, ghges.avg
    return losses.avg, ghges.avg, L_values.avg, Sharpness.avg, acc
    
    
def test(model, testloader, use_gpu, dirname):
    model.eval()
    correct, total = 0, 0
    
    if args.confusion_matrix:
        true_labels, pred_labels = [], []
    
    with torch.no_grad():
        for data, labels in testloader:
            if use_gpu:
                data, labels = data.cuda(), labels.cuda()
            outputs = model(data)
            
            predicted = outputs.max(1, keepdim=True)[1]
            correct += (predicted.squeeze().data.cpu() == labels.data.cpu()).sum()
            total += labels.size(0)

            if args.confusion_matrix:
                true_labels.append(labels.data.cpu().numpy())
                pred_labels.append(predicted.squeeze().data.cpu().numpy())

    acc = correct * 100. / total
    err = 100. - acc

    if args.confusion_matrix:
        if args.best_acc < acc:
            args.best_acc = acc
            true_labels = np.concatenate(true_labels, 0)
            pred_labels = np.concatenate(pred_labels, 0)
            cm = confusion_matrix(true_labels, pred_labels)
            plot_confusion_matrix(cm, dirname, prefix='confusion_matrix')
            np.savetxt(osp.join(dirname, 'true_labels.txt'), true_labels, fmt='%d')
            np.savetxt(osp.join(dirname, 'pred_labels.txt'), pred_labels, fmt='%d')
    
    return acc, err

def plot_confusion_matrix(cm, dirname, prefix):
    dirname = osp.join(args.save_dir, prefix)
    if not osp.exists(dirname):
        os.mkdir(dirname)
    save_name = osp.join(dirname, 'confusion_matrix.pdf')
    with PdfPages(save_name) as pdf:
        plt.figure(figsize=(6, 6), dpi=100)
        np.set_printoptions(precision=2)
        plt.imshow(cm, cmap=plt.cm.jet)
        pdf.savefig()
        plt.close()

def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    """
    Save the training model
    """
    torch.save(state, filename)


if __name__ == '__main__':
    for i in range(3):
        main(i)