import os
from pathlib import Path
from argparse import ArgumentParser

import torch
from matplotlib import pyplot as plt
import matplotlib as mlp
import math

from torchvision import models, transforms as T
import numpy as np
import pandas as pd
from pyhessian import hessian

import PIL

TORCH_VERSION = torch.__version__

from tqdm import tqdm

from copy import deepcopy
import utils

from utils import get_train_dataloader_for_eval_with_inds, get_test_dataloader_with_inds




def main(args, path, model, device = "cpu"):
    #get_attack_activations = get_attack_activations_function("channel")
    result_path = ""
    if not args.pretrained:
        model_path = f"{path}/final_model.pt"
        model.load_state_dict(torch.load(model_path))
        result_path = f"{path}/hessian_attacked"
    else:
        result_path = f"{path}/hessian_pretrained"

    #Model
    if args.cuda:
        print("on GPU")
        model = torch.nn.DataParallel(model)
    
    model.eval()

    #Data 
    test_set = get_test_dataloader_with_inds(imagenet_dir = args.imagenet, batch_size=args.batch_size)
    train_set = get_train_dataloader_for_eval_with_inds(imagenet_dir=args.imagenet, batch_size = args.batch_size)
    #Get a single batch!!!
    for X, y, inds in train_set:
        break

    criterion = torch.nn.CrossEntropyLoss()

    #Hessian object
    hessian_comp = hessian(model, criterion, data = (X.to(device), y.to(device)), cuda=args.cuda)

    density_eigen, density_weight = hessian_comp.density()

    pd.DataFrame({"eigen": density_eigen[0], "weights": density_weight[0]}).to_csv(f"{result_path}/eigen_density.csv")

    get_esd_plot(density_eigen, density_weight, result_path)


    return None

#From Pyhessian Lib
def get_esd_plot(eigenvalues, weights, path):
    density, grids = density_generate(eigenvalues, weights)
    plt.semilogy(grids, density + 1.0e-7)
    plt.ylabel('Density (Log Scale)', fontsize=14, labelpad=10)
    plt.xlabel('Eigenvalue', fontsize=14, labelpad=10)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.axis([np.min(eigenvalues) - 1, np.max(eigenvalues) + 1, None, None])
    plt.tight_layout()
    plt.savefig(f'{path}/density_eigen.pdf')

#From Pyhessian Lib
def density_generate(eigenvalues,
                     weights,
                     num_bins=10000,
                     sigma_squared=1e-5,
                     overhead=0.01):

    eigenvalues = np.array(eigenvalues)
    weights = np.array(weights)

    lambda_max = np.mean(np.max(eigenvalues, axis=1), axis=0) + overhead
    lambda_min = np.mean(np.min(eigenvalues, axis=1), axis=0) - overhead

    grids = np.linspace(lambda_min, lambda_max, num=num_bins)
    sigma = sigma_squared * max(1, (lambda_max - lambda_min))

    num_runs = eigenvalues.shape[0]
    density_output = np.zeros((num_runs, num_bins))

    for i in range(num_runs):
        for j in range(num_bins):
            x = grids[j]
            tmp_result = gaussian(eigenvalues[i, :], x, sigma)
            density_output[i, j] = np.sum(tmp_result * weights[i, :])
    density = np.mean(density_output, axis=0)
    normalization = np.sum(density) * (grids[1] - grids[0])
    density = density / normalization
    return density, grids


def gaussian(x, x0, sigma_squared):
    return np.exp(-(x0 - x)**2 /
                  (2.0 * sigma_squared)) / np.sqrt(2 * np.pi * sigma_squared)


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--results-directory", type=str)
    parser.add_argument("--pretrained", action="store_true")
    parser.add_argument("--imagenet", type = str, default="/home/shared_data/imagenet")
    parser.add_argument("--arch", type = str, default="alexnet")
    parser.add_argument("--batch_size", type = int, default = 256)
    parser.add_argument('--cuda',
                    action='store_true')

    args = parser.parse_args()

    if args.arch == 'vgg19':
        print('Using VGG19!')
        model = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1)

    elif args.arch == "alexnet":
        print('using AlexNet!')
        model = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1)
    elif args.arch == "resnet50":
        print('using Resnet50!')
        model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
    elif args.arch == "resnet18":
        print('using Resnet18!')
        model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
    elif args.arch == "resnet34":
        print('using Resnet34!')
        model = models.resnet34(weights=models.resnet34)
    elif args.arch =='efficientnet':
        print('using EfficientNet!')
        model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)

    results_directory = args.results_directory

    ensuredir = lambda directory: Path(directory).mkdir(parents=True, exist_ok=True)

    ensuredir(os.path.join(results_directory, "hessian_pretrained"))
    ensuredir(os.path.join(results_directory, "hessian_attacked"))

    main(args, f"{results_directory}", model, device = "cuda" if args.cuda else "cpu")
