from __future__ import print_function
import logging
import os
import sys
import datetime
import time 

import numpy as np
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

from utils import *
from tqdm import tqdm, trange

from pyhessian.utils import get_params_grad, group_product, normalization
from pyhessian.hessian import hessian


from models.simple_net import simple_net, simpler_net, simple_cnn
from models.vgg import vgg11

import setproctitle

EPS = 1e-24

# Training settings
parser = argparse.ArgumentParser(description='Training ')

parser.add_argument('--dataset-type',
                    type=str,
                    default='cifar10',
                    help='choose dataset')
parser.add_argument('--dataset-ratio',
                    default=1.0,
                    type=float,
                    help='subsampling ratio from training data (default: 1)')
parser.add_argument('--normalization',
                    action='store_true',
                    help='do we use normalization or not')
parser.add_argument('--data_augmentation',
                action='store_true',
                help='do we use data_augmentation or not')

parser.add_argument('--batch-size',
                    type=int,
                    default=128,
                    help='input batch size for training (default: 128)')
parser.add_argument('--test-batch-size',
                    type=int,
                    default=10000,
                    help='input batch size for testing (default: 10000)')
parser.add_argument('--hessian-bs',
                    type=int,
                    default=2500,
                    help='input batch size for training (default: 2500)')
parser.add_argument('--epochs',
                    type=int,
                    default=4000,
                    help='number of epochs to train (default: 4000)')
parser.add_argument('--lr',
                    type=float,
                    default=0.04,
                    help='learning rate (default: 0.04)')
parser.add_argument('--weight-decay',
                    default=0,
                    type=float,
                    help='weight decay (default: 0)') 
parser.add_argument('--momentum',
                    default=0,
                    type=float,
                    help='momentum (default: 0)') 
parser.add_argument('--model',
                    type=str,
                    default='6CNN',
                    help='choose model structure')

parser.add_argument('--cuda',
                    action='store_false',
                    help='do we use gpu or not')
parser.add_argument('--parallel',
                    action='store_false',
                    help='do we use parallel or not (default: Ture)') 

parser.add_argument('--saving-folder',
                    type=str,
                    default='pretrained/',
                    help='choose saving name')
parser.add_argument('--savemodels',
                    action='store_true',
                    help='save models')
parser.add_argument('--name',
                    type=str,
                    default='noname',
                    help='choose saving name')
parser.add_argument('--overwrite',
                    action='store_false',
                    help='do we rewrite or not')
parser.add_argument('--seed',
                    type=int,
                    default=1,
                    help='random seed (default: 1)')


parser.add_argument('--sn-iter',
                    type=int,
                    default=200,
                    help='max iteration for power iteration (default: 200)')

parser.add_argument('--criterion',
                    type=str,
                    default='cross-entropy')

parser.add_argument('--reg_w',
                    default=0,
                    type=float,
                    help='(default: 0)') 

args = parser.parse_args()

time_now = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
print(time_now)
logger = logging.getLogger(__name__)
logging.basicConfig(
    filename=os.path.join('./log/'+args.name+'_'+time_now+'.log'),
    format='[%(asctime)s] - %(message)s',
    datefmt='%Y/%m/%d %H:%M:%S',
    level=logging.DEBUG
    )

logger.info(args)
     
setproctitle.setproctitle(args.name+'_'+time_now)

# set random seed to reproduce the work
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

for arg in vars(args):
    print(arg, getattr(args, arg))

# get dataset
if args.dataset_type == 'cifar10':
    if args.dataset_ratio == 1.0:
        train_loader, test_loader = getData(name='cifar10',
                                        train_bs=args.batch_size,
                                        test_bs=args.test_batch_size, data_augmentation=args.data_augmentation,normalization=args.normalization)
    else:
        train_loader, test_loader = getData(name='cifar10',                                            
                                        train_bs=args.batch_size,
                                        test_bs=args.test_batch_size, data_augmentation=args.data_augmentation,normalization=args.normalization,
                                        subset=list(range(int(50000*args.dataset_ratio))))    

elif args.dataset_type == 'mnist':
    train_loader, test_loader = getData(name='mnist',
                                        train_bs=args.batch_size,
                                        test_bs=args.test_batch_size, data_augmentation=args.data_augmentation,normalization=args.normalization)
    
# get model and optimizer
if args.model == '6CNN':
    model = simple_net(in_channel=3, widen_factor=1, n_fc=100, num_classes=10)
elif args.model == 'VGG':
    model = vgg11()
elif args.model == '3FCN':
    model = get_model('fc',dataset_type=args.dataset_type)
elif args.model == 'SimpleCNN':
    model = simple_cnn(in_channel=3, widen_factor=1, n_fc=128, num_classes=10)
    
else:
    raise ValueError("Unknown model")
        
if args.cuda:
    model = model.cuda()
if args.parallel:
    model = torch.nn.DataParallel(model)
print(model)
print('# params:',count_parameters(model))

optimizer = optim.SGD(model.parameters(), 
                      lr=args.lr,
                      momentum=args.momentum,
                      weight_decay=args.weight_decay)


if args.savemodels:
    if not os.path.isdir(args.saving_folder):
        os.makedirs(args.saving_folder)
    if os.path.isdir(args.saving_folder + args.name):
        if args.overwrite:
            pass
        else:
            args.name = args.name + time_now
            os.mkdir(args.saving_folder + args.name)
    else:
        os.mkdir(args.saving_folder + args.name)
    print(args.saving_folder + args.name)

    torch.save(model.state_dict(), args.saving_folder + args.name + '/model_0.pth' )
    

evl_list = []  

def _one_hot(tensor: torch.Tensor, num_classes: int, default=0):
    M = F.one_hot(tensor, num_classes)
    M[M == 0] = default
    return M.float()    
class SquaredLoss(nn.Module):
    def forward(self, input: torch.Tensor, target: torch.Tensor):
#         print('loss shape', ((input - _one_hot(target,10)) ** 2).sum().shape )
#         print('len input' , len(input))
        return 0.5 * ((input - _one_hot(target,10)) ** 2).sum()/len(input)
class SquaredLoss_sum(nn.Module):
    def forward(self, input: torch.Tensor, target: torch.Tensor):
#         print('loss sum shape', ((input - _one_hot(target,10)) ** 2).sum().shape )
        return 0.5 * ((input - _one_hot(target,10)) ** 2).sum()
    
if args.criterion =='cross-entropy':
    criterion = nn.CrossEntropyLoss()
    criterion_sum = nn.CrossEntropyLoss(reduction='sum')
elif args.criterion =='mse':
    criterion = SquaredLoss()
    criterion_sum = SquaredLoss_sum()
              
whole_fix_u = torch.randn(1,10)    
for epoch in range(1, args.epochs + 1):
    print('Current Epoch: ', epoch)
    train_loss = 0.
    train_clean_loss = 0.
    total_num = 0
    correct = 0
    clean_correct = 0    
    
    log_list = [0]*12
    start_time = time.time()
    lr = optimizer.__dict__['param_groups'][0]['lr'] 
    
    with tqdm(total=len(train_loader.dataset)) as progressbar:

        for batch_idx, (data, target) in enumerate(train_loader):    
            step = (epoch-1)*len(train_loader) + batch_idx 
            model.train()
            if args.cuda:
                data, target = data.cuda(), target.cuda()
                
            output = model(data)
            reg = 0   
            if args.reg_w != 0:
                rand_batch = torch.randn(1,10).repeat(output.shape[0],1).cuda()
                rand_batch_nm = rand_batch/(torch.norm(rand_batch,p=2,dim=-1,keepdim=True)+EPS)
                for param in model.module.parameters():
                    if not param.requires_grad:
                        continue               
                    reg_J = torch.autograd.grad(outputs=output,inputs=param, grad_outputs=rand_batch_nm
                                              , retain_graph = True, create_graph = True)[0]/output.shape[0]
                    reg += torch.norm(reg_J,p=2)**2 ## 
                    
            loss = nn.CrossEntropyLoss()(output, target)+args.reg_w*reg
            
            train_loss += target.size()[0]*loss.item()
            total_num += target.size()[0]
            _, predicted = output.max(1)
            correct_in_batch = predicted.eq(target).sum().item()
            correct += correct_in_batch
            
            optimizer.zero_grad()            
            loss.backward()
        
            progressbar.set_postfix(loss=train_loss/total_num,
                                    acc=100. * correct / total_num)
            progressbar.update(target.size(0))
            
            optimizer.step() 
            optimizer.zero_grad()

            e_train_loss = loss.item() #######################
            e_train_acc = correct_in_batch/target.size()[0] ####################### correct/total_num
            train_time = time.time()
            if step%10==0:
                acc, test_loss = test(model, test_loader, print_opt=False)
            test_time = time.time()     

            log_list[:10] = [step, train_time - start_time, test_time - train_time, lr,
                        e_train_loss, e_train_acc,0,0,test_loss, acc]

            log_list = log_list[:10+1]

            log_input = (*log_list, )
            logger.info(('%d'+'\t%.4f'*(len(log_input)-1))%(log_input))