import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from scipy.ndimage import gaussian_filter
from scipy.stats import gaussian_kde
from Trace import TakeGradient

# Define the function
def GradNormHM(func, fig_name, data_name):
    # Create a high-resolution grid of (x, y) values
    x = torch.linspace(-3, 11, 100)  # Increase the number of points for smoother rendering
    y = torch.linspace(-3, 11, 100)
    X, Y = torch.meshgrid(x, y)

    # Compute function values
    Z = np.zeros(shape=(len(x), len(y)))
    for i in range(len(x)):
        for j in range(len(y)):
            gradient = TakeGradient(torch.tensor([x[i], y[j]], requires_grad=True), func)
            #print(gradient.detach())
            #print("grad norm is", torch.norm(gradient, p='fro'))
            Z[i][j] = torch.norm(gradient, p='fro').detach().numpy()

    #Z = Z.numpy()
    # Save Z to a cvs file
    np.savetxt(data_name, Z, delimiter=',')
    # Apply Gaussian filter for smoothing
    Z_smooth = gaussian_filter(Z, sigma=3)  # Adjust sigma for more or less smoothing

    # Plot the smoothed heatmap
    plt.figure(figsize=(8, 8))
    plt.pcolormesh(X, Y, Z, shading='auto', cmap='jet')
    plt.colorbar(label='Function Value')
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.title('Gradient Norm Heatmap')
    plt.savefig(fig_name)
    plt.close()