# -*- coding: utf-8 -*-

from load_data import MNIST
from resnet_shallow import ResNetSmall

import os
import argparse
import numpy as np
import torch
from torch.autograd import grad as torchgrad
import copy

#from GPUtil import showUtilization as gpu_usage
from utils import accuracy
#
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--iters', type=int, default=int(1e4))
parser.add_argument('--batchsize', type=int, default=10)
parser.add_argument('--lr', type=float, default=0.1)
parser.add_argument('--decaylr', type=int, default=1)
parser.add_argument('--datadir', type=str, default='data')
parser.add_argument('--logdir', type=str, default='logs_resnet')

#%
args = parser.parse_args()
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.cuda.empty_cache()

dataset = MNIST(args.datadir)
test_list = dataset.load_test_list(1000) ### get test in batches, the parameter is the batch size
#%
# model
model = ResNetSmall().cuda() ### ResNet with only 2 residual blocks
lr = args.lr
criterion = torch.nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

filename = args.logdir+"/lr_"+str(args.lr)+"_lrdecay_"+str(bool(args.decaylr))+ "_iters_"+str(args.iters)+"_batchsize_"+str(args.batchsize)+ "_seed_" + str(args.seed) + ".csv"
# optimization
result = []

print("Start training, {} training data and {} testing data.".format(dataset.n_train,dataset.n_test) )
for i in range(args.iters):
    #print("training step {}".format(i))
    if i == 5000 and args.decaylr: ## change the step size at 5000th iteration if we want to decay the lr
        lr *= 0.1
        print("Step size changed to {}".format(lr))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    # train
    model.train()
    optimizer.zero_grad()
    
    x, y = dataset.sample_train_data(args.batchsize, True)
    out = model(x)
    
    loss = criterion(out, y)
    loss.backward()
    optimizer.step()
    #print(gpu_usage())
    # evaluate every 500 steps
    if i % 500 == 0:
        print("Training: {} steps finished out of {}.".format(i,args.iters))
        
        ### get test loss and test accuracy
        test_loss, test_acc = 0, 0
        for x,y in test_list:
            out = model(x)
            test_loss += criterion(out, y).item()
            test_acc += accuracy(out, y).item()
        test_loss /= len(test_list)
        test_acc /= len(test_list)
        
        ### training accuracy and loss based on the batch that used to train
        model.eval()
        out = model(x)
        train_acc = accuracy(out, y).item()
        train_loss = criterion(out, y).item()
       
        ### calculating Rayleigh Quotient based on a subsample, since the full calculation requires too many memory
        X, Y = dataset.sample_train_data(dataset.n_train)
        optimizer.zero_grad()
        out = model(X)
        loss = criterion(out, Y)
       
        ### Calculate the Rayleigh Quotient 
        model_ref = copy.deepcopy(model)
       # print(gpu_usage())
        jacobian_temp = 0
        flag = 0
        for param, param_ref in zip(model.parameters(), model_ref.parameters()):
            jacobian_temp += torch.sum(torchgrad(loss, param, create_graph=True)[0]*param_ref.detach())
            flag += 1
            #print(gpu_usage(),flag)
            
        Hessian_temp = 0
        norm_temp = 0     
        #print(gpu_usage())
        flag = 0
        for param, param_ref in zip(model.parameters(), model_ref.parameters()):
            Hessian_temp += torch.sum(torchgrad(jacobian_temp, param, retain_graph=True)[0]*param_ref.detach()).item()
            norm_temp += torch.sum(param_ref*param_ref).item()
            flag += 1
            #print(gpu_usage(),flag)
        rayleigh = Hessian_temp/norm_temp
   
        
        ### Approximate the largest eigenvector
        temp = [v.detach() for v in model_ref.parameters()]
        for _ in range(5):
            loss_temp = 0
            for param, v in zip(model.parameters(), temp):
                jacobian_temp += torch.sum(torchgrad(loss, param, create_graph=True)[0]*v.detach())
            temp = []
            norm = 0
            for param in model.parameters():
                new_v = torchgrad(jacobian_temp, param, retain_graph=True)[0].detach()
                temp += [new_v]
                norm += torch.sum(new_v**2)
            for j in range(len(temp)):
                temp[j] /= torch.sqrt(norm)

        
        ### largest eigenvalue
        Hessian_temp = 0
        norm_temp = 0
        jacobian_temp = 0
        for param, v in zip(model.parameters(), temp):
            jacobian_temp += torch.sum(torchgrad(loss, param, create_graph=True)[0]*v.detach())
        for param, v in zip(model.parameters(), temp):
            Hessian_temp += torch.sum(torchgrad(jacobian_temp, param, retain_graph=True)[0]*v.detach()).item()
            norm_temp += torch.sum(v**2).item()
            
        max_eig = Hessian_temp/norm_temp
        
        result.append([i,train_acc,train_loss,test_acc,test_loss,rayleigh,max_eig,rayleigh/max_eig])

print("Finished training, saving result to " + filename)
os.makedirs(args.logdir, exist_ok=True)
np.savetxt(filename, np.array(result), delimiter=',')
torch.cuda.empty_cache()
