import pywt
import numpy as np
import torch

"""
Input: _in: 2d image of format ()
"""
def waveletdec2d(_input, wavelet='db1', level=4, axes=(-2, -1)):
    coeff_list = pywt.wavedec2(_input, wavelet, level=level, axes=axes)
    wavelet_dat = pywt.coeffs_to_array(coeff_list, axes=axes)
    return wavelet_dat

def waveletrec2d(wavelet_mat, coeff_slices, wavelet='db1', axes=(-2, -1)):
    coeff = pywt.array_to_coeffs(wavelet_mat, coeff_slices, output_format='wavedec2')
    coeff_shape = [coeff[0].shape]
    for t in coeff[1:]:
        coeff_shape.append([d.shape for d in t])
    img = pywt.waverec2(coeff, wavelet, axes=axes)
    return img

def get_batched_coeff_slices(single_coeff_slices):
    batched_cs = [(slice(None, None, None), *single_coeff_slices[0])]
    for d in single_coeff_slices[1:]:
        new_d = {}
        for k in d.keys():
            new_d[k] = (slice(None, None, None), *d[k])
        batched_cs.append(new_d)
    return batched_cs

class WaveletDec2dTransform(object):
    def __init__(self, axes=(-2, -1)):
        self.coeff_slices = None
        self.axes = axes

    def __call__(self, img):
        return self.encode(img)
    
    def encode(self, img):
        wavelet_dat = waveletdec2d(img, axes=self.axes)
        # Initialize coeff slices with the first batch
        # Assumes all images are equally sized
        if self.coeff_slices is None:
            self.coeff_slices = wavelet_dat[1]
        return np.array(wavelet_dat[0], dtype=float)
    
    def decode(self, encoding):
        img = waveletrec2d(encoding, 
                           get_batched_coeff_slices(self.coeff_slices))
        return img

class DFT2dTransform(object):
    def __init__(self, axes=(-2, -1)):
        self.axes = axes
        self.scaling_const = 1e4
    
    def __call__(self, img):
        return self.encode(img)
    
    def encode(self, img):
        spectrum = np.fft.rfft2(img, axes=self.axes)
        encoding = np.stack([spectrum.real, spectrum.imag], -1) / self.scaling_const
        return encoding
    
    def decode(self, encoding):
        spectrum = encoding[..., 0] * self.scaling_const + 1j * encoding[..., 1] * self.scaling_const
        img = np.fft.irfft2(spectrum, axes=self.axes)
        return img

def test_wavelet():
    test_img = pywt.data.camera()
    h, w = test_img.shape
    test_img = np.reshape(test_img, (1, h, w))
    wavelet_mat, _ = waveletdec2d(test_img)
    import matplotlib.pyplot as plt
    plt.imshow(np.squeeze(wavelet_mat), cmap=plt.cm.gray)
    plt.show()

def test_wavelet_mnist():
    from torchvision import transforms, datasets
    from torch.utils.data import DataLoader

    wavelet_transform_obj = WaveletDec2dTransform()
    transform=transforms.Compose([
        wavelet_transform_obj,
        transforms.ToTensor(),
        # transforms.Normalize((0.1307,), (0.3081,))
    ])
    train_ds = datasets.MNIST('data', train=True, download=True, transform=transform)
    loader = DataLoader(train_ds, batch_size=64, shuffle=True, drop_last=True)

    import matplotlib.pyplot as plt
    import time
    plt.ion()
    fig, axes = plt.subplots(1, 2)
    plt.show()
    for i, (X, y) in enumerate(loader):
        wavelet_mat = np.squeeze(X.numpy())
        img_rec = wavelet_transform_obj.decode(wavelet_mat)
        for i in range(len(wavelet_mat)):

            axes[0].imshow(np.squeeze(wavelet_mat[i]), cmap=plt.cm.gray)
            axes[0].set_title("Decomposition")

            axes[1].imshow(np.squeeze(img_rec[i]), cmap=plt.cm.gray)
            axes[1].set_title("Reconstruction")

            fig.canvas.draw()
            fig.canvas.flush_events()
            time.sleep(1)
        
def test_dft():
    test_img = pywt.data.camera()
    h, w = test_img.shape
    test_img = np.reshape(test_img, (1, h, w))
    spectrum = np.fft.rfft2(test_img)
    print(spectrum)
    reconstruction = np.fft.irfft2(spectrum)
    import matplotlib.pyplot as plt
    plt.subplot(131)
    plt.imshow(np.squeeze(spectrum.real), cmap=plt.cm.gray)
    plt.subplot(132)
    plt.imshow(np.squeeze(spectrum.imag), cmap=plt.cm.gray)
    plt.subplot(133)
    plt.imshow(np.squeeze(reconstruction), cmap=plt.cm.gray)
    plt.show()

def test_dft_mnist():
    from torchvision import transforms, datasets
    from torch.utils.data import DataLoader

    dft_transform = DFT2dTransform()
    transform=transforms.Compose([
        dft_transform,
        transforms.ToTensor(),
        # transforms.Normalize((0.1307,), (0.3081,))
    ])
    train_ds = datasets.MNIST('data', train=True, download=True, transform=transform)
    loader = DataLoader(train_ds, batch_size=64, shuffle=True, drop_last=True)

    import matplotlib.pyplot as plt
    import time
    plt.ion()
    fig, axes = plt.subplots(2, 3)
    plt.show()
    for i, (X, y) in enumerate(loader):
        spectrum = np.squeeze(X.permute(0, 2, 3, 1).numpy())
        spectrum_mean = np.mean(spectrum, 0)
        img_rec = dft_transform.decode(spectrum)
        mean_img_rec = dft_transform.decode(spectrum_mean)

        print(spectrum.shape)

        axes[0, 0].clear()
        axes[0, 0].imshow(np.squeeze(spectrum_mean[..., 0]), cmap=plt.cm.gray)
        axes[0, 0].set_title("Mean DFT Amp")

        axes[0, 1].clear()
        axes[0, 1].imshow(np.squeeze(spectrum_mean[..., 1]), cmap=plt.cm.gray)
        axes[0, 1].set_title("Mean DFT Phase")

        axes[0, 2].clear()
        axes[0, 2].imshow(np.squeeze(mean_img_rec.real), cmap=plt.cm.gray)
        axes[0, 2].set_title("Mean Reconstruction")
        for i in range(len(spectrum)):
            axes[1, 0].clear()
            axes[1, 0].imshow(np.squeeze(spectrum[i, ..., 0]), cmap=plt.cm.gray)
            axes[1, 0].set_title("DFT Amp")

            axes[1, 1].clear()
            axes[1, 1].imshow(np.squeeze(spectrum[i, ..., 1]), cmap=plt.cm.gray)
            axes[1, 1].set_title("DFT Phase")

            axes[1, 2].clear()
            axes[1, 2].imshow(np.squeeze(img_rec[i].real), cmap=plt.cm.gray)
            axes[1, 2].set_title("Reconstruction")

            fig.canvas.draw()
            fig.canvas.flush_events()
            time.sleep(1)

if __name__ == "__main__":
    # test_wavelet_mnist()
    # test_dft()
    test_dft_mnist()