import torch
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
import math


def gen_heatmaps(input, heatmap_size=16, tau=0.02):
    """
    :param input: (batch_size, n_points, 2)
    :return: (batch_size, n_points, grid_size, grid_size)
    dist[i,j] = ||x[b,i,:]-y[b,j,:]||^2
    """
    batch_size, n_points, _ = input.shape
    x = torch.linspace(-1, 1, heatmap_size).to(input)
    x, y = torch.meshgrid([x, x])
    grid = torch.cat((x.reshape(-1, 1), y.reshape(-1, 1)), dim=1).reshape(1, heatmap_size * heatmap_size, 2).repeat(batch_size, 1, 1)
    input_norm = (input ** 2).sum(dim=2).unsqueeze(2)
    grid_norm = (grid ** 2).sum(dim=2).unsqueeze(1)
    dist = input_norm + grid_norm - 2 * torch.bmm(input, grid.permute(0, 2, 1))
    heatmaps = torch.exp(-dist / tau)
    return heatmaps.reshape(batch_size, n_points, heatmap_size, heatmap_size)


def heatmap2keypoints(heatmap):
    max_alongx, _ = torch.max(heatmap, dim=-1)
    max_alongy, _ = torch.max(heatmap, dim=-2)
    _, max_y_index = torch.max(max_alongx, dim=-1)
    _, max_x_index = torch.max(max_alongy, dim=-1)
    return torch.stack([max_y_index, max_x_index], dim=-1).float()


def gen_gabor_heatmaps(input, tau_inv, theta, sigma_inv, lam_inv, psi, heatmap_size=16):
    """
    :param input: (batch_size, n_points, 2)
           theta: (batch_size, n_points, 1)
           sigma_inv: (batch_size, n_points, 2)  x and y do NOT share the same sigma
           lam_inv: (batch_size, n_points, 1)
           psi: (batch_size, n_points, 1)
    :return: (batch_size, n_points, grid_size, grid_size)
    dist[i,j] = ||x[b,i,:]-y[b,j,:]||^2
    """
    batch_size, n_points, _ = input.shape
    x = torch.linspace(-1, 1, heatmap_size).to(input)
    x, y = torch.meshgrid([x, x])
    grid = torch.cat((x.reshape(-1, 1), y.reshape(-1, 1)), dim=1).reshape(1, heatmap_size * heatmap_size, 2).repeat(batch_size, 1, 1)
    diff = input.unsqueeze(-2) - grid.unsqueeze(-3)   # (batch_size, n_points, heatmap_size**2, 2)

    roted_x = diff[:, :, :, 0] * torch.cos(theta) + diff[:, :, :, 1] * torch.sin(theta)  # (batch_size, n_points, heatmap_size**2)
    roted_y = - diff[:, :, :, 0] * torch.sin(theta) + diff[:, :, :, 1] * torch.cos(theta)  # (batch_size, n_points, heatmap_size**2)
    gabor = torch.exp(- (sigma_inv[:, :, 0:1]**2 * roted_x**2 - sigma_inv[:, :, 1:2]**2 * roted_y**2) / 2) *\
        torch.cos(2*math.pi*lam_inv*roted_x + psi)

    dist = diff[:, :, :, 0] ** 2 + diff[:, :, :, 1] ** 2
    heatmaps = torch.exp(-dist * tau_inv).reshape(batch_size, n_points, heatmap_size, heatmap_size)

    gabor = gabor.reshape(batch_size, n_points, heatmap_size, heatmap_size)
    gabor_heatmaps = gabor * heatmaps

    return heatmaps, gabor, gabor_heatmaps


if __name__ == '__main__':
    import matplotlib.pyplot as plt

    heatmap_size = 4
    n_points = 8
    keypoints = torch.rand((1, n_points, 2))*2-1

    heatmaps = gen_heatmaps(keypoints, heatmap_size=heatmap_size, tau=0.01)
    plt.imshow(heatmaps[0][0])
    plt.colorbar()
    plt.show()
