import numpy as np
from scipy import signal

import jax
import jax.numpy as jnp
import jaxopt

from flax import linen as nn
from utils.utils import gkern


class uPDNet(nn.Module):
    mask: jnp.array
    im_size: int  # Necessary for the meas. op.
    num_iter: int
    channels: int
    features = 40
    blur_kernel = gkern(kernlen=7, std=0.5)
    blur_kernel_adjoint = blur_kernel

    def initconv(self, rng, *init_args):
        return jax.random.normal(rng, (3, 3, self.channels, self.features)) / 50

    def initths(self, rng, *init_args):
        # return jax.random.uniform(rng, (1, 1, 1, 1)) / 20.
        return jnp.ones((1, 1, 1, 1)) / 20.

    def setup(self):
        # rng = jax.random.PRNGKey(0)
        for it in range(self.num_iter):
            self.param('weight_' + str(it), self.initconv, (3, 3, self.channels, self.features))
            self.param('ths_' + str(it), self.initths, (1, 1, 1, 1))

    def gkern(kernlen=21, std=3):
        """Returns a 2D Gaussian kernel array."""
        gkern1d = signal.gaussian(kernlen, std=std).reshape(kernlen, 1)
        gkern2d = np.outer(gkern1d, gkern1d)
        kernel = jnp.array(np.sqrt(gkern2d[..., np.newaxis, np.newaxis]))
        kernel = jnp.repeat(kernel, 3, axis=3)
        return kernel / (kernel).sum() * 3

    def blur_forward(self, x):
        conv_output = jax.lax.conv_general_dilated(x,  # lhs = NCHW image tensor
                                                   jnp.transpose(self.blur_kernel[..., :self.channels], [3, 2, 0, 1]),
                                                   (1, 1),
                                                   'SAME',
                                                   feature_group_count=self.channels,
                                                   batch_group_count=1)
        return conv_output

    def blur_backward(self, x):
        conv_output = jax.lax.conv_general_dilated(x,  # lhs = NCHW image tensor
                                                   jnp.transpose(self.blur_kernel[..., :self.channels], [3, 2, 0, 1]),
                                                   (1, 1),
                                                   'SAME',
                                                   feature_group_count=self.channels,
                                                   batch_group_count=1)
        return conv_output

    def prox_adjoint(self, x, gamma, lamb):
        return x - gamma * jaxopt.prox.prox_lasso(x / gamma, lamb / gamma)

    def compute_conv_lip(self, h, Nf=(32, 32)):
        r'''
        Computes the lipschitz constant of the convolution with kernel h.
        More precisely, let L:x \mapsto h*x, then this function returns
        ||L||^2.
        '''
        H = jnp.fft.fft2(jnp.transpose(h, (2, 3, 1, 0)), Nf, axes=[0, 1])
        s = jnp.linalg.svd(H, compute_uv=False)
        return s.max()

    def apply_mask(self, x):
        return self.mask * x

    def identity_fun(self, x):
        return x

    def forward_op(self, x,
                   type_operator):  # This is ugly. Idealy we'd like to work with dicts but this is difficult atm in jax.
        branches = [lambda x: self.identity_fun(x),
                    lambda x: self.apply_mask(x),
                    lambda x: self.blur_forward(x)]
        out = jax.lax.switch(type_operator[0], branches, x)
        return out

    def backward_op(self, x,
                    type_operator):  # This is ugly. Idealy we'd like to work with dicts but this is difficult atm in jax.
        branches = [lambda x: self.identity_fun(x),
                    lambda x: self.apply_mask(x),
                    lambda x: self.blur_backward(x)]
        out = jax.lax.switch(type_operator[0], branches, x)
        return out

    def PD_step(self, w_conv, ths_param, y, x_prev, u, type_operator):
        # Convolutional layer params
        # w_conv, ths_param = params['weight_' + str(it)], params['ths_' + str(it)]

        lip_conv = self.compute_conv_lip(w_conv)  #
        gamma = 1 / (1.2 * lip_conv)  # 1.2
        tau = 0.5

        x = x_prev - tau * self.backward_op(self.forward_op(x_prev, type_operator) - y,
                                            type_operator) - tau * jax.lax.conv_transpose(u,  # lhs = NCHW image tensor
                                                                                          jnp.transpose(w_conv,
                                                                                                        [3, 2, 0, 1]),
                                                                                          (1, 1),  # window strides
                                                                                          'SAME',
                                                                                          transpose_kernel=True,
                                                                                          dimension_numbers=(
                                                                                              'NCHW', 'OIHW', 'NCHW'))

        Wx = jax.lax.conv(2 * x - x_prev,  # lhs = NCHW image tensor
                          jnp.transpose(w_conv, [3, 2, 0, 1]),
                          (1, 1),  # window strides
                          'SAME')

        z = u + gamma * Wx  # padding mode  # kernel should be odd

        u = self.prox_adjoint(z, gamma, ths_param)

        cost = 0.5 * jnp.sum((self.forward_op(x, type_operator) - y) ** 2) + ths_param * jnp.sum(jnp.abs(Wx))

        return x, u, cost

    def debug(self, y, type_operator):
        out = self.backward_op(self.forward_op(y, type_operator), type_operator)

        return out

    @nn.compact
    def __call__(self, y, type_operator, check=False):  # for now we remove this, type_operator):
        u = jax.lax.conv(y,  # lhs = NCHW image tensor
                         jnp.transpose(self.variables['params']['weight_0'], [3, 2, 0, 1]),
                         (1, 1),  # window strides
                         'SAME')
        x = y

        for it in range(self.num_iter):
            x, u, cost = self.PD_step(self.variables['params']['weight_'+str(it)], self.variables['params']['ths_'+str(it)], y, x, u, type_operator)

        return x

    def apply(self, params, y, type_operator, check=False):
        """
        Necessary to rewrite for appropriate backprop through the unfolded parameters.
        """
        u = jax.lax.conv(y,  # lhs = NCHW image tensor
                         jnp.transpose(params['params']['weight_0'], [3, 2, 0, 1]),
                         (1, 1),  # window strides
                         'SAME')

        x = y

        for it in range(self.num_iter):
            x, u, cost = self.PD_step(params['params']['weight_'+str(it)], params['params']['ths_'+str(it)], y, x, u, type_operator)

        return x
