import numpy as np
from scipy import signal

import jax
import jax.numpy as jnp
import jaxopt

from flax import linen as nn
from utils.utils import gkern


class uPDNetMRI(nn.Module):
    mask: jnp.array
    im_size: int  # Necessary for the meas. op.
    num_iter: int
    channels: int
    features = 40
    blur_kernel = gkern(kernlen=7, std=0.5)
    blur_kernel_adjoint = blur_kernel

    def initconv(self, rng, *init_args):
        return jax.random.normal(rng, (3, 3, self.channels, self.features)) / 50

    def initths(self, rng, *init_args):
        # return jax.random.uniform(rng, (1, 1, 1, 1)) / 20.
        return jnp.ones((1, 1, 1, 1)) / 20.

    def setup(self):
        # rng = jax.random.PRNGKey(0)
        for it in range(self.num_iter):
            self.param('weight_' + str(it), self.initconv, (3, 3, self.channels, self.features))
            self.param('ths_' + str(it), self.initths, (1, 1, 1, 1))

    def gkern(kernlen=21, std=3):
        """Returns a 2D Gaussian kernel array."""
        gkern1d = signal.gaussian(kernlen, std=std).reshape(kernlen, 1)
        gkern2d = np.outer(gkern1d, gkern1d)
        kernel = jnp.array(np.sqrt(gkern2d[..., np.newaxis, np.newaxis]))
        kernel = jnp.repeat(kernel, 3, axis=3)
        return kernel / (kernel).sum() * 3

    def blur_forward(self, x):
        conv_output = jax.lax.conv_general_dilated(x,  # lhs = NCHW image tensor
                                                   jnp.transpose(self.blur_kernel[..., :self.channels], [3, 2, 0, 1]),
                                                   (1, 1),
                                                   'SAME',
                                                   feature_group_count=self.channels,
                                                   batch_group_count=1)
        return conv_output

    def blur_backward(self, x):
        conv_output = jax.lax.conv_general_dilated(x,  # lhs = NCHW image tensor
                                                   jnp.transpose(self.blur_kernel[..., :self.channels], [3, 2, 0, 1]),
                                                   (1, 1),
                                                   'SAME',
                                                   feature_group_count=self.channels,
                                                   batch_group_count=1)
        return conv_output

    def prox_adjoint(self, x, gamma, lamb):
        return x - gamma * jaxopt.prox.prox_lasso(x / gamma, lamb / gamma)

    def compute_conv_lip(self, h, Nf=(32, 32)):
        r'''
        Computes the lipschitz constant of the convolution with kernel h.
        More precisely, let L:x \mapsto h*x, then this function returns
        ||L||^2.
        '''
        H = jnp.fft.fft2(jnp.transpose(h, (2, 3, 1, 0)), Nf, axes=[0, 1])
        s = jnp.linalg.svd(H, compute_uv=False)
        return s.max()

    def apply_mask(self, x):
        return self.mask * x

    def MRI_forward(self, x, norm="ortho"):
        x = jnp.fft.fftn(  # type: ignore
            x, axes=(-2, -1), norm=norm
        )
        x = jnp.fft.fftshift(x, axes=[-2, -1])
        out = x * self.mask
        out = jnp.fft.ifftshift(out, axes=[-2, -1])
        return out

    def MRI_backward(self, x, norm="ortho"):
        x = jnp.fft.ifftshift(x, axes=[-2, -1])
        x = self.mask * x
        data = jnp.fft.fftshift(x, axes=[-2, -1])
        out = jnp.fft.ifftn(  # type: ignore
            data, axes=(-2, -1), norm=norm
        )
        return jnp.real(out)

    def identity_fun(self, x):
        return x

    def forward_op(self, x, type_operator):
        return self.MRI_forward(x)

    def backward_op(self, x, type_operator):
        return self.MRI_backward(x)

    def PD_step(self, w_conv, ths_param, y, x_prev, u, type_operator):

        lip_conv = self.compute_conv_lip(w_conv)  #
        gamma = 1 / (1.2 * lip_conv)  # 1.2
        tau = 0.5

        x = x_prev - tau * self.backward_op(self.forward_op(x_prev, type_operator) - y,
                                            type_operator) - tau * jax.lax.conv_transpose(u,  # lhs = NCHW image tensor
                                                                                          jnp.transpose(w_conv,
                                                                                                        [3, 2, 0, 1]),
                                                                                          (1, 1),  # window strides
                                                                                          'SAME',
                                                                                          transpose_kernel=True,
                                                                                          dimension_numbers=(
                                                                                              'NCHW', 'OIHW', 'NCHW'))

        Wx = jax.lax.conv(2 * x - x_prev,  # lhs = NCHW image tensor
                          jnp.transpose(w_conv, [3, 2, 0, 1]),
                          (1, 1),  # window strides
                          'SAME')

        z = u + gamma * Wx  # padding mode  # kernel should be odd

        u = self.prox_adjoint(z, gamma, ths_param)

        cost = 0.5 * jnp.sum((self.forward_op(x, type_operator) - y) ** 2) + ths_param * jnp.sum(jnp.abs(Wx))

        return x, u, cost

    def debug(self, y, type_operator):
        out = self.backward_op(self.forward_op(y, type_operator), type_operator)

        return out

    @nn.compact
    def __call__(self, y, type_operator, x_init=None, u_init=None):
        u = jax.lax.conv(x_init,  # lhs = NCHW image tensor
                         jnp.transpose(self.variables['params']['weight_0'], [3, 2, 0, 1]),
                         (1, 1),  # window strides
                         'SAME')
        x = x_init

        for it in range(self.num_iter):
            x, u, cost = self.PD_step(self.variables['params']['weight_' + str(it)],
                                      self.variables['params']['ths_' + str(it)], y, x, u, type_operator)

        return x, u

    def apply(self, params, y, type_operator, x_init=None, u_init=None):
        """
        Necessary to rewrite for appropriate backprop through the unfolded parameters.
        """
        u = u_init

        x = x_init

        for it in range(self.num_iter):
            x, u, cost = self.PD_step(params['params']['weight_' + str(it)], params['params']['ths_' + str(it)], y, x,
                                      u, type_operator)

        return x, u

