import pywt
import torch
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm


def get_jacobian(net, x, noutputs):
    x = x.squeeze().repeat(noutputs, 1)
    x.requires_grad_(True)
    y = net(x)
    y.backward(torch.eye(noutputs).to(x.device))
    return x.grad.data


def avg2d(x):
    bs, c, w, h = x.shape
    kernel = torch.tensor([[0.25, 0.25], [0.25, 0.25]]).unsqueeze(0).unsqueeze(0).expand(c, 1, 2, 2).to(x.device)
    return F.interpolate(F.conv2d(x, kernel, stride=2, groups=c), [w, h], mode='nearest')


def new(x):
    return x - 2*avg2d(x)


class Laplacian(nn.Module):
    def __init__(self, tau=0.5, sh=[3, 8, 8], device=torch.device('cpu'), mode='mean_sub'):
        super().__init__()
        assert 0 < tau <= 1.0
        assert mode in ['mean_sub_add', 'mean_sub', 'laplacian', 'up', 'avg', 'new']
        self.tau = tau
        self.sh = sh
        self.device = device
        self.mode = mode
        self.c, self.w, self.h = sh
        L = torch.tensor([[0, tau/4, 0],
                          [tau/4, 1-tau, tau/4],
                          [0, tau/4, 0]], dtype=torch.float32).unsqueeze(0).unsqueeze(0).expand(self.c, 1, 3, 3).to(device)
        self.register_buffer('L', L)
    def forward(self, X):
        # X : BxN => B,c,w,h
        assert X.shape[1] == self.c * self.w * self.h
        x = X.reshape([len(X), *self.sh])
        if self.mode in ['mean_sub_add', 'mean_sub'] :
            x_mean = x.mean(dim=1).unsqueeze(1)
            x = x - x_mean
        Bx = F.conv2d(F.pad(x, (1, 1, 1, 1), mode='reflect'), self.L, groups=self.c)
        # Bx = F.conv2d(F.pad(x, (1, 1, 1, 1), mode='constant'), self.L, groups=self.c)
        DBx = F.interpolate(Bx, size=[int(self.w*self.tau), int(self.h*self.tau)], mode='nearest')
        # if self.mode is not 'laplacian':
        #     DBx_mean = DBx.mean(dim=1).unsqueeze(1)
        #     DBx = DBx - DBx_mean
        UDBx = F.interpolate(DBx, size=[self.w, self.h], mode='bicubic', align_corners=False).clamp_(0, 1)
        # UDBx = F.interpolate(DBx, size=[self.w, self.h], mode='nearest')
        if self.mode == 'mean_sub_add':
            return X - UDBx.reshape(len(X), -1), Bx, DBx, UDBx
        elif self.mode == 'mean_sub':
            return (x - UDBx).reshape(len(X), -1)
        elif self.mode == 'laplacian':
            return X - UDBx.reshape(len(X), -1)
        elif self.mode == 'up':
            return UDBx.reshape(len(X), -1)
        elif self.mode == 'avg':
            return avg2d(x).reshape(len(X), -1)
        elif self.mode == 'new':
            return new(x).reshape(len(X), -1)


# https://github.com/t-vi/pytorch-tvmisc/blob/master/misc/2D-Wavelet-Transform.ipynb
class MyWaveletTx(nn.Module):
    def __init__(self, sh=[3, 8, 8], device=torch.device('cpu'), levels=1, inv=False, wavelet='bior2.2'):
        super().__init__()
        self.sh = sh
        self.device = device
        self.levels = levels
        self.inv = inv
        self.wavelet = wavelet
        w = pywt.Wavelet(wavelet)
        # pyplot.plot(w.dec_hi[::-1], label="dec hi")
        # pyplot.plot(w.dec_lo[::-1], label="dec lo")
        # pyplot.plot(w.rec_hi, label="rec hi")
        # pyplot.plot(w.rec_lo, label="rec lo")
        # pyplot.title("Bior 2.2 Wavelets")
        # pyplot.legend()
        dec_hi = torch.tensor(w.dec_hi[::-1])
        dec_lo = torch.tensor(w.dec_lo[::-1])
        rec_hi = torch.tensor(w.rec_hi)
        rec_lo = torch.tensor(w.rec_lo)
        filters = torch.stack([dec_lo.unsqueeze(0)*dec_lo.unsqueeze(1),
                               dec_lo.unsqueeze(0)*dec_hi.unsqueeze(1),
                               dec_hi.unsqueeze(0)*dec_lo.unsqueeze(1),
                               dec_hi.unsqueeze(0)*dec_hi.unsqueeze(1)], dim=0)
        inv_filters = torch.stack([rec_lo.unsqueeze(0)*rec_lo.unsqueeze(1),
                                   rec_lo.unsqueeze(0)*rec_hi.unsqueeze(1),
                                   rec_hi.unsqueeze(0)*rec_lo.unsqueeze(1),
                                   rec_hi.unsqueeze(0)*rec_hi.unsqueeze(1)], dim=0)
        self.register_buffer('filters', filters)
        self.register_buffer('inv_filters', inv_filters)
    def wt(self, img, levels):
        h = img.size(2)
        w = img.size(3)
        if 'bior' in self.wavelet:
            img = torch.nn.functional.pad(img, (2, 2, 2, 2))
        elif self.wavelet == 'sym2':
            img = torch.nn.functional.pad(img, (1, 1, 1, 1))
        final = []
        for ch in range(img.shape[1]):
            res = torch.nn.functional.conv2d(img[:, ch:ch+1], self.filters[:, None], stride=2)
            if levels > 1:
                res[:, :1] = self.wt(res[:, :1], levels=levels-1)
            res = res.reshape(-1, 2, h//2, w//2).transpose(1, 2).reshape(-1, 1, h, w)
            final.append(res)
        res = torch.cat(final, dim=1)
        return res
    def iwt(self, img, levels):
        h = img.size(2)
        w = img.size(3)
        final = []
        for ch in range(img.shape[1]):
            res = img[:, ch:ch+1].reshape(-1, h//2, 2, w//2).transpose(1, 2).reshape(-1, 4, h//2, w//2).clone()
            if levels > 1:
                res[:, :1] = self.iwt(res[:, :1], levels=levels-1)
            res = torch.nn.functional.conv_transpose2d(res, self.inv_filters[:, None], stride=2)
            print("conv2dtranspose", res.shape)
            if 'bior' in self.wavelet:
                res = res[:, :, 2:-2, 2:-2]
            elif self.wavelet == 'sym2':
                res = res[:, :, 1:-1, 1:-1]
            final.append(res)
        return torch.cat(final, dim=1)
    def forward(self, img):
        x = img.reshape([len(img), *self.sh])
        if self.inv:
            return self.iwt(x, self.levels).reshape(len(img), -1)
        else:
            return self.wt(x, self.levels).reshape(len(img), -1)


from pytorch_wavelets import DWTForward, DWTInverse
class WaveletTx(nn.Module):
    def __init__(self, sh=[3, 8, 8], device=torch.device('cpu'), levels=1, wavelet='haar'):
        super().__init__()
        self.sh = sh
        self.device = device
        self.levels = levels
        self.wavelet = wavelet
        self.dwt = DWTForward(J=levels, mode='zero', wave=wavelet)
    def wt(self, img):
        Yl, Yh = self.dwt(img)
        Yl = Yl.unsqueeze(2)
        Yh = Yh[0]
        Y = torch.cat([Yl, Yh], dim=2)
        return Y
    def forward(self, img):
        x = img.reshape([len(img), *self.sh])
        return self.wt(x).reshape(len(img), -1)


# WT
device = torch.device('cpu')
n, c, w, h = 10, 1, 64, 64
levels = 1
wavelet = 'haar'
Wt = WaveletTx(sh=[c, w, h], device=device, levels=levels, wavelet=wavelet)
X = torch.rand(n, c, w, h).reshape(n, -1)
dets = []
for x in tqdm(X):
    x = x.unsqueeze(0).to(device)
    # y = Lap(x)
    dets.append(torch.det(get_jacobian(Wt, x, x.shape[-1])).item())

# MyWT
device = torch.device('cpu')
n, c, w, h = 10, 1, 64, 64
levels = 1
inv = False
wavelet = 'haar'
Wt = MyWaveletTx(sh=[c, w, h], device=device, levels=levels, inv=inv, wavelet=wavelet)
X = torch.rand(n, c, w, h).reshape(n, -1)
dets = []
for x in tqdm(X):
    x = x.unsqueeze(0).to(device)
    # y = Lap(x)
    dets.append(torch.det(get_jacobian(Wt, x, x.shape[-1])).item())


# Lap NEW
device = torch.device('cpu')
tau = 0.5
mode = 'new'
n, c, w, h = 10, 1, 64, 64
Lap = Laplacian(tau=tau, sh=[c, w, h], device=device, mode=mode)
X = torch.rand(n, c, w, h).reshape(n, -1)
dets = []
for x in tqdm(X):
    x = x.unsqueeze(0).to(device)
    # y = Lap(x)
    dets.append(torch.det(get_jacobian(Lap, x, x.shape[-1])).item())

# Lap mean_sub_add
device = torch.device('cuda')
tau = 0.5
mode = 'mean_sub_add'
n, c, w, h = 1, 1, 64, 64
Lap = Laplacian(tau=tau, sh=[c, w, h], device=device, mode=mode)
X = torch.rand(n, c, w, h).reshape(n, -1)
dets = []
for x in tqdm(X):
    x = x.unsqueeze(0).to(device)
    # y = Lap(x)
    dets.append(torch.det(get_jacobian(Lap, x, x.shape[-1])).item())


modes = ['mean_sub_add', 'mean_sub', 'laplacian']
n, c, w, h = 1, 1, 64, 64
for mode in modes:
    Lap = Laplacian(tau=tau, sh=[c, w, h], device=device, mode=mode)
    x = torch.rand(n, c, w, h).reshape(n, -1).to(device)
    print(mode, torch.det(get_jacobian(Lap, x, x.shape[-1])).item())


# Image
im = torch.from_numpy(imageio.imread('/home/voletiv/EXPERIMENTS/msflow/retriever256.png')).unsqueeze(0).permute(0, 3, 1, 2).float().div(255)
im64 = F.interpolate(im, size=[64, 64], mode='nearest')
# With mean_sub
Lap = Laplacian(tau=tau, sh=[c, w, h], device=device, mean_sub=True)
lap_ms = Lap(im64.reshape(1, -1).to(device)).cpu().reshape(im64.shape)
# WithOUT mean_sub
Lap = Laplacian(tau=tau, sh=[c, w, h], device=device, mean_sub=False)
lap_wms = Lap(im64.reshape(1, -1).to(device)).cpu().reshape(im64.shape)

plt.subplot(1, 2, 1)
plt.imshow((lap_ms.permute(0, 2, 3, 1)[0].numpy() + 1)/2)
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow((lap_wms.permute(0, 2, 3, 1)[0].numpy() + 1)/2)
plt.axis('off')
plt.savefig('/home/voletiv/a.png', bbox_inches='tight', pad_inches=0.0)

print(lap_ms.min(), lap_ms.max())
print(lap_wms.min(), lap_wms.max())


#######################

class MeanSub(nn.Module):
    def __init__(self):
        super().__init__()
        self.mode = mode
    def forward(self, X):
        return X - X.mean()

MS = MeanSub()
X = torch.rand(1, 64)
print(torch.det(get_jacobian(MS, X, X.shape[-1])).item())
