import os
import argparse
import numpy as np
import torch
import time
from torch.autograd import grad as torchgrad
# from tensorboardX import SummaryWriter
import copy
from mnist import FashionMNIST
from net import ConvSmall
from utils import *

parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--iters', type=int, default=int(1e4+1))
parser.add_argument('--batchsize', type=int, default=25)
parser.add_argument('--lr', type=float, default=0.1)
parser.add_argument('--wd', type=float, default=0)
parser.add_argument('--momentum', type=float, default=0)
parser.add_argument('--resume', type=str, default=None)
parser.add_argument('--datadir', type=str, default='dataset')
parser.add_argument('--logdir', type=str, default='logs/SGD')





args = parser.parse_args()
logger = LogSaver(args.logdir)
logger.save(str(args), 'args')
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# data
dataset = FashionMNIST(args.datadir)
logger.save(str(dataset), 'dataset')
test_list = dataset.getTestList(10000, True)

# model
model = ConvSmall().cuda()
start_iter = 0
lr = args.lr
criterion = torch.nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=args.wd)
if args.resume:
    checkpoint = torch.load(args.resume)
    start_iter = checkpoint['iter'] + 1
    lr = checkpoint['lr']
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    logger.save("=> loaded checkpoint '{}'".format(args.resume))
logger.save(str(model), 'classifier')
logger.save(str(optimizer), 'optimizer')

# writer
writer = open(args.logdir+"/lr_"+str(args.lr)+"_batchsize-"+str(args.batchsize)+"_wd_"+str(args.wd)+ "_seed_" + str(args.seed) + ".txt","w")
start_time = time.time()
# optimization
for i in range(start_iter, args.iters):
    
    if i in [2500, 7000]:
        lr *= 0.1
        logger.save('update lr: %f'%(lr))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    # train
    model.train()
    optimizer.zero_grad()
    x, y = dataset.getTrainBatch(args.batchsize, True)
    out = model(x)
#     print (x.shape)
    loss = criterion(out, y)
    loss.backward()
    optimizer.step()

    # evaluate
    if i % 100 == 0:
#         writer.add_scalar('lr', lr, i)
        model.eval()
        out = model(x)
        train_acc = accuracy(out, y).item()
        train_loss = criterion(out, y).item()
    
        # calculate rayleigh quotient and max eigenvalue
        X, Y = dataset.getTrainBatch(dataset.n_train, True)
        optimizer.zero_grad()
        out = model(X)
        loss = criterion(out, Y)
        
        model_ref = copy.deepcopy(model)
 
        loss_temp = 0
        for param, param_ref in zip(model.parameters(), model_ref.parameters()):
            loss_temp += torch.sum(torchgrad(loss, param, create_graph=True)[0]*param_ref.detach())
            
        Hessian_vec_prod = 0
        grad_norm = 0        
        for param, param_ref in zip(model.parameters(), model_ref.parameters()):
            Hessian_vec_prod += torch.sum(torchgrad(loss_temp, param, retain_graph=True)[0]*param_ref.detach())
            grad_norm += torch.sum(param_ref*param_ref)
        rayleigh = Hessian_vec_prod/grad_norm
        
        temp = [v.detach() for v in model_ref.parameters()]
        for idx in range(5):
            loss_temp = 0
            for param, v in zip(model.parameters(), temp):
                loss_temp += torch.sum(torchgrad(loss, param, create_graph=True)[0]*v.detach())
            temp = []
            norm = 0
            for param in model.parameters():
                new_v = torchgrad(loss_temp, param, retain_graph=True)[0].detach()
                temp += [new_v]
                norm += torch.sum(new_v**2)
            for j in range(len(temp)):
                temp[j] /= torch.sqrt(norm)

                
        Hessian_vec_prod = 0
        grad_norm = 0
        loss_temp = 0
        for param, v in zip(model.parameters(), temp):
            loss_temp += torch.sum(torchgrad(loss, param, create_graph=True)[0]*v.detach())
        for param, v in zip(model.parameters(), temp):
            Hessian_vec_prod += torch.sum(torchgrad(loss_temp, param, retain_graph=True)[0]*v.detach())
            grad_norm += torch.sum(v**2)
            
        max_eig = Hessian_vec_prod/grad_norm

        test_loss, test_acc = 0, 0
        for x,y in test_list:
            out = model(x)
            test_loss += criterion(out, y).item()
            test_acc += accuracy(out, y).item()
        test_loss /= len(test_list)
        test_acc /= len(test_list)

        print ('Iter:%d, Test [acc: %.2f, loss: %.4f], Train [acc: %.2f, loss: %.4f]' \
                % (i, test_acc, test_loss, train_acc, train_loss))
        print ("rayleigh/max_eig: %.4f" % (rayleigh.cpu().detach().numpy()/max_eig.cpu().detach().numpy()))
        writer.write(str(i) + "\t" + str(train_acc) + "\t" + str(train_loss) + "\t" + str(train_loss) + "\t" + str(test_loss) + "\t" + str(test_acc) + "\t" + str(rayleigh.cpu().detach().numpy()) + "\t" + str(max_eig.cpu().detach().numpy()) + "\n")
        writer.flush()

        print('Testing using time:', time.time()-start_time)
    if i % 1000 == 0:
        state = {'iter':i, 'lr':lr, 'model':model.state_dict(), 'optimizer':optimizer.state_dict()}
        torch.save(state, args.logdir+'/iter-'+str(i) + "-lr-"+str(args.lr)+'.pth.tar')

writer.close()
