import utils.Spectral_Normalize_chen as chen
from operators.single_coil_mri import *
import torch.autograd as autograd


class DnCNN(nn.Module):
    def __init__(self, channels, num_of_layers=17, features=64, lip=1.0):
        super(DnCNN, self).__init__()
        kernel_size = 3
        padding = 1
        layers = [chen.spectral_norm(
            nn.Conv2d(in_channels=channels, out_channels=features, kernel_size=kernel_size, padding=padding,
                      bias=False)), nn.ReLU(inplace=True)]
        for _ in range(num_of_layers - 2):
            layers.append(chen.spectral_norm(
                nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding,
                          bias=False)))
            layers.append(nn.BatchNorm2d(features))
            layers.append(nn.ReLU(inplace=True))
        layers.append(chen.spectral_norm(
            nn.Conv2d(in_channels=features, out_channels=channels, kernel_size=kernel_size, padding=padding,
                      bias=False)))
        self.dncnn = nn.Sequential(*layers)

    def forward(self, x):
        out = self.dncnn(x)
        return x - out



class LU_prox(nn.Module):
    def __init__(self, linear_op, DnCNN, args):
        super(LU_prox, self).__init__()
        self.eta = nn.Parameter(torch.ones(1) * args.eta)
        self.linear_op = linear_op
        self.R = DnCNN
        self.maxLayers = args.maxiters

    def _linear_op(self, x):
        return self.linear_op.forward(x)

    def _linear_adjoint(self, x):
        return self.linear_op.adjoint(x)

    def set_initial_point(self, y):
        self.initial_point = self._linear_adjoint(y)

    def get_gradient(self, z, y):
        return self.linear_op.gramian(z) - self._linear_adjoint(y)

    def forward_module(self, x, y):
        return x - torch.exp(self.eta) * self.get_gradient(x, y)

    def forward(self, xk, y):
        for i in range(self.maxLayers):
            grad_step = self.forward_module(xk, y)
            xk = self.R(grad_step)
        return xk

    def forward_singlestep(self, xk, y):
        grad_step = self.forward_module(xk, y)
        xk = self.R(grad_step)
        return xk


class LU_prox_CT(nn.Module):
    def __init__(self, linear_op, DnCNN, args):
        super(LU_prox_CT, self).__init__()
        self.eta = nn.Parameter(torch.ones(1) * args.eta)
        self.linear_op = linear_op
        self.R = DnCNN
        self.maxLayers = args.maxiters

    def _linear_op(self, x):
        return self.linear_op.forward(x)

    def _linear_adjoint(self, x):
        return self.linear_op.adjoint(x)

    def set_initial_point(self, y):
        self.initial_point = self._linear_adjoint(y)

    def get_gradient(self, z, y):
        return self.linear_op.gramian(z) - self._linear_adjoint(y)

    def forward_module(self, x, y):
        return x - torch.exp(self.eta) * self.get_gradient(x, y)

    def forward(self, xk, y):
        for i in range(self.maxLayers):
            grad_step = self.forward_module(xk, y)
            xk = self.R(grad_step[:, None, :, :])[:, 0, :, :]
        return xk

    def forward_singlestep(self, xk, y):
        grad_step = self.forward_module(xk, y)
        xk = self.R(grad_step[:, None, :, :])[:, 0, :, :]
        return xk
