import matplotlib.pyplot as plt
import numpy as np
import seaborn
inv = np.linalg.inv
import matplotlib
font = {'size': 20, 'family' : 'Times New Roman'}
matplotlib.rc("font", **font)
matplotlib.rcParams["text.usetex"] = True

COLORS = seaborn.color_palette('colorblind')

import torch

# for visualization
arrow_width = 0.004
fig_size = 6

def log_gaussian_fn(theta_1, theta_2, samples):
    return (-0.5 * torch.log(torch.tensor(2 * torch.pi)) - theta_2) - 0.5 * torch.exp(-2 * theta_2) * ((samples - theta_1)**2)

N = 2000
x_samples = torch.randn(N)
logp_samples = log_gaussian_fn(torch.tensor(0.0), torch.tensor(0.0), x_samples)

damping_coef = 0.01

def analytical_grad_fn(theta_1, theta_2):
    grad_1 = torch.exp(-2 * theta_2) * ( x_samples - theta_1 )
    grad_2 = -1.0 + torch.exp(-2 * theta_2) * ( x_samples - theta_1 )**2
    grad_1 = grad_1.mean()
    grad_2 = grad_2.mean()
    grad = np.array([grad_1.detach().numpy(), grad_2.detach().numpy()])
    return grad

def natural_grad_fn(theta_1, theta_2):
    grad_1 = x_samples - theta_1
    grad_2 = -0.5 + 0.5 * torch.exp(-2*theta_2) * ( x_samples - theta_1 )**2
    grad_1 = grad_1.mean()
    grad_2 = grad_2.mean()
    grad = np.array([grad_1.numpy(), grad_2.numpy()])
    return grad

def regularized_natural_grad_fn(theta_1, theta_2):
    grad_1 = (x_samples - theta_1) / (1+ damping_coef * torch.exp(2*theta_2))
    grad_2 = -1. / (2. + damping_coef) + 1.0 / (2. + damping_coef) * torch.exp(-2*theta_2) * ( x_samples - theta_1 )**2
    grad_1 = grad_1.mean()
    grad_2 = grad_2.mean()
    grad = np.array([grad_1.numpy(), grad_2.numpy()])
    return grad

def empirical_fisher_grad_fn(theta_1, theta_2):
    from torch.func import vmap, grad
    theta_1, theta_2 = torch.tensor(theta_1, requires_grad=True, dtype=torch.float32), torch.tensor(theta_2, requires_grad=True, dtype=torch.float32)

    def compute_logp(params, per_sample):
        per_logp = log_gaussian_fn(params[0], params[1], per_sample)
        return per_logp

    ft_compute_sample_grad = vmap(grad(compute_logp), in_dims=(None, 0))
    ft_per_sample_grads = ft_compute_sample_grad((theta_1, theta_2), x_samples) # num_samples x param_shape

    with torch.no_grad():
        H = torch.stack([ft_per_sample_grads[0], ft_per_sample_grads[1]], dim=1)  # num_samples x num_params
        fisher = H.t() @ H / N

    fisher_inv = torch.inverse(fisher + damping_coef * torch.eye(2))

    log_grad1 = ft_per_sample_grads[0].mean().detach()
    log_grad2 = ft_per_sample_grads[1].mean().detach()
    log_grad = torch.stack([log_grad1, log_grad2])
    natural_grad = fisher_inv @ log_grad
    return natural_grad.numpy()

def adv_grad_fn(theta_1, theta_2):
    from torch.func import vmap, grad
    theta_1, theta_2 = torch.tensor(theta_1, requires_grad=True, dtype=torch.float32), torch.tensor(theta_2, requires_grad=True, dtype=torch.float32)

    def compute_logp(params, per_sample):
        per_logp = log_gaussian_fn(params[0], params[1], per_sample)
        return per_logp

    ft_compute_sample_grad = vmap(grad(compute_logp), in_dims=(None, 0))
    ft_per_sample_grads = ft_compute_sample_grad((theta_1, theta_2), x_samples) # num_samples x param_shape

    logp = log_gaussian_fn(theta_1, theta_2, x_samples)

    with torch.no_grad():
        H = torch.stack([ft_per_sample_grads[0], ft_per_sample_grads[1]], dim=1)  # num_samples x num_params
        HHT = H @ H.t() 
        psuedo_adv = torch.mv( torch.inverse(HHT + damping_coef * torch.eye(N)), torch.ones(N) )

    ng_loss = (logp * psuedo_adv).mean()
    final_grad1, final_grad2 = torch.autograd.grad(ng_loss, (theta_1, theta_2))
    final_grad = torch.stack([final_grad1, final_grad2])

    vanilla_final_grad1 = ft_per_sample_grads[0].mean().detach()
    vanilla_final_grad2 = ft_per_sample_grads[1].mean().detach()
    vanilla_final_grad = torch.stack([vanilla_final_grad1, vanilla_final_grad2])

    return final_grad.numpy(), vanilla_final_grad.numpy()

xlimits = [-1.0, 1.0]
ylimits = [-0.1, 1.0]
numticks = 100
x_mesh = np.linspace(*xlimits, num=numticks)
y_mesh = np.linspace(*ylimits, num=numticks)
X_mesh, Y_mesh = np.meshgrid(x_mesh, y_mesh)
x_coords = torch.from_numpy(X_mesh.ravel())
y_coords = torch.from_numpy(Y_mesh.ravel())
zs = np.array([log_gaussian_fn(x_c, y_c, x_samples).mean().numpy() for (x_c, y_c) in zip(x_coords, y_coords)])
Z = zs.reshape(X_mesh.shape)

grad_x_mesh = np.linspace(*xlimits, num=10)
grad_y_mesh = np.linspace(*ylimits, num=10)
grad_X_mesh, grad_Y_mesh = np.meshgrid(grad_x_mesh, grad_y_mesh)
grad_x_coords = torch.from_numpy(grad_X_mesh.ravel())
grad_y_coords = torch.from_numpy(grad_Y_mesh.ravel())

analytical_grad_xy = []
for (x_c, y_c) in zip(grad_x_coords, grad_y_coords):
    grad = analytical_grad_fn(x_c, y_c)
    analytical_grad_xy.append(grad)
analytical_grad_xy = np.array(analytical_grad_xy)

natural_grad_xy = []
for (x_c, y_c) in zip(grad_x_coords, grad_y_coords):
    # grad = natural_grad_fn(x_c, y_c)
    # grad = regularized_natural_grad_fn(x_c, y_c)
    grad = empirical_fisher_grad_fn(x_c, y_c)
    natural_grad_xy.append(grad)
natural_grad_xy = np.array(natural_grad_xy)

fig = plt.figure(constrained_layout=True, figsize=(fig_size, fig_size * 1), dpi=300)
gs = fig.add_gridspec(1, 1, width_ratios=[1], height_ratios=[1])

ax = fig.add_subplot(gs[0])

ax.quiver(grad_x_coords, grad_y_coords, natural_grad_xy[:, 0], natural_grad_xy[:, 1], color='blue', width=arrow_width, headwidth=7)

adv_grad_xy, grad_xy = [], []
for (x_c, y_c) in zip(grad_x_coords, grad_y_coords):
    adv_grad, normal_grad = adv_grad_fn(x_c, y_c)
    adv_grad_xy.append(adv_grad)
    grad_xy.append(normal_grad)
adv_grad_xy = np.array(adv_grad_xy)
grad_xy = np.array(grad_xy)

for (estimated_grad, analytical_grad) in zip(grad_xy, analytical_grad_xy):
    cos_sim = np.dot(estimated_grad, analytical_grad) / ( np.linalg.norm(estimated_grad) * np.linalg.norm(analytical_grad) )
    assert cos_sim > 0.99, f"Cosine similarity check failed: {cos_sim}"

ax.quiver(grad_x_coords, grad_y_coords, adv_grad_xy[:, 0], adv_grad_xy[:, 1], color='red', width=arrow_width)
ax.quiver(grad_x_coords, grad_y_coords, grad_xy[:, 0], grad_xy[:, 1], color='grey', width=arrow_width)
ax.contour(X_mesh, Y_mesh, Z, levels=50, alpha=0.3)
ax.scatter(0.0, 0.0, s=40, marker='*') 

# ax.set_title(f"Interp. Coef.: {val}", size=20)
ax.set_xlabel(r"$\theta_1 = \mu$")
ax.set_ylabel(r"$\theta_2 = \log \sigma$")

plt.savefig('vis_adv_ng.pdf', bbox_inches='tight')