from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

import time

from myhessian import hessian

import numpy as np
from sklearn import datasets
from sklearn.preprocessing import StandardScaler

def gather_flat_grad(params):
    views = []
    for p in params:
        if p.grad is None:
            view = p.new(p.numel()).zero_()
        elif p.grad.is_sparse:
            view = p.grad.to_dense().view(-1)
        else:
            view = p.grad.view(-1)
        views.append(view)
    return torch.cat(views, 0)
    
def gather_flat_data(params):
    views = []
    for p in params:
        views.append(p.data.view(-1))
    return torch.cat(views,0)    


class Model(nn.Module):
    def __init__(self, n_input_features):
        super(Model, self).__init__()
        self.linear = nn.Linear(n_input_features, 1).double()
        
    def forward(self, x):
        y_pred = torch.sigmoid(self.linear(x))
        return y_pred


def getPartData(data_loader):
    smallData = None
    for batch_idx, (data, target) in enumerate(data_loader):
        if batch_idx > 0:
            return smallData
        else:
            smallData = [(d, t) for d, t in zip(data, target)]


import pickle

results = {'train_loss': [], 'train_prec': [], 'train_gd': [],
           'test_loss': [], 'test_prec': [],'time':[]}
import math

time_stamp = time.time()

criterion = torch.nn.BCELoss(reduction='mean')

criterion2 = torch.nn.BCELoss(reduction='sum')


def train(model, X_train,y_train, optimizer, epoch,wd,start_time):
    model.train()    
    
    
    def closure():        
        y_pred = model(X_train)    
        loss = criterion(y_pred,y_train)
        optimizer.zero_grad()
        loss.backward()
        return loss
    
    xk = gather_flat_data(model.parameters())    
    with torch.enable_grad():
        loss = closure()
    gk = gather_flat_grad(model.parameters())    
    gk_w = gk+wd*xk
    gk = gk_w
    
    if isinstance(optimizer,optim.NCG) or isinstance(optimizer,optim.LBFGS):
        optimizer.step(closure)
    else:
        optimizer.step()    
    
    if epoch == 1:
        print('|x_k|',float(torch.norm(xk)),'g_k',float(torch.norm(gk)))           
    
    print(f'epoch:{epoch}, loss={loss.item():.4f}, |gd|={gk.norm().item():.4e}')
    
    with torch.no_grad():
        y_predicted = model(X_train)
        y_predicted_cls = y_predicted.round() 
        acc = y_predicted_cls.eq(y_train).sum()/float(y_train.shape[0])
        print(f'accuracy:{acc.item():.4f}')    

    results['train_loss'].append(float(loss))
    #results['train_prec'].append(float(correct / len(train_loader.dataset)))
    results['train_gd'].append(float(gk.norm()))
    results['time'].append(time.time()-start_time)
        


def test(model, device, test_loader,optimizer):
    model.eval()
    test_loss = 0
    correct = 0
    gradient = None
    params = optimizer.param_groups[0]['params']
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion2(output, target).item()  # sum up batch loss
            #pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability

            pred = torch.tensor([(1 if a >= 0.5 else 0) for a in output]).to(device)
            correct += pred.eq(target.view_as(pred)).sum().item()            

    test_loss /= len(test_loader.dataset)    

    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    #sqgd = float(torch.dot(gradient, gradient))
    #print('(gd,gd):', sqgd)
    results['test_loss'].append(float(test_loss))
    results['test_prec'].append(float(correct / len(test_loader.dataset)))
    #results['test_sqgd'].append(sqgd)

def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--train-part-size', type=int, default=1000, metavar='N',
                        help='part size for training (default: 1000)')
    parser.add_argument('--test-part-size', type=int, default=100, metavar='N',
                        help='part size for testing (default: 100)')
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=64, metavar='N',
                        help='input batch size for testing (default: 64)')
    parser.add_argument('--epochs', type=int, default=30, metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--dry-run', action='store_true', default=False,
                        help='quickly check a single pass')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=1, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')
    parser.add_argument('--dump-data', default='output.ser', type=str, metavar='PATH',
                        help='path to save loss\correct data (default: output.ser)')
    parser.add_argument('--optim', default='sgdm', type=str,
                        help='optimizer (default: SGDM)')

    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    # torch.manual_seed(args.seed)
    # torch.cuda.manual_seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    import random
    import numpy as np
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    device = torch.device("cuda" if use_cuda else "cpu")

    train_part_kwargs = {'batch_size': args.train_part_size, 'shuffle': True}
    test_part_kwargs = {'batch_size': args.test_part_size, 'shuffle': True}
    if use_cuda:
        cuda_kwargs = {'num_workers': 1,
                       'pin_memory': True,
                       'shuffle': True}
        train_part_kwargs.update(cuda_kwargs)
        test_part_kwargs.update(cuda_kwargs)

    #transform = transforms.Compose([
    #    transforms.ToTensor(),
    #    transforms.Normalize((0.1307,), (0.3081,))
    #])       

    import scipy.io as io
    train_matr = io.loadmat('madelon_train.mat')
    train_fea = train_matr['train_features']
    train_gnd = train_matr['train_labels']
    test_matr = io.loadmat('madelon_test.mat')
    test_fea = test_matr['test_features']
    test_gnd = test_matr['test_labels']
    
    X_train = train_fea.toarray()
    y_train = train_gnd
    X_test = test_fea.toarray()
    y_test = test_gnd
    
    sc = StandardScaler()
    X_train = sc.fit_transform(X_train)    
    X_test = sc.transform(X_test)        

    X_train = torch.from_numpy(X_train.astype(np.float64))    
    X_test = torch.from_numpy(X_test.astype(np.float64))
    y_train = torch.from_numpy(y_train.astype(np.float64))
    y_test = torch.from_numpy(y_test.astype(np.float64))
    y_train = (y_train+1)/2
    y_test = (y_test+1)/2

    y_train = y_train.view(y_train.shape[0], 1)  
    y_test = y_test.view(y_test.shape[0], 1)
    
    X_train = X_train.to(device)
    y_train = y_train.to(device)
    X_test = X_test.to(device)
    y_test = y_test.to(device)
                   
    model = Model(X_train.shape[1]).to(device)
    criterion = torch.nn.BCELoss(reduction='mean')
        
    wd = 0           
    
    sgd = optim.SGD(model.parameters(), lr=args.lr, weight_decay=wd, momentum=0)
    
            
    #lbfgs = optim.LBFGS(model.parameters(),lr=1,max_iter=1,history_size=20,tolerance_grad=0,tolerance_change=0,weight_decay=wd)    
    #mlbfgs = optim.LBFGS(model.parameters(),lr=1,max_iter=1,history_size=1,tolerance_grad=0,tolerance_change=0,weight_decay=wd)
        
    stbfgs = optim.STBFGS(sgd,beta=args.lr,tao=1e-16)
    # stbfgs = optim.STBFGS(sgd,beta=1.,tao=1e-16,m=50)    # also restart at 10th iteration, for the case wd=0 or 1e-4 
    # 
    # compare
    #anderson1 = optim.Anderson1(sgd,beta=1,hist_length=20) 
    #anderson2 = optim.Anderson2(sgd,beta=1,hist_length=20) 
    
    #nag = optim.NAG(model.parameters(), mu=0.10132827,L=1.4479628, weight_decay=wd)    #wd=1e-1     
    #nag = optim.NAG(model.parameters(), mu=0.00111464,L=1.0993723, weight_decay=wd)    #wd=1e-4
    #nag = optim.NAG(model.parameters(), mu=0.01109528,L=1.1740165, weight_decay=wd)     #wd=1e-2
    #nag = optim.NAG(model.parameters(), mu=1.0112892e-03,L=1.0967612e+00, weight_decay=wd)      #wd=0    
    #nag = optim.NAG(model.parameters(), mu=0.1,L=2, weight_decay=wd)    
    #ncg = optim.NCG(model.parameters(), lr=args.lr, weight_decay=wd,line_search_fn = "strong_wolfe")
    #stam2 = optim.STAM2(sgd,beta=1.0,tao=1e-16)
    #optimizer = ncg
    
    optimizer = stbfgs

    # scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    start_time = time.time()
    for epoch in range(1, args.epochs + 1):
        train(model, X_train,y_train, optimizer, epoch,wd,start_time)
        #test(model, device, test_loader,optimizer)
        # scheduler.step()
    
    if isinstance(optimizer,optim.NCG):
        print('func_evals:',optimizer.get_func_evals())

    if args.save_model:
        torch.save(model.state_dict(), "madelon_solution.pt")

    pickle.dump(results, open(args.dump_data, 'wb'))
    
    flag_eval = False               # weather to evaluate Ritz values
    if isinstance(optimizer,optim.STBFGS) and flag_eval:    
    
        eig = optimizer.geteig()
        
        #model = Model(n_features)

        #model.load_state_dict(torch.load("solution.pt"))

        def my_criterion(a,b,w,model):
            tmp_sum = 0.
            for p in model.parameters():
                tmp_sum += torch.sum(p*p)
            return torch.nn.BCELoss(reduction='mean')(a,b)+0.5*w*tmp_sum

        model.eval()
                
        criterion = my_criterion
                        
        inputs,targets = X_train,y_train
        
        # Compute the Ritz values using Lanczos algorithm
        # We used the implementation from https://github.com/amirgholami/pyhessian.git (under an MIT license) 
        hessian_comp = hessian(model,criterion,weight_decay=wd,data=(inputs,targets),cuda=True)
        density_eigen,_ = hessian_comp.density()

        eig_tmp = density_eigen[0]
        
        
        print('****************')
        print(np.sort(eig))
        print(np.sort(eig_tmp))    
        # print(eig_tmp)    

        import matplotlib.pyplot as plt
        print(len(eig_tmp), len(eig))
        plt.tick_params(labelsize='xx-large')
        plt.scatter(eig_tmp, np.zeros(len(eig_tmp)), label='Lanczos', color='darkorange', marker='o')
        plt.scatter(eig.real, eig.imag, label='Min-AM', color='blue', marker='+')
        plt.legend()
        plt.legend(fontsize='x-large')
        plt.show()
        pickle.dump((eig,eig_tmp),open('minam_eig_lr'+str(args.lr)+'.ser','wb'))
              
        #plt.savefig('madelon_eigen_'+str(wd)+'_'+str(args.lr)+'_'+ str(time_stamp)+'.pdf',format='pdf',dpi=120,bbox_inches='tight')


if __name__ == '__main__':
    main()
