#*
# @file Different utility functions
# Copyright (c) Zhewei Yao, Amir Gholami
# All rights reserved.
# This file is part of PyHessian library.
#
# PyHessian is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# PyHessian is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with PyHessian.  If not, see <http://www.gnu.org/licenses/>.
#*

from __future__ import print_function

import json
import os
import sys
import time
import re
import numpy as np
import argparse

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import timm

from torchvision import datasets, transforms
from torch.autograd import Variable

from utils import *
from density_plot import get_esd_plot
from models.resnet import resnet
from pyhessian import hessian

# Settings
parser = argparse.ArgumentParser(description='PyTorch Example')
parser.add_argument(
    '--mini-hessian-batch-size',
    type=int,
    default=200,
    help='input batch size for mini-hessian batch (default: 200)')
parser.add_argument('--hessian-batch-size',
                    type=int,
                    default=200,
                    help='input batch size for hessian (default: 200)')
parser.add_argument('--seed',
                    type=int,
                    default=1,
                    help='random seed (default: 1)')
parser.add_argument('--batch-norm',
                    action='store_false',
                    help='do we need batch norm or not')
parser.add_argument('--residual',
                    action='store_false',
                    help='do we need residual connect or not')

parser.add_argument('--cuda',
                    action='store_false',
                    help='do we use gpu or not')
parser.add_argument('--resume',
                    type=str,
                    default='',
                    help='get the checkpoint')
# distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
                    help='number of distributed processes')
parser.add_argument('--local-rank', default=-1, type=int)

args = parser.parse_args()
# set random seed to reproduce the work
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

for arg in vars(args):
    print(arg, getattr(args, arg))





### load models you need


import numpy as np
from pyhessian.utils import get_params_grad
from tqdm import tqdm  # 引入进度条库
import matplotlib.pyplot as plt

from tqdm import tqdm
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def visualize_loss_landscape(model, dataloader, eigenvectors, alpha_range, steps=10, criterion=nn.CrossEntropyLoss(), device='cuda:0', save_path=None, plot_path=None):
    model.to(device)
    criterion.to(device)

    # 获取 top1 和 top2 特征向量
    top1_vector, top2_vector = eigenvectors[0], eigenvectors[1]
    top1_vector, top2_vector = [v.to(device) for v in eigenvectors[0]], [v.to(device) for v in eigenvectors[1]]

    # 保存原始模型参数
    original_params, _ = get_params_grad(model)
    original_state_dict = {name: param.clone() for name, param in zip([name for name, _ in model.named_parameters()], original_params)}

    # 初始化损失值存储矩阵
    loss_landscape = np.zeros((steps, steps))

    total_steps = steps * steps * len(dataloader)
    pbar = tqdm(total=total_steps, desc='Processing', position=0)

    for i, alpha1 in enumerate(np.linspace(-alpha_range, alpha_range, steps)):
        for j, alpha2 in enumerate(np.linspace(-alpha_range, alpha_range, steps)):
            # 执行模型扰动
            for param, v1, v2 in zip(original_params, top1_vector, top2_vector):
                perturbation = alpha1 * v1 + alpha2 * v2
                param.data.add_(perturbation.to(device))

            # 计算损失
            loss = 0.0
            for inputs, labels in dataloader:
                inputs, labels = inputs.to(device), labels.to(device)
                with torch.no_grad():
                    outputs = model(inputs)
                loss += criterion(outputs, labels).item()
                torch.cuda.empty_cache()

                pbar.update(1)

            loss /= len(dataloader)
            loss_landscape[i, j] = loss

            # 恢复原始模型参数
            for name, param in zip([name for name, _ in model.named_parameters()], original_params):
                param.data.copy_(original_state_dict[name].to(device))

    pbar.close()

    from scipy.interpolate import griddata
    # 在此处进行插值以得到更多的数据点
    alpha_values = np.linspace(-alpha_range, alpha_range, steps)
    X, Y = np.meshgrid(alpha_values, alpha_values)
    points = np.array([X.flatten(), Y.flatten()]).T
    values = loss_landscape.flatten()

    # 插值的新网格
    new_steps = steps * 2  # 增加到原来的2倍（可调整）
    new_alpha_values = np.linspace(-alpha_range, alpha_range, new_steps)
    new_X, new_Y = np.meshgrid(new_alpha_values, new_alpha_values)
    new_points = np.array([new_X.flatten(), new_Y.flatten()]).T

    # 进行插值
    new_values = griddata(points, values, new_points, method='cubic')
    new_loss_landscape = new_values.reshape((new_steps, new_steps))

    # 绘制热图
    if plot_path:
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        # ax.set_zlim([0, 5])  # 设置纵坐标范围为 [0, 5]
        ax.plot_surface(new_X, new_Y, new_loss_landscape, cmap='viridis')
        # 从路径中解析出 epoch 编号
        filename = os.path.basename(plot_path)  # 获取文件名
        epoch = filename.split('_')[-1].replace('.png', '')  # 假设 epoch 是文件名最后一个下划线之后的数字
        
        def extract_title(plot_path):
            if re.search(r'model_distilledclip', plot_path):
                return "Distilled CLIP Loss Landscape Last Epoch"
            elif re.search(r'model_clip', plot_path):
                return "CLIP Loss Landscape Last Epoch"
            elif re.search(r'model_distilledssl', plot_path):
                return "Distilled SSL Loss Landscape Last Epoch"
            elif re.search(r'model_ssl', plot_path):
                return "SSL Loss Landscape Last Epoch"
            else:
                return "Unknown Model"

        title = extract_title(plot_path)
        print(f"The title is: {title}")
        plt.xlabel(r"$\epsilon_{1}$")  # 修改了这里
        plt.ylabel(r"$\epsilon_{2}$")  # 修改了这里
        # ax.set_zlabel("Loss")
        # plt.savefig(plot_path, bbox_inches='tight', pad_inches=0)
        plt.savefig(plot_path, bbox_inches='tight')
    else:
        return new_loss_landscape

def getData(name='cifar10', train_bs=128, test_bs=1000):
    """
    Get the dataloader
    """
    if name == 'cifar10':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        trainset = datasets.CIFAR10(root='../data',
                                    train=True,
                                    download=True,
                                    transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=train_bs,
                                                   shuffle=True)

        testset = datasets.CIFAR10(root='../data',
                                   train=False,
                                   download=False,
                                   transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=test_bs,
                                                  shuffle=False)
    if name == 'cifar10_without_dataaugmentation':
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        trainset = datasets.CIFAR10(root='../data',
                                    train=True,
                                    download=True,
                                    transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=train_bs,
                                                   shuffle=True)

        testset = datasets.CIFAR10(root='../data',
                                   train=False,
                                   download=False,
                                   transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=test_bs,
                                                  shuffle=False)
        
    if name == 'imagenet-1k_withoutdataaugmentation':
        transform_train = transforms.Compose([
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        trainset = datasets.ImageFolder(root='/workspace/sync/imagenet-1k/train',
                                    transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=train_bs,
                                                   shuffle=True, drop_last=True)

        testset = datasets.ImageFolder(root='/workspace/sync/imagenet-1k/val',
                                   transform=transform_test)
        
        import math
        from torch.utils.data import Subset
        # 计算子集大小
        subset_size = int(len(testset) * (2 ** -6))
        # 随机选择子集的索引
        indices = np.random.choice(range(len(testset)), subset_size, replace=False)
        # 创建子集
        testset = Subset(testset, indices)


        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=test_bs,
                                                  shuffle=False, drop_last=True)


    return train_loader, test_loader
# 假设已经定义了 model 和 dataloader, 以及通过您之前的代码得到的 eigenvectors
train_loader, test_loader = getData(name='imagenet-1k_withoutdataaugmentation',
                                    train_bs=args.mini_hessian_batch_size,
                                    test_bs=args.mini_hessian_batch_size)

eigendata_dir = ""
vis_save_dir = ""
device = 'cuda:0'  # 这里可以修改为您的CUDA设备
alpha_range = 1e-2
steps = 10

clip_modellist = []
distilledclip_modellist = []
ssl_modellist = []
distilledssl_modellist = []

for i in clip_modellist:
    eigendata = torch.load(os.path.join(eigendata_dir, f"model_clip_{i}_hessian_data.pth"))
    eigenvectors = eigendata["top_eigenvectors"]

    model_clip = deepcopy(model_clip_init)
    model_clip.load_state_dict(torch.load(f'/workspace/sync/Feature-Distillation/results/hessian-inspired/hessian-contrast/no-distilled/checkpoint-{i}.pth')["model"], strict=True)
    visualize_loss_landscape(model_clip, test_loader, eigenvectors, alpha_range, steps, device=device, plot_path=f"/workspace/sync/Feature-Distillation/PyHessian/results/vis/model_clip_{i}.pdf")

for i in distilledclip_modellist:
    eigendata = torch.load(os.path.join(eigendata_dir, f"model_distilledclip_{i}_hessian_data.pth"))
    eigenvectors = eigendata["top_eigenvectors"]

    model_distilledclip = deepcopy(model_distilledclip_init)
    model_distilledclip.load_state_dict(torch.load(f'/workspace/sync/Feature-Distillation/results/hessian-inspired/hessian-contrast/distilled/checkpoint-{i}.pth')["model"], strict=True)
    visualize_loss_landscape(model_distilledclip, test_loader, eigenvectors, alpha_range, steps, device=device, plot_path=f"/workspace/sync/Feature-Distillation/PyHessian/results/vis/model_distilledclip_{i}.pdf")


for i in ssl_modellist:
    eigendata = torch.load(os.path.join(eigendata_dir, f"model_ssl_{i}_hessian_data.pth"))
    eigenvectors = eigendata["top_eigenvectors"]

    model_ssl = deepcopy(model_ssl_init)
    model_ssl.load_state_dict(torch.load(f'/workspace/sync/Feature-Distillation/results/hessian-inspired/hessian-contrast/ssl/checkpoint-{i}.pth')["model"], strict=True)
    visualize_loss_landscape(model_ssl, test_loader, eigenvectors, alpha_range, steps, device=device, plot_path=f"/workspace/sync/Feature-Distillation/PyHessian/results/vis/model_ssl_{i}.pdf")


for i in distilledssl_modellist:
    eigendata = torch.load(os.path.join(eigendata_dir, f"model_distilledssl_{i}_hessian_data.pth"))
    eigenvectors = eigendata["top_eigenvectors"]

    model_distilledssl = deepcopy(model_distilledssl_init)
    model_distilledssl.load_state_dict(torch.load(f'/workspace/sync/Feature-Distillation/results/hessian-inspired/hessian-contrast/distilledssl/checkpoint-{i}.pth')["model"], strict=True)
    visualize_loss_landscape(model_distilledssl, test_loader, eigenvectors, alpha_range, steps, device=device, plot_path=f"/workspace/sync/Feature-Distillation/PyHessian/results/vis/model_distilledssl_{i}.pdf")