import numpy as np
import torch
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from e3nn import o3


def plot_image(fig, ax,
               image: torch.tensor,
               vpad=60,
              ):
    image = torch.nn.functional.pad(image, (0, 0, vpad, vpad), 'constant', 1)
    ax.imshow(image.permute(1,2,0))
    ax.axis('off')


def plot_so3_distribution(probs: torch.Tensor,
                          rots: torch.Tensor,
                          gt_rotation=None,
                          fig=None,
                          ax=None,
                          display_threshold_probability=0.000005,
                          prob_threshold: float=0.01,
                          show_color_wheel: bool=True,
                          canonical_rotation=torch.eye(3),
                         ):
    '''
    Taken from https://github.com/google-research/google-research/blob/master/implicit_pdf/evaluation.py
    '''
    cmap = plt.cm.hsv

    def _show_single_marker(ax, rotation, marker, edgecolors=True, facecolors=False):
        alpha, beta, gamma = o3.matrix_to_angles(rotation)
        # there is some issue with the first value being all zeros, so repeat it and take second value
        color = cmap(0.5 + gamma.repeat(2) / 2. / np.pi)[-1]
        ax.scatter(alpha, beta-np.pi/2, s=2000, edgecolors=color, facecolors='none', marker=marker, linewidth=5)
        ax.scatter(alpha, beta-np.pi/2, s=1500, edgecolors='k', facecolors='none', marker=marker, linewidth=2)
        ax.scatter(alpha, beta-np.pi/2, s=2500, edgecolors='k', facecolors='none', marker=marker, linewidth=2)

    if ax is None:
        fig = plt.figure(figsize=(8, 4), dpi=400)
        fig.subplots_adjust(0.01, 0.08, 0.90, 0.95)
        ax = fig.add_subplot(111, projection='mollweide')

    rots = rots @ canonical_rotation
    scatterpoint_scaling = 3e3
    alpha, beta, gamma = o3.matrix_to_angles(rots)

    # offset alpha and beta so different gammas are visible
    R = 0.02
    alpha += R * np.cos(gamma)
    beta += R * np.sin(gamma)

    which_to_display = (probs > display_threshold_probability)


    # Display the distribution
    ax.scatter(alpha[which_to_display],
               beta[which_to_display]-np.pi/2,
               s=scatterpoint_scaling * probs[which_to_display],
               c=cmap(0.5 + gamma[which_to_display] / 2. / np.pi))
    if gt_rotation is not None:
        if len(gt_rotation.shape) == 2:
            gt_rotation = gt_rotation.unsqueeze(0)
        gt_rotation = gt_rotation @ canonical_rotation
        _show_single_marker(ax, gt_rotation, 'o')
    ax.grid()
    ax.set_xticklabels([])
    ax.set_yticklabels([])

    if show_color_wheel:
        # Add a color wheel showing the tilt angle to color conversion.
        ax = fig.add_axes([0.86, 0.17, 0.12, 0.12], projection='polar')
        theta = np.linspace(-3 * np.pi / 2, np.pi / 2, 200)
        radii = np.linspace(0.4, 0.5, 2)
        _, theta_grid = np.meshgrid(radii, theta)
        colormap_val = 0.5 + theta_grid / np.pi / 2.
        ax.pcolormesh(theta, radii, colormap_val.T, cmap=cmap)
        ax.set_yticklabels([])
        ax.set_xticklabels([r'90$\degree$', None,
                            r'180$\degree$', None,
                            r'270$\degree$', None,
                            r'0$\degree$'], fontsize=14)
        ax.spines['polar'].set_visible(False)
        plt.text(0.5, 0.5, 'Tilt', fontsize=14,
                 horizontalalignment='center',
                 verticalalignment='center', transform=ax.transAxes)

    img = plot_to_image(fig)
    plt.close(fig)
    return img

def plot_predictions(images, probs, rots, gt_rots, num=4, path=None):
    images = images.cpu()
    probs = probs.detach().cpu()
    rots = rots.cpu()
    gt_rots = gt_rots.cpu()

    fig = plt.figure(figsize=(4.8, np.ceil(num/2)), dpi=300)
    gs = GridSpec(int(np.ceil(num/2)), 4, width_ratios=[1,3,1,3], wspace=0, left=0, top=1, bottom=0, right=1)

    for i in range(num):
        ax0 = fig.add_subplot(gs[2*i])
        plot_image(fig, ax0, images[i])

        ax1 = fig.add_subplot(gs[2*i+1])
        img = plot_so3_distribution(probs[i], rots, gt_rotation=gt_rots[i],
                                    show_color_wheel=i==0)
        ax1.imshow(img)
        ax1.axis('off')

    if path is None:
        plt.show()
    else:
        plt.savefig(path)

    plt.close()


def plot_to_image(fig):
    fig.canvas.draw()
    rgb_array = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
    rgb_array = rgb_array.reshape(fig.canvas.get_width_height()[::-1]+(3,))
    plt.close(fig)
    return rgb_array

def rotate_s2(s2_signal, alpha=0, beta=0, gamma=0):
    '''alpha beta gamma in radians'''
    from spherical_conv.predictor import s2_irreps
    lmax = int(s2_signal.shape[-1]**0.5) - 1
    irreps = s2_irreps(lmax)
    alpha = torch.tensor(alpha, dtype=torch.float32)
    beta = torch.tensor(beta, dtype=torch.float32)
    gamma = torch.tensor(gamma, dtype=torch.float32)
    return torch.einsum("ij,...j->...i",
                        irreps.D_from_angles(alpha, beta, gamma),
                        s2_signal)

def show_projection(fmap_size=8):
    import colorsys
    from e3nn import o3
    from spherical_conv.predictor import s2_healpix_grid
    from matplotlib.gridspec import GridSpec

    x,y = np.meshgrid(np.arange(fmap_size), np.arange(fmap_size))
    x = x.flatten()
    y = y.flatten()
    h = np.arctan2(y-fmap_size/2,x-fmap_size/2)/(2*np.pi) + 0.5
    s = np.linalg.norm(np.stack((x-fmap_size/2-0.5,y-fmap_size/2-0.5), axis=0), axis=0)
    s = np.clip(s/(0.6*s.max()),0,1)

    img = np.zeros((fmap_size, fmap_size,3), dtype=float)
    for i in range(fmap_size):
        for j in range(fmap_size):
            img[i,j] = colorsys.hsv_to_rgb(h[i+j*fmap_size], s[i+j*fmap_size], 1)

    max_beta = np.radians(90)
    taper_beta = np.radians(75)
    kernel_grid = s2_healpix_grid(max_beta=max_beta, rec_level=2)
    xyz = o3.angles_to_xyz(*kernel_grid)

    coverage = 0.9
    max_radius = torch.linalg.norm(xyz[:,[0,2]], dim=1).max()
    sample_x = coverage * xyz[:,2] / max_radius
    sample_y = coverage * xyz[:,0] / max_radius

    f = plt.figure(figsize=(6,4), dpi=250)
    gs = GridSpec(1,3, width_ratios=[1,1,2], wspace=0.1)
    ax = f.add_subplot(gs[0])
    ax.imshow(img[:,::-1])
    ax.set_xticks([],[])
    ax.set_yticks([],[])

    sigma=0.2
    gridx, gridy = torch.meshgrid(2*[torch.linspace(-1,1,fmap_size)], indexing='ij')
    scale = 1 / np.sqrt(2 * np.pi * sigma**2)
    data = scale * torch.exp(-((gridx.unsqueeze(-1) - sample_x).pow(2) \
                               +(gridy.unsqueeze(-1) - sample_y).pow(2)) / (2*sigma**2) )
    data = data / data.sum((0,1), keepdims=True)
    betas = kernel_grid[1]
    if taper_beta < max_beta:
        mask = ((betas - max_beta)/(taper_beta - max_beta)).clamp(max=1).view(1,1,-1)
    else:
        mask = torch.ones_like(data)

    data = (mask * data).unsqueeze(0).unsqueeze(0).to(torch.float32)
    x = torch.from_numpy(img).permute(2,0,1)
    x = (x.unsqueeze(-1) * data).sum((2,3)).squeeze(0)

    lmax=6
    Y = o3.spherical_harmonics_alpha_beta(range(lmax+1), *kernel_grid, normalization='component')
    harmonics = torch.einsum('ni,yn->yi', Y, x.float()) / Y.shape[0]**0.5
    harmonics = rotate_s2(harmonics, 0, -np.pi/2)
    signal = o3.ToS2Grid(lmax, res=(100,101))(harmonics).permute(1,2,0)


    ax = f.add_subplot(gs[1])
    ax.imshow(1+0*img)
    ax.plot(0.5*fmap_size*(sample_y+1)-0.5, 0.5*fmap_size*(sample_x+1)-0.5, 'k.', markersize=3)
    ax.set_xticks([],[])
    ax.set_yticks([],[])

    lon, lat = np.meshgrid(np.linspace(-np.pi, np.pi, 101), np.linspace(-np.pi/2, np.pi/2,100))
    imgs = []
    for c in range(3):
        tmp_fig = plt.figure(figsize=(2,1), dpi=400)
        ax = tmp_fig.add_subplot(111, projection='mollweide')
        ax.pcolormesh(lon, lat, 0.2*(signal[...,c].numpy()-0.6), cmap='gray')
        ax.set_xticks([],[])
        ax.set_yticks([],[])
        plt.tight_layout(pad=0.1)
        tmp_fig.canvas.draw()
        imgs.append(np.array(tmp_fig.canvas.renderer.buffer_rgba())[...,:3].mean(2)/255)
        # plt.show()
        plt.close(tmp_fig)
    # img = imgs[-1]
    imgs = np.stack(imgs, axis=2)
    print(imgs.shape)

    ax = f.add_subplot(gs[2])
    ax.imshow(imgs)
    ax.axis('off')
    plt.show()



if __name__ == "__main__":
    show_projection()
