import torch, pywt
from torch import nn
import numpy as np
from pytorch_wavelets import DWTInverse, DWTForward
from .operations import translate



def idwt2(arg, mode='periodization', wave="db4"):
    cA, (cH, cV, cD) = arg
    return pywt.idwt2((cA, (cH, cV, cD)), wave, mode)


def dwt2(fields, mode='periodization', wave="db4"):
    return pywt.dwt2(fields, wave, mode)

def compute_bar_G(L):
    bar_G = np.zeros((3, L, L))
    all_zeros = np.zeros((L // 2, L // 2))
    dirac = np.zeros((L // 2, L // 2))
    dirac[0, 0] = 1.
    bar_G[0, :, :] = idwt2((all_zeros, (dirac, all_zeros, all_zeros)))
    bar_G[1 :, :] = idwt2((all_zeros, (all_zeros, dirac, all_zeros)))
    bar_G[2, :, :] = idwt2((all_zeros, (all_zeros, all_zeros, dirac)))

    return bar_G
    

def reconstruction_fields(fields, mode="periodization", wave="db4", gamma=1., gammabar=1., normalize=False):
    """
    :param fields: ndarray of size (...,4,L,L)
    :param b_j: (float) the rescaling factor
    :return: The field decompose into wavelet and rescaled of size (..., 2*L, 2*L)
    result = b_j(G x_j + \bar G \bar x_j)
    NB : return size is (...,1,2L,2L)
    """
    if normalize:
        fields = fields.copy()  # Let's be safe and copy before in-place modification.
        fields[..., 0, :, :] *= gamma
        fields[..., 1:, :, :] *= gammabar
    ret = idwt2(
        (
            fields[..., 0, :, :],
            (
                fields[..., 1, :, :],
                fields[..., 2, :, :],
                fields[..., 3, :, :]
            )
        ), mode=mode, wave=wave)
    return ret[..., None, :, :]  # (*, 1, 2L, 2L)


def zero_lifting(lower_freq, gamma):
    """
    Lifts a field to a higher resolution by adding zero details.
    input : array (n,1,l,l)
    output : array (n,1,2l,2l) with input as low frequencies and zero details."""
    dims = lower_freq.shape
    decomposed = np.concatenate((lower_freq, np.zeros((dims[0], 3, dims[-1], dims[-1]))), axis=1)
    return reconstruction_fields(decomposed, gamma)


def decomposition_fields(fields, mode='periodization', wave="db4", gamma=1., gammabar=1., normalize=False):
    """
    :param fields: ndarray of size (*, 1, L, L)
    :param b_j: (float) rescaling factor
    :return: ndarray of size (*, 4, L/2, L/2)
    """
    cA, (cH, cV, cD) = dwt2(fields[..., 0, :, :], mode=mode, wave=wave)
    result = np.stack((cA, cH, cV, cD), axis=-3)
    if normalize:
        result[..., 0, :, :] /= gamma
        result[..., 1:, :, :] /= gammabar
    return result



def reconstruct_operator(operator):
    L = int(operator.shape[-1] * 2)
    delta = np.zeros((L//2, L//2))
    delta[0,0] = 1
    zeros_array = np.zeros((L//2, L//2))


    cA = idwt2((delta, (zeros_array, zeros_array, zeros_array)))
    cH = idwt2((zeros_array, (delta, zeros_array, zeros_array)))
    cV = idwt2((zeros_array, (zeros_array, delta, zeros_array)))
    cD = idwt2((zeros_array, (zeros_array, zeros_array, delta)))

    fft_A = np.fft.fft2(cA)
    fft_H = np.fft.fft2(cH)
    fft_V = np.fft.fft2(cV)
    fft_D = np.fft.fft2(cD)

    filter_matrix = np.zeros((4, 4, L, L), dtype=complex)
    filter_matrix[:,0,:,:] = np.stack((fft_A, fft_H, fft_V, fft_D))

    filter_matrix[:, 1, :, :] = translate(filter_matrix[:, 0, :, :], L//2, 0)
    filter_matrix[:, 2, :, :] = translate(filter_matrix[:, 0, :, :], 0, L//2)
    filter_matrix[:, 3, :, :] = translate(filter_matrix[:, 0, :, :], L // 2, L//2)

    filter_matrix = filter_matrix / 2.

    tilde_operator = np.zeros((4,4,L,L))
    tilde_operator[:,:,::2,::2] = operator
    fft_operator = np.fft.fft2(tilde_operator)

    for k1 in range(L):
        for k2 in range(L):
            product = np.dot(filter_matrix[:, :, k1, k2], np.conjugate(np.transpose(filter_matrix[:,:, k1, k2])))
            if np.sum(np.abs(product - np.identity(4).astype(complex))) > 1e-14:
                raise ValueError

    fft_tilde_reconstructed_operator = np.zeros((4,4,L,L)).astype(complex)
    for k1 in range(L):
        for k2 in range(L):
            fft_tilde_reconstructed_operator[:, :, k1, k2] = np.dot(
                np.dot(
                    np.conjugate(np.transpose(filter_matrix[:, :, k1, k2])),
                    fft_operator[:, :, k1, k2]
                ),
                filter_matrix[:,:, k1, k2]
            )
    loss = 0.
    for m1 in range(4):
        for m2 in range(m1+1, 4):
            loss += np.sum(np.abs(fft_tilde_reconstructed_operator[m1, m2, :, :]))
    # print('loss: {loss}'.format(loss=loss))
    fft_reconstruced_operator = fft_tilde_reconstructed_operator[0,0]

    if np.max(np.abs(np.imag(fft_reconstruced_operator))) > 1e-13:
        print("Reconstruction: fft of result is not real")
        print(np.max(np.abs(np.imag(fft_reconstruced_operator))))
        # raise ValueError

    if np.max(np.abs(np.imag(np.fft.ifft2(fft_reconstruced_operator)))) > 1e-13:
        print("Reconstruction: ifft of result is not real")
        print(np.max(np.abs(np.imag(np.fft.ifft2(fft_reconstruced_operator)))))
        raise ValueError

    return np.real(np.fft.ifft2(fft_reconstruced_operator))


def deconstruct_operator(operator):
    L = int(operator.shape[-1])
    delta = np.zeros((L//2, L//2))
    delta[0,0] = 1
    zeros_array = np.zeros((L//2, L//2))


    cA = idwt2((delta, (zeros_array, zeros_array, zeros_array)))
    cH = idwt2((zeros_array, (delta, zeros_array, zeros_array)))
    cV = idwt2((zeros_array, (zeros_array, delta, zeros_array)))
    cD = idwt2((zeros_array, (zeros_array, zeros_array, delta)))


    fft_A = np.fft.fft2(cA)
    fft_H = np.fft.fft2(cH)
    fft_V = np.fft.fft2(cV)
    fft_D = np.fft.fft2(cD)

    filter_matrix = np.zeros((4, 4, L, L), dtype=complex)
    filter_matrix[:,0,:,:] = np.stack((fft_A, fft_H, fft_V, fft_D))

    filter_matrix[:, 1, :, :] = translate(filter_matrix[:, 0, :, :], L//2, 0)
    filter_matrix[:, 2, :, :] = translate(filter_matrix[:, 0, :, :], 0, L//2)
    filter_matrix[:, 3, :, :] = translate(filter_matrix[:, 0, :, :], L // 2, L//2)

    filter_matrix = filter_matrix / 2.

    for k1 in range(L):
        for k2 in range(L):
            product = np.dot(filter_matrix[:, :, k1, k2], np.conjugate(np.transpose(filter_matrix[:,:, k1, k2])))
            if np.sum(np.abs(product - np.identity(4).astype(complex))) > 1e-14:
                raise ValueError

    fft_operator = np.fft.fft2(operator)
    fft_tilde_operator = np.zeros((4,4,L,L)).astype(complex)

    fft_tilde_operator[0, 0] = translate(fft_operator, 0, 0)
    fft_tilde_operator[1, 1] = translate(fft_operator, L//2, 0)
    fft_tilde_operator[2, 2] = translate(fft_operator, 0, L//2)
    fft_tilde_operator[3, 3] = translate(fft_operator, L//2, L//2)

    fft_result = np.zeros((4,4,L, L)).astype(complex)
    for k1 in range(L):
        for k2 in range(L):
            fft_result[:, :, k1, k2] = np.dot(
                np.dot(
                    filter_matrix[:,:, k1, k2],
                    fft_tilde_operator[:, :, k1, k2]
                ),
                np.transpose(np.conjugate(filter_matrix[:,:, k1, k2]))
            )

    fft_result = np.fft.fft2(np.fft.ifft2(fft_result)[:,:,::2, ::2])

    if np.max(np.abs(np.imag(np.fft.ifft2(fft_result)))) > 1e-8:
        print("Deconstruction: ifft of result is not real")
        print(np.max(np.abs(np.imag(np.fft.ifft2(fft_result)))))
        raise ValueError

    return np.real(np.fft.ifft2(fft_result))




def load_and_rescale_fields(training_dataset, nb_scales):
    fields = np.expand_dims(np.load(training_dataset), 1)
    fields_per_scale = [fields[:, 0]]


    for _ in range(nb_scales):
        fields = decomposition_fields(fields[:, 0], b_j=1.)
        fields_per_scale.append(fields[:, 0])

    fields_per_scale, b_j, theta_j = rescale_fields(fields_per_scale)
    return fields_per_scale, b_j, theta_j

def rescale_fields(fields_per_scale):
    theta = [1.]
    b_j = []
    result = []
    for field in fields_per_scale:
        theta.append(np.std(field))
        b_j.append(theta[-1] / theta[-2])
        result.append(field / theta[-1])

    return result, b_j, theta[1:]


def get_filters(L, transpose=True):
    """
    Returns the four (L,L) filters psi_i used in the wavelet decompositon. 
    The wavelet coeffs of x are subsampled(psi_i * x).
    - return size (4,L,L)
    """
    l = int(L/2)
    x = np.zeros((4,4,l,l))
    for i in range(4):
        x[i, i, 0, 0] = 1.
    filters = reconstruction_fields(x)[:,0,:,:] # those are the transposed filters
    if transpose:
        filters = np.real(np.fft.ifft2(np.conjugate(np.fft.fft2(filters))))
    return filters



######################################################################
#                       Torch Wavelets                               #
######################################################################

class Wavelet(nn.Module):
    """
    This is a Pytorch_wavelets implementation of the wavelet decomposition from T's code. 
    Supports GPU and autodiff. 
    """
    def __init__(self, J=1, mode='periodization', wave='db4', gamma=1., gammabar=1.):
        super().__init__()
        assert J == 1
        self.decompose = DWTForward(J=J, mode=mode, wave=wave)
        self.recompose = DWTInverse(mode=mode, wave=wave)
        self.mode = mode
        self.wave = wave
        self.gamma = gamma
        self.gammabar = gammabar 

    def forward(self, tensor, normalize=False):
        """  
        Wavelet decomposition at the first level. 
        :param fields: ndarray of size (N, 1, L, L)
        :return: ndarray of size (N, 4, L/2, L/2). Low-freqs are in channel 0.
        """
        if min(tensor.shape[-2:]) <= 4:
            # Pytorch-wavelets is bugged and has border effects, use pywt (numpy) implementation instead.
            return torch.from_numpy(decomposition_fields(
                tensor.cpu().numpy(), mode=self.mode, wave=self.wave, gamma=self.gamma, gammabar=self.gammabar,
            )).to(tensor.device)

        l,h = self.decompose(tensor)
        if normalize:
            return torch.concat((l/self.gamma, h[0][...,0,:,:,:]/self.gammabar), dim=-3)
        else:
            return torch.concat((l, h[0][...,0,:,:,:]), dim=-3)

    def inverse(self, tensor, normalize=False):
        """
        Wavelet reconstruction at one level. 
        :param fields: ndarray of size (N,4,L,L) with low-freqs in channel 0.
        :return: The field decompose into wavelet and rescaled of size (..., 2*L, 2*L)
        result = gamma (G x_j + \bar G \bar x_j)
        NB : return size is (N,1,2L,2L)
        """
        if min(tensor.shape[-2:]) <= 2:
            # Pytorch-wavelets is bugged and has border effects, use pywt (numpy) implementation instead.
            return torch.from_numpy(reconstruction_fields(
                tensor.cpu().numpy(), mode=self.mode, wave=self.wave, gamma=self.gamma, gammabar=self.gammabar,
            )).to(tensor.device)

        if normalize:
            arg = (self.gamma*tensor[...,:1,:,:], [self.gammabar*tensor[...,None,1:4,:,:]])
        else:
            arg = (tensor[...,:1,:,:], [tensor[...,None,1:4,:,:]])

        return self.recompose(arg)
    
def preprocess_data(data, normalize):
    """Takes as input a 2d batch of fields then decomposes it into
    wavelets and normalizes each channel if needed. 

    :param data: (N,1,W,H)
    :param normalize: bool, default True. If true, the 4 channels of the output are standardized to have std 1.
    :return: wavelet decomposition of data (N,4,W/2,H/2), std of lofreqs, std of hifreqs
    """
    wav = Wavelet()
    data_decomposed = wav(data.to(torch.float32))
    std_phi = data_decomposed[:,0,:,:].std()
    std_psi = data_decomposed[:,1:4,:,:].std()
    if normalize:
        data_decomposed[:,0,:,:] /= std_phi
        data_decomposed[:,1:4,:,:] /= std_psi
    return data_decomposed.to(torch.float32), std_phi, std_psi