from __future__ import print_function
import logging
import os
import sys
import datetime
import time 

import numpy as np
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

from utils import *
from tqdm import tqdm, trange

from pyhessian.utils import get_params_grad, group_product, normalization
from pyhessian.hessian import hessian


from models.simple_net import simple_net, simpler_net
from models.vgg import vgg11


EPS = 1e-24

# Training settings
parser = argparse.ArgumentParser(description='Training ')

parser.add_argument('--dataset-type',
                    type=str,
                    default='cifar10',
                    help='choose dataset')
parser.add_argument('--dataset-ratio',
                    default=1.0,
                    type=float,
                    help='subsampling ratio from training data (default: 1)')
parser.add_argument('--normalization',
                    action='store_true',
                    help='do we use normalization or not')
parser.add_argument('--data_augmentation',
                action='store_true',
                help='do we use data_augmentation or not')

parser.add_argument('--batch-size',
                    type=int,
                    default=10000,
                    help='input batch size for training (default: 10000)')
parser.add_argument('--test-batch-size',
                    type=int,
                    default=10000,
                    help='input batch size for testing (default: 10000)')
parser.add_argument('--hessian-bs',
                    type=int,
                    default=2500,
                    help='input batch size for training (default: 2500)')
parser.add_argument('--epochs',
                    type=int,
                    default=4000,
                    help='number of epochs to train (default: 4000)')
parser.add_argument('--lr',
                    type=float,
                    default=0.04,
                    help='learning rate (default: 0.04)')
parser.add_argument('--weight-decay',
                    default=0,
                    type=float,
                    help='weight decay (default: 0)') 
parser.add_argument('--momentum',
                    default=0,
                    type=float,
                    help='momentum (default: 0)') 
parser.add_argument('--model',
                    type=str,
                    default='6CNN',
                    help='choose model structure')

parser.add_argument('--cuda',
                    action='store_false',
                    help='do we use gpu or not')
parser.add_argument('--parallel',
                    action='store_false',
                    help='do we use parallel or not (default: Ture)') 

parser.add_argument('--saving-folder',
                    type=str,
                    default='pretrained/',
                    help='choose saving name')
parser.add_argument('--savemodels',
                    action='store_true',
                    help='save models')
parser.add_argument('--name',
                    type=str,
                    default='noname',
                    help='choose saving name')
parser.add_argument('--overwrite',
                    action='store_false',
                    help='do we rewrite or not')
parser.add_argument('--seed',
                    type=int,
                    default=1,
                    help='random seed (default: 1)')


parser.add_argument('--sn-iter',
                    type=int,
                    default=200,
                    help='max iteration for power iteration (default: 200)')

parser.add_argument('--criterion',
                    type=str,
                    default='cross-entropy')

parser.add_argument('--reg_w',
                    default=0,
                    type=float,
                    help='(default: 0)') 

args = parser.parse_args()

time_now = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
print(time_now)
logger = logging.getLogger(__name__)
logging.basicConfig(
    filename=os.path.join('./log/'+args.name+'_'+time_now+'.log'),
    format='[%(asctime)s] - %(message)s',
    datefmt='%Y/%m/%d %H:%M:%S',
    level=logging.DEBUG
    )

logger.info(args)
     
# set random seed to reproduce the work
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

for arg in vars(args):
    print(arg, getattr(args, arg))

# get dataset
if args.dataset_type == 'cifar10':
    if args.dataset_ratio == 1.0:
        train_loader, test_loader = getData(name='cifar10',
                                        train_bs=args.batch_size,
                                        test_bs=args.test_batch_size, data_augmentation=args.data_augmentation,normalization=args.normalization)
        hessian_loader, _ = getData(name='cifar10',
                                        train_bs=args.hessian_bs,
                                        test_bs=args.test_batch_size, data_augmentation=args.data_augmentation,normalization=args.normalization,
                                        subset=list(range(args.hessian_bs)))
    else:
        train_loader, test_loader = getData(name='cifar10',                                            
                                        train_bs=args.batch_size,
                                        test_bs=args.test_batch_size, data_augmentation=args.data_augmentation,normalization=args.normalization,
                                        subset=list(range(int(50000*args.dataset_ratio))))   
        hessian_loader, _ = getData(name='cifar10',
                                        train_bs=args.hessian_bs,
                                        test_bs=args.test_batch_size, data_augmentation=args.data_augmentation,normalization=args.normalization,
                                        subset=list(range(int(50000*args.dataset_ratio))))   

elif args.dataset_type == 'mnist':
    train_loader, test_loader = getData(name='mnist',
                                        train_bs=args.batch_size,
                                        test_bs=args.test_batch_size, data_augmentation=args.data_augmentation,normalization=args.normalization)
    hessian_loader, _ = getData(name='mnist',
                                        train_bs=args.hessian_bs,
                                        test_bs=args.test_batch_size, data_augmentation=args.data_augmentation,normalization=args.normalization)
    
# get model and optimizer
if args.model == '6CNN':
    model = simple_net(in_channel=3, widen_factor=1, n_fc=100, num_classes=10)
elif args.model == 'VGG':
    model = vgg11()
elif args.model == '3FCN':
    model = get_model('fc',dataset_type=args.dataset_type)
    
else:
    raise ValueError("Unknown model")
        
if args.cuda:
    model = model.cuda()
if args.parallel:
    model = torch.nn.DataParallel(model)
print(model)
print('# params:',count_parameters(model))

optimizer = optim.SGD(model.parameters(), 
                      lr=args.lr,
                      momentum=args.momentum,
                      weight_decay=args.weight_decay)


if args.savemodels:
    if not os.path.isdir(args.saving_folder):
        os.makedirs(args.saving_folder)
    if os.path.isdir(args.saving_folder + args.name):
        if args.overwrite:
            pass
        else:
            args.name = args.name + time_now
            os.mkdir(args.saving_folder + args.name)
    else:
        os.mkdir(args.saving_folder + args.name)
    print(args.saving_folder + args.name)

    torch.save(model.state_dict(), args.saving_folder + args.name + '/model_0.pth' )
    

evl_list = []  

def _one_hot(tensor: torch.Tensor, num_classes: int, default=0):
    M = F.one_hot(tensor, num_classes)
    M[M == 0] = default
    return M.float()    
class SquaredLoss(nn.Module):
    def forward(self, input: torch.Tensor, target: torch.Tensor):
#         print('loss shape', ((input - _one_hot(target,10)) ** 2).sum().shape )
#         print('len input' , len(input))
        return 0.5 * ((input - _one_hot(target,10)) ** 2).sum()/len(input)
class SquaredLoss_sum(nn.Module):
    def forward(self, input: torch.Tensor, target: torch.Tensor):
#         print('loss sum shape', ((input - _one_hot(target,10)) ** 2).sum().shape )
        return 0.5 * ((input - _one_hot(target,10)) ** 2).sum()
    
if args.criterion =='cross-entropy':
    criterion = nn.CrossEntropyLoss()
    criterion_sum = nn.CrossEntropyLoss(reduction='sum')
elif args.criterion =='mse':
    criterion = SquaredLoss()
    criterion_sum = SquaredLoss_sum()
                           
for data_h, target_h in hessian_loader:
    if args.cuda:
        data_h, target_h = data_h.cuda(), target_h.cuda()
    hessian_dataloader = (data_h, target_h)
    break

whole_fix_u = torch.randn(1,10)    
for epoch in range(1, args.epochs + 1):
    print('Current Epoch: ', epoch)
    train_loss = 0.
    train_clean_loss = 0.
    total_num = 0
    correct = 0
    clean_correct = 0    
    
    log_list = [0]*40
    start_time = time.time()
    lr = optimizer.__dict__['param_groups'][0]['lr'] 
    
    optimizer.zero_grad()
    
    
    with tqdm(total=len(train_loader.dataset)) as progressbar:

        for batch_idx, (data, target) in enumerate(train_loader):
            step = (epoch-1)*len(train_loader) + batch_idx 
            model.train()
            if args.cuda:
                data, target = data.cuda(), target.cuda()

#             output = model(data) #############################
            
            ### compute the Hessian
            if batch_idx == 0: 
                model.zero_grad()
                hessian_comp = hessian(model,
                                       criterion,
                                       data=hessian_dataloader,
                                       cuda=args.cuda)
                top_eigenvalues, top_eigenvectors = hessian_comp.eigenvalues()   ## 14             
                
                trace = hessian_comp.trace()
                evl_list.append(top_eigenvalues)
                model.zero_grad()
                
            output = model(data)   
            
            reg = 0   
            if args.reg_w != 0:
                rand_batch = torch.randn(1,10).repeat(output.shape[0],1).cuda()
                rand_batch_nm = rand_batch/(torch.norm(rand_batch,p=2,dim=-1,keepdim=True)+EPS)
                for param in model.module.parameters():
                    if not param.requires_grad:
                        continue               
                    reg_J = torch.autograd.grad(outputs=output,inputs=param, grad_outputs=rand_batch_nm
                                              , retain_graph = True, create_graph = True)[0]/output.shape[0]
                    reg += torch.norm(reg_J,p=2)**2 ## 
            
            loss = criterion_sum(output, target)/len(train_loader.dataset)+args.reg_w*reg
        
        
            if batch_idx == 0:
                p = torch.nn.Softmax(1)(output)
                p, p_index = p.sort(descending=True)
                output_sort, _ = torch.sort(output,descending=True)
                p_ = p.unsqueeze(-1)
                ppt = p_.bmm(p_.transpose(1,2))
                diagp = torch.diag_embed(p)
                M = diagp-ppt
                
                ## the top eigenvalue of the logit Hessian
                snM = batch_power_iteration_evl(M,args.sn_iter) ##18
                
                q_1 = p/(p-snM+EPS)	
                q_1_nm = q_1/(torch.norm(q_1,p=2,dim=-1,keepdim=True)+EPS)
                
                M2= M - snM.unsqueeze(-1)*torch.bmm(q_1_nm.unsqueeze(-1),torch.transpose(q_1_nm.unsqueeze(-1),2,1))
                
                ## the second largest eigenvalue of the logit Hessian 
                snM2 = batch_power_iteration_evl(M2,args.sn_iter)     
                
                trM = 1-(p**2).sum(dim=-1,keepdim=True)
                
                J_1_norm_sq = 0
                
                one_nm = torch.ones_like(q_1_nm)/(10**(1/2))
                for param in model.module.parameters():
                    if not param.requires_grad:
                        continue                
                    J_1 = torch.autograd.grad(outputs=output,inputs=param, grad_outputs=one_nm
                                              , retain_graph = True)[0]/output.shape[0]
                    J_1_norm_sq += torch.norm(J_1,p=2)**2 ## 23
                    
                
            train_loss += loss.item()
            total_num += target.size()[0]
            _, predicted = output.max(1)
            correct_in_batch = predicted.eq(target).sum().item()
            correct += correct_in_batch
            
            loss.backward()
        
            progressbar.set_postfix(loss=train_loss,
                                    acc=100. * correct / total_num)
            progressbar.update(target.size(0))
    _, grads = get_params_grad(model)
    optimizer.step() 
    ################# Test after update
    
    gq = np.abs(group_product(normalization(grads),top_eigenvectors[0]).item())
    
    if epoch!=1:
        qq = np.abs(group_product(tmp_eigenvectors[0],top_eigenvectors[0]).item())
        gg = group_product(normalization(tmp_grads),normalization(grads)).item()
    else:
        qq = 0
        gg = 0
    tmp_eigenvectors = top_eigenvectors  
    tmp_grads = grads
        

    e_train_loss = train_loss
    e_train_acc = correct/total_num
    train_time = time.time()

    acc, test_loss = test(model, test_loader)
    test_time = time.time()     

    log_list[:10] = [epoch, train_time - start_time, test_time - train_time, lr,
                e_train_loss, e_train_acc,0,0,test_loss, acc]
    log_list[12] = gq
    log_list[13] = qq
    log_list[14] = evl_list[-1][0]
    log_list[15] = group_product(grads,grads).item()
    log_list[16] = gg
    log_list[17] = np.mean(trace) ###
    log_list[18] = snM.mean().item()
    log_list[22] = snM2.mean().item()
    log_list[23] = J_1_norm_sq.item()**(1/2)
    log_list[30] = trM.mean().item()
    log_list[35] = snM.std().item()
    log_list[36] = snM.quantile(0.75).item()
    log_list[37] = snM.quantile(0.25).item()
    log_list[38] = snM.median().item()
    
    log_list = log_list[:40+1]

    log_input = (*log_list, )
    logger.info(('%d'+'\t%.4f'*(len(log_input)-1))%(log_input))

    ### every epoch
    if args.savemodels:
        torch.save({
                'epoch':epoch,
                'train_loss':e_train_loss,
                'model_state_dict':model.state_dict()},
                args.saving_folder + args.name + '/model_' + str(epoch)+ '.pth' )
