import os
import sys
import warnings
import pickle
import torch
import torch.nn as nn
import math
from data.data_loader import data_loader, get_deri_loader
from model.linear import Linear
import model.vgg as vgg
from utils.essen_plot import plot_loss, plot_model_output, plot_eig_vs_var, plot_eig_vs_mean, plot_loss_landscape
from utils.save_path import CheckpointSaver, create_save_dir, save_code_and_config
from utils.derivatives_of_parameters import derivatives, get_hessian_eig
from utils.get_weight_matrix_and_pca import Get_weight_matrix_and_pca, get_loss_for_weight_matrix
import shutil
import platform
import numpy as np
from config.config import parse_args
import copy
import random
warnings.filterwarnings("ignore")


os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5,6,7'


def one_hot(x, class_count):
    return torch.eye(class_count)[x, :]


def seed_torch(seed=1029):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)  # 为了禁止hash随机化，使得实验可复现
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def main():

    args, _ = parse_args()
    argsy = {}
    argsy['test_outputs'] = []
    print(args.hidden_layers_width)
    seed = np.random.randint(1000000)
    seed=0
    
    args.seed = seed
    seed_torch(args.seed)
    R = {}
    R['loss_train'] = []
    R['exploration_para'] = []
    R['ini_trace'] = []
    R['iden_trace'] = []
    # R['acc_test'] = []
    args.device = torch.device("cuda:%s" % (
        args.device_rank) if torch.cuda.is_available() else "cpu")
    # args.device = 'cpu'
    # path_ori = '/home/xxx/data/saddle_points/test104/'
    args.model_name = '_'.join(map(str, args.hidden_layers_width))
    args.path = create_save_dir(args.ini_output_dir, args.model_name, args.t)
    print(args.path)

    if not platform.system() == 'Windows':

        for i in args.save_dir:
            print(i)
            save_code_and_config(i, args.path)
        shutil.copy(__file__, '%scode/%s' %
                    (args.path, os.path.basename(__file__)))
    if args.network_type == 'linear':
        model = Linear(args.t, args.hidden_layers_width, args.input_dim,
                       args.output_dim, nn.ReLU(), args.initialization, args.dropout, args.dropout_pro, args.bias).to(args.device)
    if args.network_type == 'vgg':
        model =vgg.VGG9(args.dropout, args.dropout_pro).to(args.device)
    print(model)
    if args.change_dropout_pro==True:
        model.features[4].p = 0.0


    for param_tensor in model.state_dict():
        # 打印 key value字典
        print(param_tensor)
    if args.data == '1Dpro':
        train_loader, test_loader, test_inputs, train_inputs, test_targets, train_targets = data_loader(
            training_batch_size=args.training_batch_size, test_batch_size=args.test_batch_size, training_size=args.training_size,  data=args.data,  args=args)
        args.train_inputs, args.test_inputs = train_inputs, test_inputs
        args.train_targets, args.test_targets = train_targets, test_targets
    elif args.data == 'MNIST':
        train_loader, test_loader = data_loader(
            training_batch_size=args.training_batch_size, test_batch_size=args.test_batch_size, training_size=args.training_size,  data=args.data,  args=args)
        if args.change_loader==True:
            train_batch_loader,_=data_loader(
            training_batch_size=args.changed_training_batch_size, test_batch_size=args.test_batch_size, training_size=args.training_size,  data=args.data,  args=args)
        else:
            train_batch_loader=None

    elif args.data == 'cifar10':
        train_loader, test_loader = data_loader(
            training_batch_size=args.training_batch_size, test_batch_size=args.test_batch_size, training_size=args.training_size,  data=args.data,  args=args)
        if args.change_loader==True:
            train_batch_loader,_=data_loader(
            training_batch_size=args.changed_training_batch_size, test_batch_size=args.test_batch_size, training_size=args.training_size,  data=args.data,  args=args)
        else:
            train_batch_loader=None

    # for i, data in enumerate(train_loader):
    #     if i <20:
    #         print("第 {} 个Batch \n{}".format(i, data))
    # sys.exit()
    # loss_fn = torch.nn.MSELoss(reduction='mean')
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
    if args.use_nesterov:
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.0, nesterov=True)
    
    # optimizer = torch.optim.SGD(model.parameters())
    train(model, optimizer, loss_fn, R, train_loader, test_loader, args, argsy,train_batch_loader)


def train_one_step(epoch, model, optimizer, loss_fn, R, train_loader, args, save_para):
    runing_loss = 0.0
    correct = 0
    total = 0
    device = args.device
    for batch_idx, (data, target) in enumerate(train_loader, 1):

        model.train()

        data, target = data.to(device), target.to(device)
        batch_size = data.size(0)

        # print(inputs)
        if save_para == True:
            model.eval()
            if model.features[2].bias is not None:
                R['exploration_para'].append(
                    np.concatenate((model.features[2].weight.detach().cpu().numpy().reshape((-1, 1)).squeeze(), model.features[2].bias.detach().cpu().numpy().reshape((-1, 1)).squeeze())))
            else:
                R['exploration_para'].append(
                    model.features[2].weight.detach().cpu().numpy().reshape((-1, 1)).squeeze())
            model.train()
        optimizer.zero_grad()
        tensor = {}

        if args.add_tru_on_weight == True:
            for name, p in model.named_parameters():
                if 'features.2' in name:
                    # print(name)
                    tensor[name] = copy.deepcopy(p.data)
                    p.data = p.data + \
                        torch.randn_like(p.data, device=args.device,
                                         requires_grad=True)*args.turblence

        outputs = model(data)

        if args.softmax:
            outputs = torch.nn.functional.softmax(outputs)
        if args.one_hot:
            target_onehot = one_hot(target, args.output_dim).to(device)
            loss = loss_fn(outputs, target_onehot.long())
        else:

            loss = loss_fn(outputs, target.long())
        loss.backward()



        if args.add_tru_on_weight == True:
            for name, p in model.named_parameters():
                if 'features.2' in name:
                    p.data = tensor[name]

        if args.add_tru_on_grad == True:
            for name, p in model.named_parameters():
                if 'features.2' in name:
                    p.grad.data = p.grad.data + \
                        torch.randn_like(p.data, device=args.device,
                                         requires_grad=True)*args.turblence
        optimizer.step()

        runing_loss += loss.item()*batch_size

        _, predicted = torch.max(outputs.data, dim=1)  # add if accuracy~
        total += batch_size
        correct += (predicted == target).sum().item()

    R['loss_train'].append(runing_loss/total)

    acc = 100*correct/total
    # if args.method=='zhuzhanxing':
    #     return loss, acc, initial_trace, iden_trace

    return runing_loss/total, acc


def test(model, test_loader, loss_fn, args, argsy):
    model.eval()
    runing_loss = 0.0
    correct = 0
    total = 0
    device = args.device
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)
            batch_size = labels.size(0)
            outputs = model(images)
            if args.plot_output:

                argsy['test_outputs'].append(outputs.detach().cpu().numpy())
            if args.softmax:
                outputs = torch.nn.functional.softmax(outputs)
            if args.one_hot:
                labels_onehot = one_hot(labels, args.output_dim).to(device)
                loss = loss_fn(outputs, labels_onehot.long())
            else:
                loss = loss_fn(outputs, labels.long())
            runing_loss += loss.item()*batch_size

            _, predicted = torch.max(outputs.data, dim=1)
            total += batch_size
            correct += (predicted == labels).sum().item()
    # R['acc_test'].append(100*correct/total)
    return runing_loss/total, 100*correct/total


def train(model, optimizer, loss_fn, R, train_loader, test_loader, args, argsy,train_batch_loader):

    log_train_file = os.path.join(args.path, 'train.log')
    log_valid_file = os.path.join(args.path, 'valid.log')
    exploration_step = -10000000

    with open(log_train_file, 'w') as log_tf, open(log_valid_file, 'w') as log_vf:
        log_tf.write('epoch,loss,accuracy\n')
        log_vf.write('epoch,loss,accuracy\n')
    save_para = False
    # train_loader_lst, full_dataloader = get_deri_loader(args)
    print('creating data loader')
    # _, full_dataloader = get_deri_loader(args)

    for epoch in range(args.training_steps):
        # if epoch % (args.plot_epoch*args.training_batch_size/args.training_size) == 0:
        if epoch % (args.plot_epoch) == 0:
            plot_loss(args.path, R, x_log=True)
            plot_loss(args.path, R, x_log=False)
            if args.plot_output:
                plot_model_output(args.path, args, argsy, epoch)

            with open('%s/objs.pkl' % (args.path), 'wb') as f:
                pickle.dump(R, f, protocol=4)
            print(args.path)
        # if epoch % (args.save_epoch*args.training_batch_size/args.training_size) == 0:
        if epoch % (args.save_epoch) == 0:
            # model.eval()
            saver = CheckpointSaver(
                model=model, optimizer=optimizer, args=args, path=args.path, extension='%s.pth.tar' % (epoch))
            saver.save_checkpoint(epoch)
            # with open('%s/objsy.pkl' % (args.path), 'wb') as f:
            #     pickle.dump(argsy, f, protocol=4)
        model.train()
        loss, acc = train_one_step(
            epoch, model, optimizer, loss_fn, R, train_loader, args, save_para)
        loss_val, acc_val = test(model, test_loader, loss_fn, args, argsy)
        if args.method == 'zhuzhanxing' and epoch % 10 == 0:
            Get_hessian_eig = get_hessian_eig(
                model, full_dataloader=train_loader, loss_fn=loss_fn, dropout_size=0, dropout_times=args.Sampling_times, args=args)

            hessian = Get_hessian_eig.calcu_full_hessain()
            hessian = hessian.detach().cpu().numpy()

            model.features[4].p = args.dropout_pro
            grads = []
            # Derivatives = derivatives(model, loss_fn, args)
            for i in range(args.Sampling_times):
                model.train()
                Get_hessian_eig = get_hessian_eig(
                    model, full_dataloader=train_loader, loss_fn=loss_fn, dropout_size=0, dropout_times=args.Sampling_times, args=args)

                gradient = Get_hessian_eig.calcu_full_grad()

                # get gradient: gradient
                grads.append(gradient.detach().cpu().numpy())

            grads = np.stack(grads)
            grad_mean = np.mean(grads, axis=0)
            grads = grads - np.tile(grad_mean, (args.Sampling_times, 1))
            covariance = np.matmul(grads.T, grads)/(args.Sampling_times-1)

            initial_trace = np.trace(np.matmul(hessian, covariance))

            iden_trace = np.trace(covariance) * \
                np.trace(hessian)/hessian.shape[0]
            R['ini_trace'].append(initial_trace)
            R['iden_trace'].append(iden_trace)
            model.features[4].p = 0.0
            np.savetxt('%sini_trace.txt' % (args.path), R['ini_trace'])
            np.savetxt('%siden_trace.txt' % (args.path), R['iden_trace'])
        if epoch % 1 == 0:
            print("[%d] loss: %.6f acc: %.2f valloss: %.6f valacc: %.2f " %
                  (epoch + 1, loss, acc, loss_val, acc_val))

        if epoch == 3000 and args.change_dropout_pro==True:
            model.features[4].p = args.dropout_pro

        if epoch == 80 and args.change_loader==True:
            train_loader=train_batch_loader
            args.training_batch_size=args.changed_training_batch_size
        # if epoch == args.exploration_step:
        if loss < 2e-4 and save_para == False and args.cal_pca == True:
            model.eval()
            para_dict = copy.deepcopy(model.state_dict())
            saver.save_checkpoint(epoch)
            if args.method == 'zhuzhanxing':
                break

            if args.no_training:
                train_loader_lst, full_dataloader = get_deri_loader(args)
                for ind in range(args.Sampling_times):
                    # for ind, i in enumerate(train_loader_lst):
                    print(ind)
                    dataloader_sample = full_dataloader if train_loader_lst == [] else train_loader_lst[ind]
                    for batch_idx, (data, target) in enumerate(dataloader_sample, 1):
                        # for batch_idx, (data, target) in enumerate(train_loader_lst[ind], 1):
                        model.load_state_dict(para_dict)
                        model.train()
                        tensor = {}

                        data, target = data.to(
                            args.device), target.to(args.device)
                        inputs = data
                        optimizer.zero_grad()

                        if args.add_tru_on_weight == True:
                            for name, p in model.named_parameters():
                                if 'features.2' in name:
                                    # print(name)
                                    tensor[name] = copy.deepcopy(p.data)
                                    p.data = p.data + \
                                        torch.randn_like(
                                            p.data, device=args.device, requires_grad=True)*args.turblence

                        outputs = model(inputs)
                        loss = loss_fn(outputs, target)
                        # print(inputs,target)
                        loss.backward()
                        if args.add_tru_on_grad == True:
                            for name, p in model.named_parameters():
                                if 'features.2' in name:
                                    p.grad.data = p.grad.data + \
                                        torch.randn_like(
                                            p.grad.data, device=args.device, requires_grad=True)*args.turblence

                        if args.add_tru_on_weight == True:
                            for name, p in model.named_parameters():
                                if 'features.2' in name:
                                    # print(name)
                                    # tensor = copy.deepcopy(p.data)
                                    p.data = tensor[name]

                        optimizer.step()

                        model.eval()
                        if model.features[2].bias is not None:
                            R['exploration_para'].append(
                                np.concatenate((model.features[2].weight.detach().cpu().numpy().reshape((-1, 1)).squeeze(), model.features[2].bias.detach().cpu().numpy().reshape((-1, 1)).squeeze())))
                        else:
                            R['exploration_para'].append(
                                model.features[2].weight.detach().cpu().numpy().reshape((-1, 1)).squeeze())
                        model.train()
                with open('%s/objs.pkl' % (args.path), 'wb') as f:
                    pickle.dump(R, f, protocol=4)
                print('finish saving objs')
                print(args.method)
                if args.method == 'pca':
                    print('calculating pca')
                    sample_index = [2, 4, 6, 8, 10, 30, 50, 70,
                                    90, 100, 300, 400, 600, 700, 900, 1000]
                    get_weight_matrix_and_pca = Get_weight_matrix_and_pca(
                        R, args, loss_fn, full_dataloader, para_dict, sample_index)
                    sigma2, _ = get_weight_matrix_and_pca.get_pca_matrix()
                    np.savetxt('%ssigma.txt' % (args.path), np.real(sigma2))
                    loss_ini = get_weight_matrix_and_pca.getloss(0, 0)
                    theta = np.linspace(-10, 10, 210, True)
                    print(loss_ini)
                    theta_vector_posi = []
                    theta_vector_nage = []
                    # loss_all=[]

                    for i in range(len(sample_index)):
                        print('i=%s' % (i))
                        # loss_lst = get_weight_matrix_and_pca.get_landscape_fig(
                        #     theta, i)
                        # # loss_all.append(loss_lst)
                        # plot_loss_landscape(args.path, theta, loss_lst, i)

                        theta_posi, theta_nega = get_weight_matrix_and_pca.get_theta(
                            i, loss_ini, ini_a=-3, ini_b=3)
                        theta_vector_posi.append(theta_posi)
                        theta_vector_nage.append(theta_nega)
                        np.savetxt('%stheta_posi.txt' %
                                   (args.path), theta_vector_posi)
                        np.savetxt('%stheta_nage.txt' %
                                   (args.path), theta_vector_nage)

                    break
                if args.method == 'hessian':
                    print('calculating hessian')
                    model.load_state_dict(para_dict)
                    Get_hessian_eig = get_hessian_eig(
                        model, full_dataloader=full_dataloader, loss_fn=loss_fn, dropout_size=0, dropout_times=args.Sampling_times, args=args)
                    if model.features[2].bias is not None:
                        grad = R['exploration_para']-np.concatenate((model.features[2].weight.detach().cpu().numpy().reshape(
                            (-1, 1)).squeeze(), model.features[2].bias.detach().cpu().numpy().reshape((-1, 1)).squeeze()))
                    else:
                        grad = R['exploration_para'] - model.features[2].weight.detach(
                        ).cpu().numpy().reshape((-1, 1)).squeeze()
                    var, w = Get_hessian_eig.get_eigvalue_and_var(
                        first_derivative_lst=grad)
                    print(var.shape)
                    np.savetxt('%svar.txt'%(args.path), var)
                    np.savetxt('%seig.txt'%(args.path), w)
                    plot_eig_vs_var(args.path, var, w, epoch)

                    break
            else:
                exploration_step = epoch
                if args.change_lr:
                    for param_group in optimizer.param_groups:
                        param_group["lr"] = args.changed_lr
                save_para = True
        if epoch == (exploration_step+int(args.Sampling_times*args.training_batch_size/args.training_size)):
            model.eval()
            # R['exploration_para'].append(
            #     model.features[2].weight.detach().cpu().numpy().reshape((-1, 1)).squeeze())
            save_para = False
            with open('%s/objs.pkl' % (args.path), 'wb') as f:
                pickle.dump(R, f, protocol=4)
            if args.method == 'pca':
                sample_index = [2, 4, 6, 8, 10, 30, 50, 70,
                                90, 100, 300, 400, 600, 700, 900, 1000]
                get_weight_matrix_and_pca = Get_weight_matrix_and_pca(
                    R, args, loss_fn, full_dataloader, para_dict, sample_index)
                sigma2, _ = get_weight_matrix_and_pca.get_pca_matrix()
                np.savetxt('%ssigma.txt' % (args.path), sigma2)
                loss_ini = get_weight_matrix_and_pca.getloss(0, 0)
                theta = np.linspace(-10, 10, 210, True)
                print(loss_ini)
                theta_vector_posi = []
                theta_vector_nage = []
                # loss_all=[]

                for i in range(len(sample_index)):
                    print('i=%s' % (i))
                    # loss_lst = get_weight_matrix_and_pca.get_landscape_fig(
                    #     theta, i)
                    # # loss_all.append(loss_lst)
                    # plot_loss_landscape(args.path, theta, loss_lst, i)

                    theta_posi, theta_nega = get_weight_matrix_and_pca.get_theta(
                        i, loss_ini, ini_a=-3, ini_b=3)
                    theta_vector_posi.append(theta_posi)
                    theta_vector_nage.append(theta_nega)
                    np.savetxt('%stheta_posi.txt' %
                               (args.path), theta_vector_posi)
                    np.savetxt('%stheta_nage.txt' %
                               (args.path), theta_vector_nage)

                break
            if args.method == 'hessian':
                print('calculating hessian')
                model.load_state_dict(para_dict)
                Get_hessian_eig = get_hessian_eig(
                    model, full_dataloader=full_dataloader, loss_fn=loss_fn, dropout_size=0, dropout_times=args.Sampling_times, args=args)
                var, w = Get_hessian_eig.get_eigvalue_and_var(
                    first_derivative_lst=R['exploration_para'])
                print(var.shape)
                np.savetxt('%svar.txt', var)
                np.savetxt('%seig.txt', w)
                plot_eig_vs_var(args.path, var, w, epoch)

                break

        with open(log_train_file, 'a') as log_tf, open(log_valid_file, 'a') as log_vf:
            log_tf.write('{epoch},{loss: 8.5f},{accu:3.3f}\n'.format(
                epoch=epoch, loss=loss, accu=acc))
            log_vf.write('{epoch},{loss: 8.5f},{accu:3.3f}\n'.format(
                epoch=epoch, loss=loss_val, accu=acc_val))

        if R['loss_train'][-1] < args.stop_loss:
            plot_loss(args.path, R, x_log=True)
            plot_loss(args.path, R, x_log=False)
            break
        if args.network_type=='vgg':
            if int(epoch) == 150 or int(epoch) == 225 or int(epoch) == 275:
                # lr *= args.lr_decay
                for param_group in optimizer.param_groups:
                    param_group['lr'] *= 0.1



if __name__ == "__main__":
    main()
