import torch
import copy
from utils.derivatives_of_parameters import one_hot
from utils.loss_landscape import get_loss_lst_for_diff_alpha
import numpy as np
import torch.nn as nn
from config.config import parse_args
from data.data_loader import data_loader, get_deri_loader
from model.vgg import VGG, VGG9
from utils.essen_plot import plot_several_loss_landscape
import random
import os

def seed_torch(seed=1029):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)  # 为了禁止hash随机化，使得实验可复现
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

if __name__ == "__main__":

    alpha_lst = np.linspace(-1.5, 1.5, 70, endpoint=True)
    alpha_lst=np.append(alpha_lst,np.array([0.0,1.0]))
    alpha_lst=np.append(alpha_lst,np.linspace(1.09, 1.091, 30, endpoint=True))
    alpha_lst=np.append(alpha_lst,np.linspace(1.0908571428571427, 1.090877551020408, 20, endpoint=True))
    alpha_lst=np.sort(alpha_lst)
    # alpha_lst=np.array([0.0,1.0])
    print(alpha_lst)
    loss_fn = nn.CrossEntropyLoss()
    args, _ = parse_args()
    args.device = torch.device("cuda:%s" % (
        args.device_rank) if torch.cuda.is_available() else "cpu")
    args.seed=1
    seed_torch(args.seed)
    train_loader, test_loader = data_loader(
        training_batch_size=args.training_batch_size, test_batch_size=args.test_batch_size, training_size=args.training_size,  data=args.data,  args=args)

    # for i in train_loader:
    #     print(i)

    path128 = '/home/xxx/data/sgd/test38/50_50/0.5/20210923124441528322/model/tmp290.pth.tar'
    path2048 = '/home/xxx/data/sgd/test38/50_50/0.5/20210923133959682181/model/tmp290.pth.tar'
    # path200 = '/home/xxx/data/saddle_points/test111/50_50/6/101081/model/tmp100.pth.tar'

    model1 =VGG9(True, 0.2).to(args.device)

    checkpoint = torch.load(path128,map_location=args.device)
    # print(checkpoint['state_dict'][1])
    model1.load_state_dict(checkpoint['state_dict'][0])

    model2 =VGG9(True, 0.0).to(args.device)
    checkpoint = torch.load(path2048,map_location=args.device)
    model2.load_state_dict(checkpoint['state_dict'][0]) 

    model3 =VGG9(True, 0.0).to(args.device)
    
    # model3 = Linear(args.t, args.hidden_layers_width, args.input_dim,
    #                 args.output_dim, nn.ReLU(), args.initialization).to(args.device)
    # checkpoint = torch.load(path200,map_location=args.device)
    # model3.load_state_dict(checkpoint['state_dict'][0]) 

    print('model finished')

    loss_lst_all = get_loss_lst_for_diff_alpha(
        alpha_lst, train_loader, args.device, args, loss_fn,model1.cpu(),model2.cpu(),model3.cpu())
    path='/home/xxx/data/sgd/test38/50_50/0.5/'

    plot_several_loss_landscape(path,alpha_lst,loss_lst_all)