# from utils.essen_plot import plot_sigma_F
# path='/home/xxx/data/sgd/test24/50_50/0.5/'
# plot_sigma_F(path)

# import pickle
# import numpy as np
# from sklearn.decomposition import PCA
# path='/home/xxx/data/sgd/test10/50_50/0.5/129252/objs.pkl'

# with open(path, 'rb') as f1:
#     data = pickle.load(f1)

# X=np.array(data['exploration_para'])
# centered_matrix = X - X.mean(axis=0)
# cov = np.dot(centered_matrix.T, centered_matrix) 
# print(cov.shape)
# eigvals, eigvecs = np.linalg.eig(cov) 
# print(eigvals)
# np.savetxt('/home/xxx/data/sgd/test16/50_50/0.5/129697/eig.txt',np.real(eigvals))

# pca = PCA(n_components=2550)
# pca.fit(X)
# a=pca.singular_values_
# index=[19,21]
# print(a[[19,21]])
# from utils.essen_plot import plot_cov_hessian
# path='/home/xxx/data/sgd/test31/50_50/0.5/'
# plot_cov_hessian(path)
import torch
import copy
from utils.derivatives_of_parameters import one_hot
from utils.loss_landscape import get_loss_lst_for_diff_alpha2
import numpy as np
import torch.nn as nn
from config.config import parse_args
from data.data_loader import data_loader, get_deri_loader
from model.vgg import VGG, VGG9
from utils.essen_plot import plot_several_loss_landscape
import random
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '2'
def seed_torch(seed=1029):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)  # 为了禁止hash随机化，使得实验可复现
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

if __name__ == "__main__":

    alpha_lst = np.linspace(-1.5, 1.5, 70, endpoint=True)
    # alpha_lst=np.append(alpha_lst,np.array([0.0,1.0]))
    # alpha_lst=np.append(alpha_lst,np.linspace(1.09, 1.091, 30, endpoint=True))
    # alpha_lst=np.append(alpha_lst,np.linspace(1.0908571428571427, 1.090877551020408, 20, endpoint=True))
    alpha_lst=np.sort(alpha_lst)
    # alpha_lst=np.array([0.0,1.0])
    print(alpha_lst)
    loss_fn = nn.CrossEntropyLoss()
    args, _ = parse_args()
    args.device = torch.device("cuda:%s" % (
        args.device_rank) if torch.cuda.is_available() else "cpu")
    args.seed=1
    seed_torch(args.seed)
    train_loader, test_loader = data_loader(
        training_batch_size=args.training_batch_size, test_batch_size=args.test_batch_size, training_size=args.training_size,  data=args.data,  args=args)

    # for i in train_loader:
    #     print(i)

    path128 = '/home/xxx/data/sgd/test34/0.pth.tar'
    path2048 = '/home/xxx/data/sgd/test34/2.pth.tar'
    # path200 = '/home/xxx/data/saddle_points/test111/50_50/6/101081/model/tmp100.pth.tar'

    model1 =VGG9(True, 0.2).to(args.device)

    checkpoint = torch.load(path128,map_location=args.device)
    # print(checkpoint['state_dict'][1])
    model1.load_state_dict(checkpoint['state_dict'][0])

    model2 =VGG9(True, 0.0).to(args.device)
    checkpoint = torch.load(path2048,map_location=args.device)
    model2.load_state_dict(checkpoint['state_dict'][0]) 

    model3 =VGG9(True, 0.0).to(args.device)
    
    # model3 = Linear(args.t, args.hidden_layers_width, args.input_dim,
    #                 args.output_dim, nn.ReLU(), args.initialization).to(args.device)
    # checkpoint = torch.load(path200,map_location=args.device)
    # model3.load_state_dict(checkpoint['state_dict'][0]) 

    print('model finished')

    loss_lst_all = get_loss_lst_for_diff_alpha2(
        alpha_lst, train_loader, args.device, args, loss_fn,model1.cpu(),model2.cpu(),model3.cpu())
    path='/home/xxx/data/sgd/test38/50_50/0.5/'

    plot_several_loss_landscape(path,alpha_lst,loss_lst_all)


