import matplotlib.pyplot as plt
import numpy as np
import torch

from torch.nn import functional as F


def make_pyramid_new(img, max_scales=3):
    def avg2d(x):
        bs, c, w, h = x.shape
        kernel = torch.tensor([[0.25, 0.25], [0.25, 0.25]]).unsqueeze(0).unsqueeze(0).expand(c, 1, 2, 2).to(x.device)
        return F.conv2d(x, kernel, stride=2, groups=c)
    # img: [B, ch, size, size]
    imgs = []
    current = img
    imgs.append(current)
    for level in range(max_scales-1):
        current = avg2d(current)
        imgs.append(current)
    return imgs


def make_pyramid_cf(img, max_scales=3, heat_iters=1, tau=0.5, laplacian=True):
    # img: [B, ch, size, size]
    imgs = []
    if laplacian:
        laps = []
    current = img
    imgs.append(current)
    for level in range(max_scales-1):
        blurreddown = downsample(current, iters=heat_iters, tau=tau)
        if laplacian:
            laps.append(current - upsample(blurreddown))
        current = blurreddown
        imgs.append(current)
    if laplacian:
        return imgs, laps
    else:
        return imgs


def make_pyramid_vv(img, max_scales=3, kernel=None, gauss_size=3, gauss_sigma=1.0, tau=0.5, laplacian=True):
    # img: [B, ch, size, size]
    # If max_scales == 3:
    #     imgs = [img                , ds(img)                    , ds(ds(img))],
    #     laps = [img - up( ds(img) ), ds(img) - up( ds(ds(img)) ), None]
    imgs = []
    if laplacian:
        laps = []
    # Kernel
    if kernel is None:
        kernel = gauss_kernel(img.shape[1], img.device, gauss_size, gauss_sigma)
    # Make images and laplacians
    current = img
    for s in range(max_scales-1):
        imgs.append(current)
        blurreddown = lap_downsample(current, kernel)
        if laplacian:
            laps.append(current - lap_upsample(blurreddown, kernel))
        current = blurreddown
    imgs.append(current)
    if laplacian:
        laps.append(None)
        return imgs, laps
    else:
        return imgs


def visualize_scales(imgs, mean=0, std=1, max_n=10):
    # imgs <list (scales)>[B, ch, size, size]
    levels = len(imgs)
    for l in range(levels):
        imgs_l = imgs[l]
        if imgs_l is not None:
            plt.subplot(1, levels, l+1)
            _, ch, size, _ = imgs_l.shape
            imgs_l = (imgs_l*std + mean)[:max_n].permute(0, 2, 3, 1).reshape(-1, size, ch)
            if ch == 1:
                plt.imshow(imgs_l[..., 0], cmap='gray')
            else:
                plt.imshow(imgs_l)
            # plt.axis('off')
    plt.show()


def vis_imgs_laps(imgs, laps, max_n=10, save_path=None, vmin=-1, vmax=1):
    # imgs <list (scales)>[B, ch, size, size]
    plt.figure(figsize=(10, 10))
    levels = len(imgs)
    for l in range(levels):
        imgs_l = imgs[l]
        laps_l = laps[l] if l in laps.keys() else None
        plt.subplot(1, levels, l+1)
        _, ch, size, _ = imgs_l.shape
        imgs_l = imgs_l[:max_n].permute(0, 2, 3, 1).reshape(-1, size, ch)
        if laps_l is not None:
            laps_l = laps_l[:max_n].permute(0, 2, 3, 1).reshape(-1, size, ch)
            imgs_l = torch.cat([imgs_l, vmin*torch.ones(imgs_l.shape[0], 1, ch), vmax*torch.ones(imgs_l.shape[0], 1, ch), vmin*torch.ones(imgs_l.shape[0], 1, ch), laps_l], dim=1)
        if ch == 1:
            plt.imshow(imgs_l[..., 0], cmap='gray', vmin=vmin, vmax=vmax)
        else:
            plt.imshow(imgs_l, vmin=vmin, vmax=vmax)
    plt.colorbar()
    if save_path is None:
        plt.show()
    else:
        plt.savefig(save_path, bbox_inches='tight', pad_inches=0.1)
        plt.clf()
        plt.close()


def heat(X, time=0, iters=0, tau=0.5):
    assert time>0 or iters>0, 'one of time or iters must be nonzero'
    assert not (time>0 and iters>0), 'only one of time or iters may be nonzero'

    _, c, w, h = X.shape
    n = min(w,h)
    dx = 1/n
    dt = tau*dx**2
    assert 0<tau<=1.

    if time>0:
        iters = int(np.ceil(time/dt))

    L = torch.tensor([[0, tau/4, 0],
                      [tau/4, 1-tau, tau/4],
                      [0, tau/4, 0]], dtype=torch.float32, device=X.device)
    L = L.unsqueeze_(0).unsqueeze_(0).expand(c, 1, 3, 3)

    Xn = X.clone()
    T = 0
    for k in range(iters):
        Xn = F.conv2d(F.pad(Xn, (1,1,1,1), mode='reflect'), L, groups=c)
        T += dt

    return Xn, T


def pick(X):
    c = X.size(1)
    A = torch.tensor([[1,0],
                      [0,0]], dtype=torch.float32, device=X.device)
    A = A.unsqueeze_(0).unsqueeze_(0).expand(c,1,2,2)

    Y = F.conv2d(X, A, stride=2, groups=c)

    return Y


def place(X):
    c = X.size(1)
    A = torch.tensor([[1,1],
                      [1,1]], dtype=torch.float32, device=X.device)
    A = A.unsqueeze_(0).unsqueeze_(0).expand(c,1,2,2)

    Y = F.conv_transpose2d(X, A, stride=2, groups=c)

    return Y


def downsample(X, iters=1, tau=0.5):
    Y = heat(X, iters=iters, tau=tau)[0]
    return pick(Y)


upsample = place


# https://gist.github.com/alper111/b9c6d80e2dba1ee0bfac15eb7dad09c8
# https://github.com/mtyka/laploss/blob/master/laploss.py
def gauss_kernel(channels, device=torch.device('cpu'), size=5, sigma=1.0):
    grid = np.float32(np.mgrid[0:size,0:size].T)
    gaussian = lambda x: np.exp((x - size//2)**2/(-2*sigma**2))**2
    kernel = np.sum(gaussian(grid), axis=2)
    kernel /= np.sum(kernel)
    kernel = torch.from_numpy(kernel).repeat(channels, 1, 1, 1).to(device)
    return kernel


def conv_gauss(img, kernel, pad_size=2):
    img = torch.nn.functional.pad(img, (pad_size, pad_size, pad_size, pad_size), mode='reflect')
    out = torch.nn.functional.conv2d(img, kernel, groups=img.shape[1])
    return out


def lap_upsample(x, kernel=None):
    # x = [B, ch, size, size]
    bs, ch, size, _ = x.shape
    if kernel is None:
        kernel = gauss_kernel(channels=x.shape[1], device=x.device)
    # Upsample by copying values
    # e.g. [[1, 2],  -> [[1, 1, 2, 2],
    #       [3, 4]]      [1, 1, 2, 2],
    #                    [3, 3, 4, 4],
    #                    [3, 3, 4, 4]]
    x_up = x.reshape(bs, ch, -1, 1).repeat(1, 1, 1, 2).reshape(bs, ch, size, size*2).permute(0, 1, 3, 2).reshape(bs, ch, -1, 1).repeat(1, 1, 1, 2).reshape(bs, ch, size*2, size*2).permute(0, 1, 3, 2)
    # Blur
    return conv_gauss(x_up, kernel, kernel.shape[-1]//2)


def lap_downsample(x, kernel=None):
    if kernel is None:
        kernel = gauss_kernel(channels=x.shape[1], device=x.device)
    return conv_gauss(x, kernel, kernel.shape[-1]//2)[:, :, ::2, ::2]


def test_updown_sample_on_noise(noise='randn', up=True, kernel_size=3, dim=64, plot_noise=True):
    n = 100
    s_dim = int(dim*2) if up else int(dim/2)
    if noise == 'randn':
        # Random Gaussian noise
        x1 = torch.randn(n, 3, dim, dim)
        x2 = torch.randn(n, 3, s_dim, s_dim)
    elif noise == 'rand':
        # Random Uniform noise
        x1 = torch.rand(n, 3, dim, dim)
        x2 = torch.rand(n, 3, s_dim, s_dim)
    # Upsample/Downsample
    k = gauss_kernel(3, size=kernel_size)
    if up:
        x11 = lap_upsample(x1, k)
    else:
        x11 = lap_downsample(x1, k)
    # Plot
    if plot_noise:
        plt.hist(x1.reshape(-1), bins=100, alpha=0.5, density=True, label="noise")
    plt.hist(x11.reshape(-1), bins=100, alpha=0.5, density=True, label=(f"upsampled k={kernel_size}" if up else f"downsampled k={kernel_size}"))
    # plt.hist((x11+x2).reshape(-1), bins=100, alpha=0.5, density=True, label="noise + sampled noise")
    plt.legend()
    # plt.show()


def make_pyramid_from_image(image_path, save_dir=None, max_scales=3, heat_iters=1, tau=0.5):
    if save_dir is None:
        save_dir = os.path.dirname(image_path)
    img = torch.from_numpy(imageio.imread(image_path)).float().div(255).permute(2, 0, 1).unsqueeze(0)
    imgs, laps = make_pyramid_cf(img, max_scales=max_scales, heat_iters=heat_iters, tau=tau, laplacian=True)
    imgs = [(img.permute(0, 2, 3, 1)[0].numpy()*255.0).astype('uint8') for img in imgs]
    for i, img in enumerate(imgs):
        plt.imshow(img)
        plt.axis('off')
        plt.savefig(os.path.join(save_dir, f'img_{i}.png'), bbox_inches='tight', pad_inches=0)
        plt.clf()
        plt.close()
    laps = [lap.permute(0, 2, 3, 1)[0].numpy() for lap in laps]
    lap_min = min([lap.min() for lap in laps])
    lap_max = max([lap.max() for lap in laps])
    for i, lap in enumerate(laps):
        plt.imshow((lap - lap_min)/(lap_max - lap_min))
        plt.axis('off')
        plt.savefig(os.path.join(save_dir, f'lap_{i}.png'), bbox_inches='tight', pad_inches=0)
        plt.clf()
        plt.close()
    return imgs, laps
