from typing import List
from attr import NOTHING
from tqdm import tqdm
import numpy as np
import torch
from torch import nn
from abc import ABC, abstractmethod
from models.activations import Exptest, GaussianActivations, AdaptativeSplineActivations, QuarticWindowedActivation, compute_quantile_centers
from utils.linear_solver import quad_solver
from utils.wavelets import reconstruction_fields, decomposition_fields, get_filters, Wavelet
from utils.operations import fold2d


#####################################################
#      Linear potential for score matching
#####################################################


class AbstractPotential(nn.Module, ABC):
    """ Collection of M abstract potentials. Input is of shape (N, C, L, L).
    For conditioned abstract potential: computes U(y|x) from the concatenation of x and y along the channel axis.
    We denote K the number of conditioning channels and V the number of varying channels, so that C = K + V.
    Gradients and Laplacians are then computed only with respect to y. """

    def __init__(self, num_potentials: int):
        super().__init__()
        self.num_potentials = num_potentials

    @abstractmethod
    def forward(self, x):
        """ Compute the potential, (N, C, L, L) to (N, M). """
        raise NotImplementedError

    @abstractmethod
    def gradient(self, x):
        """ Compute the gradient of the potential, (N, C, L, L) to (N, M, V, L, L). """
        raise NotImplementedError

    @abstractmethod
    def laplacian(self, x):
        """ Compute the laplacian of the potential, (N, C, L, L) to (N, M). """
        raise NotImplementedError

    def covariance_gradient(self, x):
        """ Compute grad_U(x) grad_U(x)^T, (N, C, L, L) to (N, M, M). """
        grad = self.gradient(x).reshape((x.shape[0], self.num_potentials, -1))  # (N, M, VLL)
        return grad @ grad.transpose(1, 2)  # (N, M, M)

    def learned_potential(self, x, theta, **kwargs):
        """ Compute theta^T U(x), x is (N, C, L, L), theta is (M,), returns (N,). """
        return self(x, **kwargs) @ theta


class ScalarPotential(AbstractPotential):
    """
    Collection of scalar potentials that are translations of some activation function.
    This ScalarPotential can applied in a custom orthogonal basis, by subclassing and implementing the methods
    reconstruct (orthogonal change of basis) and decompose (inverse orthogonal change of basis).
    """
    def __init__(self, domain, window_std=None, stride=None, quartic=True):
        """
        :param domain: centers are equally spaced in [-domain, domain]
        :param window_std: standard deviation of the window
        :param stride: number of standard deviations in-between centers
        :param quartic: if true, only one potentiel equal to x^4 restricted to a window of size 2*domain
        """
        if quartic:
            super().__init__(num_potentials=1)
            self.rho = QuarticWindowedActivation(domain)
            self.domain = domain
            self.window_std = window_std
            self.stride = stride
            self.centers = torch.zeros(1)
        else:
            super().__init__(num_potentials=int(2 * domain / (stride * window_std)))
            self.domain = domain
            self.window_std = window_std
            self.stride = stride
            self.centers = torch.linspace(-domain, domain, self.num_potentials)#[:, None, None, None]  # (M, 1, 1, 1)
            self.rho = GaussianActivations(self.centers, self.window_std)

    def reconstruct(self, x):
        """ (*, C, L, L) to (*, C', L', L'). """
        return x

    def decompose(self, x):
        """ (*, C', L', L') to (*, C, L, L). """
        return x

    def forward(self, x, reconstruct=True):
        """ Compute sum v(x), (N, C, L, L) to (N, M). """
        if reconstruct:
            x = self.reconstruct(x)
        
        return self.rho(x).sum((-1, -2, -3))  # (N, M, C, L, L) to (N, M)

    def gradient(self, x):
        """ Compute v'(x), (N, C, L, L) to (N, M, C, L, L). """
        # grad[V(A x)] = A^T grad[V](A x).
        s = x.shape
        return self.rho.derivative(x)
    
    def laplacian(self, x):
        """ Compute sum v''(x), (N, C, L, L) to (N, M). """
        # Delta[V(A x)] = Tr(A A^T grad^2 V(A x)) = Delta[V](A x) because A A^T = Id.
        return self.rho.laplacian(x).sum((-1, -2, -3))  # (N, M, C, L, L) to (N, M)


class ReconstructionScalarPotential(ScalarPotential):
    """ Collection of scalar potentials that are applied at a finer scale by inverting the wavelet decomposition.
    The input is of shape (4, L, L), but the scalar potential is applied to the reconstruction of shape (1, 2L, 2L). """
    def __init__(self, domain, window_std=None, stride=None, quartic=True, gamma=1., gammabar=1.):
        super().__init__(domain, window_std, stride, quartic)
        self.wavelets = Wavelet(gamma=gamma, gammabar=gammabar)
        self.gamma = gamma
        self.gammabar = gammabar
        for l in [8,16,32,64,128,256,512]:
            self.magic_constant(torch.Size([l, l]))

    def reconstruct(self, x, normalize=False):
        """ (*, 4, L, L) to (*, 1, 2L, 2L). """
        y = self.wavelets.inverse(x, normalize=normalize)
        return y

    def decompose(self,x, normalize=False):
        """ (*, 1, 2L, 2L) to (*, 4, L, L). """
        y = self.wavelets(x, normalize=normalize)
        return y

    def magic_constant(self, spatial_shape):
        """ Returns the diagonal values of B B^T for the given spatial shape. """
        name = f"magic_L{spatial_shape[0]}_{spatial_shape[1]}"
        if not hasattr(self, name):
            # Compute the magic constants as <B B^T e_i, e_i>.
            x = torch.zeros((4, 1,) + spatial_shape)  # (4, 1, L, L)
            # Build dirac at the 4 possible parity positions.
            i = torch.arange(4)
            x[i, :, i // 2, i % 2] = 1

            x = self.decompose(x, normalize=False)
            x[:, 0, :, :] = 0  # Set low frequencies to zero.
            x = self.reconstruct(x, normalize=False)  # (4, 1, L, L)

            # Extract values at corresponding positions for each dirac.
            y = x[i, :, i // 2, i % 2].reshape((2, 2))  # (2, 2)
            #  [[[0, 0], [0, 1]],
            #   [[1, 0], [1, 1]]]

            self.register_buffer(name, y)
        return getattr(self, name)

    def forward(self, x):
        """ Compute sum v(x), (N, C, L, L) to (N, M). """
        x_minus = self.reconstruct(x, normalize=True)
        return self.rho(x_minus).sum((-1, -2, -3))  # (N, M, C, L, L) to (N, M)

    def gradient(self, x):
        """ (N, 4, L, L) to (N, M, 3, L, L). """
        # grad_y[V(A x + B y)] = B^T grad[V](A x + B y).
        x_minus = self.reconstruct(x, normalize=True)  # (N, 1, 2L, 2L)
        grad = super().gradient(x_minus)  # (N, M, 1, 2L, 2L)
        s = x.shape
        ret = torch.zeros(s[0], self.num_potentials, 3, s[-2], s[-1], device=grad.device)
        for i in range(self.num_potentials):
            ret[:,i,...] = self.decompose(grad[:,i,...], normalize=False)[:,1:,...] 
        return self.gammabar * ret

    def laplacian(self, x):
        """ (N, 4, L, L) to (N, M). """
        # Delta_y[V(A x + B y)] = <diag(B B^T), V''(A x + B y)>
        #                       = C Delta[V](A x + B y) because B B^T is a convolution operator.
        x_minus = self.reconstruct(x, normalize=True)  # (N, 1, 2L, 2L)
        c = self.magic_constant(x_minus.shape[-2:]).reshape((1, 2, 1, 2))
        hessian = self.rho.laplacian(x_minus)  # (N, M, 1, 2L, 2L), diagonal representation of the Hessian of V.
        L = x.shape[-1]
        hessian = hessian.reshape(hessian.shape[:2] + (L, 2) * 2)  # (N, M, L, 2, L, 2)
        return (self.gammabar ** 2) * (c * hessian).sum((-1, -2, -3, -4))  # (N, M)


class AdaptativeScalarPotential(AbstractPotential):
    """ Collection of scalar potentials that are translations of some activation function.
    This ScalarPotential can applied in a custom orthogonal basis, by subclassing and implementing the methods
    reconstruct (orthogonal change of basis) and decompose (inverse orthogonal change of basis). """
    def __init__(self, data, num_quantiles):
        """
        With adaptative centers.
        """
        #self.rho = AdaptativeSplineActivations(data, num_quantiles)
        super().__init__(num_potentials=num_quantiles)
        self.rho = Exptest(data, num_quantiles) #works better
        self.centers = self.rho.centers_list
        self.domain = torch.max(torch.abs(data))

    def reconstruct(self, x):
        """ (*, C, L, L) to (*, C', L', L'). """
        return x

    def decompose(self, x):
        """ (*, C', L', L') to (*, C, L, L). """
        return x

    def forward(self, x, reconstruct=True):
        """ Compute sum v(x), (N, C, L, L) to (N, M). """
        if reconstruct:
            x= self.reconstruct(x)
        
        return self.rho(x).sum((-1, -2, -3))  # (N, M, C, L, L) to (N, M)
        

    def gradient(self, x):
        """ Compute v'(x), (N, C, L, L) to (N, M, C, L, L). """
        # grad[V(A x)] = A^T grad[V](A x).
        return self.rho.derivative(x)
    
    def laplacian(self, x):
        """ Compute sum v''(x), (N, C, L, L) to (N, M). """
        # Delta[V(A x)] = Tr(A A^T grad^2 V(A x)) = Delta[V](A x) because A A^T = Id.
        return self.rho.laplacian(x).sum((-1, -2, -3))  # (N, M, C, L, L) to (N, M)


class AdaptativeReconstructionScalarPotential(AdaptativeScalarPotential):
    """ Collection of scalar potentials that are applied at a finer scale by inverting the wavelet decomposition.
    The input is of shape (4, L, L), but the scalar potential is applied to the reconstruction of shape (1, 2L, 2L). """
    
    def __init__(self, data, num_quantiles, gamma=1., gammabar=1.):
        """
        With adaptative centers.
        """
        super().__init__(data=data, num_quantiles=num_quantiles)
        self.wavelets = Wavelet(gamma=gamma, gammabar=gammabar)
        self.gamma = gamma
        self.gammabar = gammabar
        for l in [8,16,32,64,128,256,512]:
            self.magic_constant(torch.Size([l, l]))

    def reconstruct(self, x, normalize=False):
        """ (*, 4, L, L) to (*, 1, 2L, 2L). """
        y = self.wavelets.inverse(x, normalize=normalize)
        return y

    def decompose(self,x, normalize=False):
        """ (*, 1, 2L, 2L) to (*, 4, L, L). """
        y = self.wavelets(x, normalize=normalize)
        return y

    def magic_constant(self, spatial_shape):
        """ Returns the diagonal values of B B^T for the given spatial shape. """
        name = f"magic_L{spatial_shape[0]}_{spatial_shape[1]}"
        if not hasattr(self, name):
            # Compute the magic constants as <B B^T e_i, e_i>.
            x = torch.zeros((4, 1,) + spatial_shape)  # (4, 1, L, L)
            # Build dirac at the 4 possible parity positions.
            i = torch.arange(4)
            x[i, :, i // 2, i % 2] = 1

            x = self.decompose(x, normalize=False)
            x[:, 0, :, :] = 0  # Set low frequencies to zero.
            x = self.reconstruct(x, normalize=False)  # (4, 1, L, L)

            # Extract values at corresponding positions for each dirac.
            y = x[i, :, i // 2, i % 2].reshape((2, 2))  # (2, 2)
            #  [[[0, 0], [0, 1]],
            #   [[1, 0], [1, 1]]]

            self.register_buffer(name, y)
        return getattr(self, name)


    def forward(self, x):
        """ Compute sum v(x), (N, C, L, L) to (N, M). """
        x_minus = self.reconstruct(x, normalize=True)
        return self.rho(x_minus).sum((-1, -2, -3))  # (N, M, C, L, L) to (N, M)

    def gradient(self, x):
        """ (N, 4, L, L) to (N, M, 3, L, L). """
        # grad_y[V(A x + B y)] = B^T grad[V](A x + B y).
        x_minus = self.reconstruct(x, normalize=True)  # (N, 1, 2L, 2L)
        grad = super().gradient(x_minus)  # (N, M, 1, 2L, 2L)
        s = x.shape
        ret = torch.zeros(s[0], self.num_potentials, 3, s[-2], s[-1], device=grad.device)
        for i in range(self.num_potentials):
            ret[:,i,...] = self.decompose(grad[:,i,...], normalize=False)[:,1:,...] 
        return self.gammabar * ret

    def laplacian(self, x):
        """ (N, 4, L, L) to (N, M). """
        # Delta_y[V(A x + B y)] = <diag(B B^T), V''(A x + B y)>
        #                       = C Delta[V](A x + B y) because B B^T is a convolution operator.
        x_minus = self.reconstruct(x, normalize=True)  # (N, 1, 2L, 2L)
        c = self.magic_constant(x_minus.shape[-2:]).reshape((1, 2, 1, 2))
        hessian = self.rho.laplacian(x_minus)  # (N, M, 1, 2L, 2L), diagonal representation of the Hessian of V.
        L = x.shape[-1]
        hessian = hessian.reshape(hessian.shape[:2] + (L, 2) * 2)  # (N, M, L, 2, L, 2)
        return (self.gammabar ** 2) * (c * hessian).sum((-1, -2, -3, -4))  # (N, M)


class GaussianPotential(AbstractPotential):
    def __init__(self, num_varying_channels=1, num_conditioning_channels=0, shifts=((0, 1), (1, 0))):
        """ Stationary Gaussian (covariance) potentials.
        The input is composed of num_conditioning_channels + num_varying_channels channels, in this order.
        We consider the covariance of the last num_varying_channels, given the first num_conditioning_channels.
        Covariances are computed in a stationary manner, and are restricted to the given shifts. """
        self.num_varying_channels = num_varying_channels  # denoted as V in shapes
        self.num_conditioning_channels = num_conditioning_channels  # denotes as K in shapes
        self.num_channels = self.num_varying_channels + self.num_conditioning_channels  # denoted as C in shapes

        pos_shifts = torch.tensor(shifts)
        all_shifts = torch.cat((torch.zeros((1, 2)), pos_shifts, -pos_shifts))

        def pos_to_neg(i):
            """ Convert an index of a shift to the index of its opposite. """
            if i == 0:
                return 0
            elif i <= len(shifts):
                return i + len(shifts)
            else:
                return i - len(shifts)

        # We compute moments of the form sum_i x[c, i] x[d, i - s] / 2 for several channels c, d and shifts s.
        # We now build various lists of indices to have fast batched implementations of the covariances.
        indices = []  # (c, d, s_pos, s_neg, is_quad) indices

        for c in range(self.num_channels):
            # Do not consider covariances between conditioning channels.
            min_channel = self.num_conditioning_channels if c < self.num_conditioning_channels else c
            for d in range(min_channel, self.num_channels):
                # Positive shifts only when c = d, add negative shifts when c != d.
                num_shifts = 1 + len(pos_shifts) if c == d else len(all_shifts)
                for s_pos in range(num_shifts):
                    s_neg = pos_to_neg(s_pos)
                    is_quad = c == d and s_pos == 0
                    indices.append((c, d, s_pos, s_neg, is_quad))

        super().__init__(num_potentials=len(indices))

        self.register_buffer("shifts", all_shifts)  # (S, 2)

        indices = torch.tensor(indices)  # (M, 5), tensor for convenient indexing
        self.register_buffer("first_channel_indices", indices[:, 0])  # (M,) with values in [0:C]
        self.register_buffer("second_channel_indices", indices[:, 1])  # (M,) with values in [K:C]
        self.register_buffer("pos_shift_indices", indices[:, 2])  # (M,) with values in [0:S]
        self.register_buffer("neg_shift_indices", indices[:, 3])  # (M,) with values in [0:S]
        self.register_buffer("quad_indices", indices[:, 4].nonzero()[:, 0])  # (M',) with values in [0:M]

    def get_rolls(self, x):
        """ (N, C, L, L) to (N, S, C, L, L). """
        # torch.roll(x, s)[i] = x[i - s]
        return torch.stack([torch.roll(x, (shift[0].int(), shift[1].int()), (-2, -1)) for shift in self.shifts], dim=1)

    def forward(self, x):
        """ (N, C, L, L) to (N, M). """
        # The m-th moment is defined as sum_i x[c[m], i] x[d[m], i - s[m]] / 2.
        # NOTE: slightly inefficient when we have a single channel, as computing negative shifts is superfluous.
        x_roll = self.get_rolls(x[:, self.num_conditioning_channels:])  # (N, S, V, L, L)
        x_prod = x[:, self.first_channel_indices] * x_roll[:, self.pos_shift_indices, self.second_channel_indices - self.num_conditioning_channels]  # (N, M, L, L)
        return x_prod.sum((-1, -2)) / 2  # (N, M)

    def gradient(self, x):
        """ (N, C, L, L) to (N, M, V, L, L). """
        x_roll = self.get_rolls(x) / 2  # (N, S, C, L, L)
        potentials = torch.arange(self.num_potentials, device=x.device)
        subset = self.first_channel_indices >= self.num_conditioning_channels

        # The gradient of the m-th moment with respect to x[k, j] is: (k is only a varying channel)
        # ret[m, k, j] = (delta[k, c[m]] x[d[m], j - s[m]] + delta[k, d[m]] x[c[m], j + s[m]]) / 2.
        ret = torch.zeros((x.shape[0], self.num_potentials, self.num_varying_channels) + x.shape[2:], device=x.device)  # (N, M, V, L, L)
        ret[:, potentials[subset], self.first_channel_indices[subset] - self.num_conditioning_channels] += x_roll[:, self.pos_shift_indices[subset], self.second_channel_indices[subset]]  # (N, M', L, L)
        ret[:, potentials, self.second_channel_indices - self.num_conditioning_channels] += x_roll[:, self.neg_shift_indices, self.first_channel_indices]  # (N, M, L, L)
        return ret

    def laplacian(self, x):
        """ (N, C, L, L) to (N, M). """
        # The laplacian of the m-th moment is L^2 for quadratic moments, and zero for others.
        ret = torch.zeros(x.shape[0], self.num_potentials, device=x.device)
        ret[:, self.quad_indices] = np.prod(x.shape[2:])
        return ret


class DecompositionGaussianPotential(GaussianPotential):
    """ Gaussian potential that is computed on a coarser grid.
    It throws away the low frequencies at scale j + 1 and is thus well-conditioned. """
    def __init__(self, shifts=((0, 1), (1, 0)), mode="periodization", wave="db4"):
        # We have 3 high-frequencies channels at scale j, so 3 * 4 channels at scale j + 1.
        # We have one low-frequency channel at scale j, of which we keep only 3 channels at scale j + 1.
        super().__init__(num_varying_channels=12, num_conditioning_channels=3, shifts=shifts)
        self.wavelets = Wavelet(J=1, mode=mode, wave=wave)

    def decompose(self, x):
        """ (N, 4, L, L) to (N, 15, L/2, L/2). """
        y = self.wavelets(x.reshape((-1, 1) + x.shape[-2:]))  # (N*4, 1, L, L) to (N*4, 4, L/2, L/2)
        # y[i,j] = (x * psi_i) * psi_j
        y = y.reshape((-1, 16) + y.shape[-2:])[:, 1:]  # (N, 15, L/2, L/2), throw away low-frequency channel
        # The first three channels of y are the details at scale j+1 of x_j: correct ordering for GaussianPotential.
        return y

    def reconstruct(self, x):
        """ (*, 4, L/2, L/2) to (*, L, L). """
        batch_shape = x.shape[:-3]
        y = self.wavelets.inverse(x.reshape((-1,) + x.shape[-3:]))  # (N, 1, L, L)
        y = y.reshape(batch_shape + y.shape[-2:])  # (*, L, L)
        return y

    def forward(self, x):
        """ (N, 4, L, L) to (N, M). """
        return super().forward(self.decompose(x))  # (N, 15, L/2, L/2) to (N, M)

    def gradient(self, x):
        """ (N, 4, L, L) to (N, M, 3, L, L). """
        gradient = super().gradient(self.decompose(x))  # (N, M, 12, L/2, L/2)
        gradient = gradient.reshape(gradient.shape[:2] + (3, 4) + gradient.shape[-2:])  # (N, M, 3, 4, L/2, L/2)
        # For each of the 3 channels of the high-frequencies at scale j, we have the gradient with respect to its
        # 4 channels at scale j + 1. Reconstructing this gradient re-expresses it at scale j.
        return self.reconstruct(gradient) / self.wavelets.gammabar ** 2  # (N, M, 3, L, L)

    def laplacian(self, x):
        """ (N, 4, L, L) to (N, M). """
        # We have a potential V(x) = U(A x) where A is the decomposition operator.
        # This gives Laplacian[V](x) = Tr(A A^T, Hessian[U](A x)>.
        # A A^T is the operator of reconstructing then decomposing: it is the identity, except for the gamma factor.
        return super().laplacian(self.decompose(x)) / self.wavelets.gammabar ** 2


class UnionPotential(AbstractPotential):
    """ Combines several potentials. """
    def __init__(self, *potentials):
        super().__init__(num_potentials=sum(potential.num_potentials for potential in potentials))
        self.potentials = nn.ModuleList(potentials)

    def forward(self, x):
        return torch.cat(list(potential(x) for potential in self.potentials), dim=1)

    def gradient(self, x):
        return torch.cat(list(potential.gradient(x) for potential in self.potentials), dim=1)

    def laplacian(self, x):
        return torch.cat(list(potential.laplacian(x) for potential in self.potentials), dim=1)

def get_wavelets(potential):
    if isinstance(potential, UnionPotential):
        if isinstance(potential.potentials[0], DecompositionGaussianPotential):
            return potential.potentials[-1].wavelets
    elif isinstance(potential, DecompositionGaussianPotential):
        return potential.wavelets
    else:
        return NOTHING

def gradient_laplacian_checks(potential, cuda=False, gradcheck=True):
    """ Numerical checks for analytical gradient and laplacian of a potential. Prints the relative error between
    the analytical computation and the backpropagation with PyTorch auto-differentiation. """
    N = 1
    C = 4
    L = 16
    shape = (N, C, L, L)
    num_tries = 5
    device = torch.device("cuda" if cuda else "cpu")
    dtype = torch.float64 if gradcheck else torch.float32
    kwargs = dict(dtype=dtype, device=device)
    potential.to(**kwargs)

    def get_func():
        z = torch.randn((N, potential.num_potentials), **kwargs)

        def func(x):
            y = potential(x)
            return torch.sum(z * y)

        return func, z

    n = lambda x: torch.norm(torch.flatten(x))
    rel_err = lambda true, other: (n(true - other) / n(true)).item()

    # Check wavelets, just in case...
    if hasattr(potential, "wavelets"):
        for _ in range(num_tries):
            x = torch.rand((N, C, L//2, L//2), **kwargs)
            x_rec = potential.wavelets(potential.wavelets.inverse(x))
            y_rec = torch.from_numpy(decomposition_fields(reconstruction_fields(x.cpu().numpy()))).to(dtype=dtype, device=device)
            print("Pytorch wavelets", rel_err(x, x_rec))
            print("Pywavelets", rel_err(x, y_rec))

    for _ in range(num_tries):
        f, z = get_func()
        x = torch.randn(shape, requires_grad=True, **kwargs)
        if gradcheck:
            torch.autograd.gradcheck(f, x)

        y = f(x)  # (N, M)
        numeric = torch.autograd.grad(y, x)[0][:, 1:]  # (N, C, L, L)
        analytic = torch.sum(potential.gradient(x.detach()) * z[..., None, None, None], dim=1)  # (N, C, L, L)
        print("Gradient", rel_err(numeric, analytic))

    # Check Laplacian
    for _ in range(num_tries):
        f, z = get_func()
        x = torch.randn(shape, requires_grad=True, **kwargs)
        if gradcheck:
            torch.autograd.gradgradcheck(f, x)

        numeric_full = torch.autograd.functional.hessian(f, x)[:, 1:, :, :, :, 1:]  # (N, C-1, L, L, N, C-1, L, L)
        numeric = numeric_full.reshape((N * (C - 1) * L * L, N * (C - 1) * L * L)).diag().sum()
        analytic_full = potential.laplacian(x)
        analytic = torch.sum(z * analytic_full)  # (N, M)
        print("Laplacian", rel_err(numeric, analytic))


def wavelet_checks(wavelets, L, cuda=False):
    """ Numerical checks for analytical gradient and laplacian of a potential. Prints the relative error between
    the analytical computation and the backpropagation with PyTorch auto-differentiation. """
    N = 1
    C = 4
    shape = (N, C, L, L)
    num_tries = 5
    device = torch.device("cuda" if cuda else "cpu")
    dtype = torch.float32
    kwargs = dict(dtype=dtype, device=device)
    wavelets.to(**kwargs)

    n = lambda x: torch.norm(torch.flatten(x))
    rel_err = lambda true, other: (n(true - other) / n(true)).item()

    for _ in range(num_tries):
        x = torch.rand(shape, **kwargs)
        x_rec = wavelets(wavelets.inverse(x))
        print("Pytorch wavelets", rel_err(x, x_rec))
        # y_rec = torch.from_numpy(decomposition_fields(reconstruction_fields(x.cpu().numpy()))).to(dtype=dtype, device=device)
        # print("Pywavelets", rel_err(x, y_rec))


if __name__ == "__main__":
    for L in [1, 2, 4, 8, 16]:
        wavelet_checks(Wavelet(gamma=2., gammabar=2.), L=L, cuda=True)

    # gradient_laplacian_checks(GaussianPotential(num_conditioning_channels=1, num_varying_channels=3).cuda())
    # gradient_laplacian_checks(DecompositionGaussianPotential(gamma=2).cuda())

    # gradcheck = False
    # if gradcheck:
    #     torch.set_default_dtype(torch.float64)
    #
    # S = ReconstructionScalarPotential(30, gamma=3., gammabar=2.)
    # G = DecompositionGaussianPotential(shifts=())
    # F = UnionPotential(S, G)
    # gradient_laplacian_checks(F, gradcheck=gradcheck)


#####################################################
#   Compute the parameters with score matching
#####################################################

def compute_theta_with_GD(x, potential: AbstractPotential, init=None, lr=0.1, n_epochs=1000, batch_size=1000,diagonal_precondition=False, verbose=False):
    """
    :param x: (N, C, L, L)
    :param potential: potential to learn parameters for
    :param init: optional, (M,)
    :return: learned parameters from score matching (M,). If verbose, returns in addition the loss history of the GD and the conditioning matrix (M, M)
    """
    n = x.shape[0]
    batches = torch.split(x, batch_size)
    covariance = torch.sum(potential.covariance_gradient(batches[0]), dim=0)
    for batch in batches[1:]:
        covariance += torch.sum(potential.covariance_gradient(batch), dim=0)
    covariance /= n
    w = torch.mean(potential.laplacian(x), dim=0)

    if diagonal_precondition:
        d = 1/torch.sqrt(torch.diagonal(covariance))
        D = torch.diag_embed(d)
        Dinv = torch.diag_embed(1/d)
    else:
        D = torch.eye(covariance.shape[-1], device=covariance.device, dtype=covariance.dtype)
        Dinv = D

    A = D @ covariance @ D
    b = D @ w

    if verbose:
        ret, hist = quad_solver(A, b, init=init, lr=lr, n_epochs=n_epochs, return_history=True)
        return Dinv @ ret, hist, covariance 
    else:
        ret = quad_solver(A, b, init=init, lr=lr, n_epochs=n_epochs)
        return Dinv @ ret


#####################################################
#             Diffusion process
#####################################################

def batched_indices(n, batch_size):
    n_batches = n//batch_size + int(n % batch_size != 0)
    batches = [range(i*batch_size, min(n, (i+1)*batch_size)) for i in range(n_batches)] 
    return batches

def one_diffusion_step(x, step_size, std=1.):
    """
    Ornstein-Uhlenbeck diffusion with naive Euler-Maruyama scheme and given step size. 
    dX_t = -X_t dt + sqrt(2*std)dBt
    The invariant measure is a Gaussian with variance std^2.
    """
    w = torch.randn_like(x)
    delta = -x*step_size + np.sqrt(2*std*step_size)*w 
    return x + delta

def diffuse(init, steps_list, std=1.):
    """
    Diffuses noise with an Ornstein-Uhlenbeck diffusion and a naive Euler-Maruyama
    scheme for n_steps with given step_size. 
    returns : 
    array with size (n_steps, init.shape)
    output[t] = t-th step of diffusion
    
    """
    dims = init.shape
    diffusion_history = torch.zeros(len(steps_list)+1, *dims)
    diffusion_history[0] = init

    for t in range(len(steps_list)):
        diffusion_history[t+1] = one_diffusion_step(
            diffusion_history[t], 
            steps_list[t], 
            std, 
        )

    return diffusion_history.to(torch.device('cpu'))

def diffuse_fast(init, steps_list, device, batch_size=10, std=1.):
    """
    Diffusion using batches on device. Usefull for 10k samples 32x32. 
    Choose small batch sizes. 
    """
    dims = init.shape
    batches = batched_indices(dims[0], batch_size)
    diffusion_history = torch.zeros(len(steps_list)+1, *dims)
    print("Starting the diffusion...")
    for batch in tqdm(batches):
        
        diffusion_history[:, batch, ...] = diffuse(init[batch], 
        steps_list = steps_list, 
        device=device, 
        std=std)
    print("End.")
    return diffusion_history

#      the backward ODE diffusion

def one_ODE_step(x, potential:AbstractPotential, params, step_size=0.01, std=1.):
    gradient = potential.gradient(x)
    t = torch.reshape(params, (1,potential.num_potentials,1,1,1))
    delta = std*torch.sum(gradient*t, dim=1)
    delta = -x - delta
    return x + delta*step_size

def backward_ODE(init, potential:AbstractPotential, params_list, steps_list, std=1.):
    dims = init.shape
    ODE_history = torch.zeros(len(steps_list)+1, *dims)
    ODE_history[0] = torch.clone(init) 
    for t in range(len(steps_list)):
        ODE_history[t+1] = one_ODE_step(
            ODE_history[t], 
            potential, 
            params_list[t], 
            steps_list[t], 
            std
        )
    
    return ODE_history

# the backward SDE diffusion

def one_SDE_step(x, potential:AbstractPotential, params, step_size, std=1.):
    gradient = potential.gradient(x)
    t = torch.reshape(params, (1, potential.num_potentials, 1, 1 ,1))
    delta_drift = x - 2*std*torch.sum(gradient*t, dim=1)
    delta_wiener = np.sqrt(2*step_size*std)*torch.randn_like(x)
    return x + delta_drift*step_size + delta_wiener

def backward_SDE(init, potential_list, params_list, steps_list, std=1.):
    dims = init.shape
    SDE_history = torch.zeros(len(steps_list)+1, *dims)
    SDE_history[0] = torch.clone(init) 
    for t in range(len(steps_list)):
        SDE_history[t+1] = one_SDE_step(
            SDE_history[t], 
            potential_list[t], 
            params_list[t], 
            steps_list[t], 
            std
        )
    
    return SDE_history


def one_Langevin_correction_step(x, potential, params, step_size):
    gradient = torch.sum(potential.gradient(x) * params[None,:,None,None,None], dim=1)
    noise = np.sqrt(2*step_size)*torch.randn_like(x)
    return x - step_size * gradient + noise

def one_conditional_SDE_step(x, potential:AbstractPotential, params, step_size=0.01, std=1.0):
    """
    Diffusion inside the last channels - the first one is untouched.
    input (N, C, L, L)
    output (N, C, L, L)
    """
    gradient = potential.gradient(x)  # (N, M, C, L, L)
    t = torch.reshape(params, (1, potential.num_potentials, 1, 1, 1))  # (1, M, 1, 1, 1)
    ret = x
    delta_drift = step_size*x[:,1:,:,:] - 2*step_size*std*torch.sum(gradient*t, dim=1)  # (N, C-1, L, L)
    delta_wiener = np.sqrt(2*step_size*std)*torch.randn_like(delta_drift)  # (N, C-1, L, L)
    ret[:, 1:, :, :] += delta_drift + delta_wiener
    return ret

def one_conditional_Langevin_correction_step(x, potential, params, step_size):
    gradient = torch.sum(potential.gradient(x) * params[None,:,None,None,None], dim=1)
    noise = np.sqrt(2*step_size)*torch.randn_like(x)
    x[:,1:,:,:] += noise[:,1:,:,:] - step_size * gradient
    return x 


def conditional_backward_SDE(init, potential_list, params_list, steps_list, std=1.):
    dims = init.shape
    SDE_history = torch.zeros(len(steps_list)+1, *dims)
    SDE_history[0] = torch.clone(init) 
    for t in range(len(steps_list)):
        SDE_history[t+1] = one_conditional_SDE_step(
            SDE_history[t], 
            potential_list[t], 
            params_list[t], 
            steps_list[t], 
            std
        )

    return SDE_history


    ## conditional covariance

class Covariance():

    def __init__(self, filter):
        self.L = filter.shape[-1]
        self.l = int(self.L/2)
        self.psd = filter

        psi_fft = torch.fft.fft2(torch.from_numpy(get_filters(self.L))) #(4,L,L)
        isp_fft = torch.conj(psi_fft) 
        self.Sigma_wav_fft = fold2d(
                (psi_fft[:,None,...] * isp_fft[None,...]) * self.psd #(4,4,L,L) 
            ) #(4,4,l,l)

        self.identity_on_wavelets = torch.zeros(4,4,self.l,self.l)
        for i in range(1,4):
            self.identity_on_wavelets[i,i] = torch.ones(self.l,self.l)

        self.std = torch.ones(4)
    
    def normalize(self):
        """
        Computes the standard deviation of the process across each of the 4 channels, then normalize the covariance.
        """
        self.std = torch.real(torch.sqrt(torch.diag(self.Sigma_wav_fft.mean((-2,-1))))) 
        self.Sigma_wav_fft = (1/self.std[:,None,None,None]) * self.Sigma_wav_fft * (1/self.std[None,:,None,None])
        
    def denormalize(self):
        self.Sigma_wav_fft = (self.std[:,None,None,None]) * self.Sigma_wav_fft * (self.std[None,:,None,None])

    def to(self, device):
        self.identity_on_wavelets = self.identity_on_wavelets.to(device)
        self.Sigma_wav_fft = self.Sigma_wav_fft.to(device)

    def __call__(self, t=0):
        """
        - output (4,4,l,l)
        """
        ret = (self.Sigma_wav_fft).clone()
        ret[1:,1:,...] *= np.exp(-2*t) 
        ret[:1,1:,...] *= np.exp(-t)
        ret[1:,:1,...] *= np.exp(-t)
        ret += (1 - np.exp(-2*t))*self.identity_on_wavelets
        return ret

    def inverse(self, t=0):
        """
        - output (4,4,l,l)
        """
        sigma = self(t)
        ret = torch.inverse(sigma.permute(2,3,0,1).cfloat())
        return ret.permute(2,3,0,1)

    def apply(self, input, t=0, fourier=True):
        """
        Given an input (B,4,l,l) in Fourier basis, applies Sigma(t) to the input.
        If the input is not in Fourier basis, specify fourier=False. 

        - input (B, 4, l, l)
        - output (B,4,l,l)
        """
        sigma = self(t)[None,...]

        if not fourier:
            x = torch.fft.fft2(input)
            ret = (sigma * x[:,None,...]).sum(2)
            return torch.fft.ifft2(ret)
        else:
            ret = (sigma * input[:,None,...]).sum(2)
            return ret

    def apply_inverse(self, input, t=0, fourier=True):
        """
        Given an input (B,4,l,l) in Fourier basis, apply the inverse of Sigma(t) to the input.
        - input (B, 4, l, l)
        - output (B,4,l,l)
        """
        sigma = self.inverse(t)[None,...]

        if not fourier:
            x = torch.fft.fft2(input)
            ret = (sigma * x[:,None,...]).sum(2)
            return torch.fft.ifft2(ret)
        else:
            ret = (sigma * input[:,None,...]).sum(2)
            return ret

def compute_empirical_psd(input_w_f):
    """
    Computes the space covariance of a process written in wavelet-fourier.
    """
    input = torch.from_numpy(reconstruction_fields(torch.fft.ifft2(input_w_f).numpy()))
    input = torch.fft.fft2(input)
    ret = torch.abs(input)**2 / input.shape[-1]**2
    return ret.mean(0)[0]
    

def compute_empirical_psd_across_channels(input, fourier=True):
    """
    For a (B,C,W,H) array, computes the stationary covariances between the C 
    channels then averages over the whole batch.
    input (B,C,W,H) in fourier
    output (C,C,W,H)
    """
    l = input.shape[-1]
    if fourier: 
        input_f = input
    else:
        input_f = torch.fft.fft2(input)
    
    empirical_psd = input_f[:,:,None,...] * torch.conj(input_f[:,None,...]) 
    return empirical_psd.mean(0) / l**2