#*
# @file Different utility functions
# Copyright (c) Zhewei Yao, Amir Gholami
# All rights reserved.
# This file is part of PyHessian library.
#
# PyHessian is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# PyHessian is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with PyHessian.  If not, see <http://www.gnu.org/licenses/>.
#*

from __future__ import print_function

import json
import os
import sys
import time

import numpy as np
import argparse

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import timm

from torchvision import datasets, transforms
from torch.autograd import Variable

from utils import *
from density_plot import get_esd_plot
from models.resnet import resnet
from pyhessian import hessian

# Settings
parser = argparse.ArgumentParser(description='PyTorch Example')
parser.add_argument(
    '--mini-hessian-batch-size',
    type=int,
    default=200,
    help='input batch size for mini-hessian batch (default: 200)')
parser.add_argument('--hessian-batch-size',
                    type=int,
                    default=200,
                    help='input batch size for hessian (default: 200)')
parser.add_argument('--seed',
                    type=int,
                    default=1,
                    help='random seed (default: 1)')
parser.add_argument('--batch-norm',
                    action='store_false',
                    help='do we need batch norm or not')
parser.add_argument('--residual',
                    action='store_false',
                    help='do we need residual connect or not')

parser.add_argument('--cuda',
                    action='store_false',
                    help='do we use gpu or not')
parser.add_argument('--resume',
                    type=str,
                    default='',
                    help='get the checkpoint')

args = parser.parse_args()
# set random seed to reproduce the work
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

for arg in vars(args):
    print(arg, getattr(args, arg))


def getData(name='cifar10', train_bs=128, test_bs=1000):
    """
    Get the dataloader
    """
    if name == 'cifar10':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        trainset = datasets.CIFAR10(root='../data',
                                    train=True,
                                    download=True,
                                    transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=train_bs,
                                                   shuffle=True)

        testset = datasets.CIFAR10(root='../data',
                                   train=False,
                                   download=False,
                                   transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=test_bs,
                                                  shuffle=False)
    if name == 'cifar10_without_dataaugmentation':
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        trainset = datasets.CIFAR10(root='../data',
                                    train=True,
                                    download=True,
                                    transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=train_bs,
                                                   shuffle=True)

        testset = datasets.CIFAR10(root='../data',
                                   train=False,
                                   download=False,
                                   transform=transform_test)
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=test_bs,
                                                  shuffle=False)
        
    if name == 'imagenet-1k_withoutdataaugmentation':
        transform_train = transforms.Compose([
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        trainset = datasets.ImageFolder(root='/workspace/sync/imagenet-1k/train',
                                    transform=transform_train)
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=train_bs,
                                                   shuffle=True, drop_last=True)

        testset = datasets.ImageFolder(root='/workspace/sync/imagenet-1k/val',
                                   transform=transform_test)
        


        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=test_bs,
                                                  shuffle=True, drop_last=True)


    return train_loader, test_loader
# get dataset

train_loader, test_loader = getData(name='imagenet-1k_withoutdataaugmentation',
                                    train_bs=args.mini_hessian_batch_size,
                                    test_bs=args.mini_hessian_batch_size)
##############
# Get the hessian data
##############
assert (args.hessian_batch_size % args.mini_hessian_batch_size == 0)
batch_num = args.hessian_batch_size // args.mini_hessian_batch_size

if batch_num == 1:
    for inputs, labels in test_loader:
        hessian_dataloader = (inputs, labels)
        break
else:
    hessian_dataloader = []
    for i, (inputs, labels) in enumerate(test_loader):
        hessian_dataloader.append((inputs, labels))
        if i == batch_num - 1:
            break


############################################################ model things ############################################################

# define and load models you need


@timing_decorator
def cal_hessians(model, args, save_dir=None, model_name=None):
    if args.cuda:
        model = model.cuda()
    # model = torch.nn.DataParallel(model)

    criterion = nn.CrossEntropyLoss()  # label loss

    ###################
    # Get model checkpoint, get saving folder
    ###################
    # if args.resume == '':
    #     raise Exception("please choose the trained model")
    # model.load_state_dict(torch.load(args.resume))

    ######################################################
    # Begin the computation
    ######################################################

    # turn model to eval mode
    model.eval()
    if batch_num == 1:
        hessian_comp = hessian(model,
                            criterion,
                            data=hessian_dataloader,
                            cuda=args.cuda)
    else:
        hessian_comp = hessian(model,
                            criterion,
                            dataloader=hessian_dataloader,
                            cuda=args.cuda)

    print(
        '********** finish data londing and begin Hessian computation **********')

    trace = 0
    top_eigenvalues, top_eigenvectors =0,0
    top_eigenvalues, top_eigenvectors = hessian_comp.eigenvalues(top_n=2, maxIter=200, tol=1e-4)
    trace = hessian_comp.trace(maxIter=200, tol=1e-4)


    print('\n***Top Eigenvalues: ', top_eigenvalues)
    print('\n***Trace: ', np.mean(trace))


    ###################
    # Saving Results
    ###################
    if save_dir is not None and model_name is not None:
        save_path = os.path.join(save_dir, f"{model_name}_hessian_data.pth")
        
        torch.save({
            'top_eigenvalues': top_eigenvalues,
            'top_eigenvectors': top_eigenvectors,
            'trace': trace,
        }, save_path)
        
        print(f"Results saved to {save_path}")

clip_modellist = []
distilledclip_modellist = []
ssl_modellist=[]
distilledssl_modellist=[]

for i in clip_modellist:
    model_clip = deepcopy(model_clip_init)
    model_clip.load_state_dict(torch.load(f'/workspace/sync/Feature-Distillation/results/hessian-inspired/hessian-contrast/no-distilled/checkpoint-{i}.pth')["model"], strict=True)
    cal_hessians(model_clip, args, model_name=f"model_clip_{i}", save_dir="results/trace")

for i in distilledclip_modellist:
    model_distilledclip = deepcopy(model_distilledclip_init)
    model_distilledclip.load_state_dict(torch.load(f'/workspace/sync/Feature-Distillation/results/hessian-inspired/hessian-contrast/distilled/checkpoint-{i}.pth')["model"], strict=True)
    cal_hessians(model_distilledclip, args, model_name=f"model_distilledclip_{i}", save_dir="results/trace")

for i in ssl_modellist:
    model_ssl = deepcopy(model_ssl_init)
    model_ssl.load_state_dict(torch.load(f'/workspace/sync/Feature-Distillation/results/hessian-inspired/hessian-contrast/ssl/checkpoint-{i}.pth')["model"], strict=True)
    cal_hessians(model_ssl, args, model_name=f"model_ssl_{i}", save_dir="results/trace")

for i in distilledssl_modellist:
    model_distilledssl = deepcopy(model_distilledssl_init)
    model_distilledssl.load_state_dict(torch.load(f'/workspace/sync/Feature-Distillation/results/hessian-inspired/hessian-contrast/distilledssl/checkpoint-{i}.pth')["model"], strict=True)
    cal_hessians(model_distilledssl, args, model_name=f"model_distilledssl_{i}", save_dir="results/trace")


