import utils.Spectral_Normalize_chen as chen
from operators.single_coil_mri import *


class single_layer(nn.Module):
    def __init__(self, num_channel=1, num_features=64, kernel_size=7, stride=1, padding=3, img_dim=320):
        super(single_layer, self).__init__()

        self.data_layer = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(num_channel, num_features // 2, kernel_size, stride, padding)),
            nn.GroupNorm(4, num_features // 2),
            nn.LeakyReLU(0.1, inplace=True),
            nn.utils.spectral_norm(nn.Conv2d(num_features // 2, num_features // 2, kernel_size, stride, padding)),
            nn.GroupNorm(4, num_features // 2),
            nn.LeakyReLU(0.1, inplace=True)
        )

        # mixing layer
        model = []
        for _ in range(2):
            model += [nn.utils.spectral_norm(nn.Conv2d(num_features, num_features, kernel_size, stride, padding)),
                      nn.GroupNorm(4, num_features),
                      nn.LeakyReLU(0.1, inplace=True)]
        model += [nn.utils.spectral_norm(nn.Conv2d(num_features, num_channel, kernel_size, stride, padding))]
        self.mixing_layer = nn.Sequential(*model)

        self._initialize_weights()
        self.img_dim = img_dim
        self.nc = num_channel

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
                # nn.init.xavier_uniform_(m.weight, 0.0001)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, inj, xk):
        inp = self.data_layer(xk)
        return self.mixing_layer(torch.cat((inp, inj), dim=1))


class DEQIPFixedPoint(nn.Module):
    def __init__(self, g, anderson, **kwargs):
        super(DEQIPFixedPoint, self).__init__()
        self.g = g
        self.solver = anderson
        self.kwargs = kwargs

    #     def _initialize_weights(self):
    #         for m in self.modules():
    #             if isinstance(m, nn.Conv2d):
    #                 # nn.init.xavier_uniform_(m.weight, 0.001)  # kaiming_normal_
    #                 nn.init.kaiming_normal_(m.weight)
    #             elif isinstance(m, nn.BatchNorm2d):
    #                 nn.init.ones_(m.weight)
    #                 nn.init.zeros_(m.bias)

    def forward(self, y, x0, train=True, zero_init=False):

        with torch.no_grad():
            if zero_init:
                x0 = torch.zeros_like(x0)
            xk, forward_res = self.solver(lambda xk: self.g(xk, y), x0, **self.kwargs)

        if train:
            # attach gradients
            xk = self.g(xk, y)
        return xk


class DEQFixedPoint(nn.Module):
    def __init__(self, g, anderson, in_channels=4, out_channels=32, kernel_size=3, stride=1, padding=1, **kwargs):
        super(DEQFixedPoint, self).__init__()
        self.g = g
        self.solver = anderson
        self.kwargs = kwargs

        self.inputInj = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)),
            # nn.GroupNorm(4, out_channels),
            nn.LeakyReLU(0.1, inplace=True),
            nn.utils.spectral_norm(nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding)),
            # nn.GroupNorm(4, out_channels),
            nn.LeakyReLU(0.1, inplace=True)
        )
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # nn.init.xavier_uniform_(m.weight, 0.001)  # kaiming_normal_
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x0, train):
        with torch.no_grad():
            inj = self.inputInj(x0)
            xk, self.forward_res = self.solver(lambda xk: self.g(inj, xk), torch.zeros_like(x0), **self.kwargs)

        if train:
            inj = self.inputInj(x0)
            xk = self.g(inj, xk)
        return xk


class inverse_block_mri_prox(nn.Module):
    """ compute x[k+1] = forward_module(x[k]) - eta * R(x[k])"""

    def __init__(self, linear_op, deq_inner, args):
        super(inverse_block_mri_prox, self).__init__()
        self.shared_eta = args.shared_eta
        if args.shared_eta:
            self.eta = nn.Parameter(torch.ones(1) * args.eta)
        else:
            self.eta = nn.Parameter(torch.ones(args.maxiters) * args.eta)
        self.linear_op = linear_op
        self.R = deq_inner

    def _linear_op(self, x):
        return self.linear_op.forward(x)

    def _linear_adjoint(self, x):
        return self.linear_op.adjoint(x)

    def set_initial_point(self, y):
        self.initial_point = self._linear_adjoint(y)

    def get_gradient(self, z, y):
        return self.linear_op.gramian(z) - self._linear_adjoint(y)

    def forward_module(self, x, y, k):
        eta = self.eta[0] if self.shared_eta else self.eta[k]
        return x - torch.exp(eta) * self.get_gradient(x, y)

    def forward(self, xk, y, k, train):
        forward_res = self.forward_module(xk, y, k)
        return self.R(forward_res, train) + forward_res


class inverse_block_mri(nn.Module):
    def __init__(self, linear_op, DnCNN, args):
        super(inverse_block_mri, self).__init__()
        self.eta = nn.Parameter(torch.ones(1) * args.eta)
        self.linear_op = linear_op
        self.R = DnCNN

    def _linear_op(self, x):
        return self.linear_op.forward(x)

    def _linear_adjoint(self, x):
        return self.linear_op.adjoint(x)

    def set_initial_point(self, y):
        self.initial_point = self._linear_adjoint(y)

    def get_gradient(self, z, y):
        return self.linear_op.gramian(z) - self._linear_adjoint(y) - self.R(z)

    def forward(self, xk, y):
        # self.set_initial_point(y)  # [8, 2, 320, 320]
        return xk - self.eta * self.get_gradient(xk, y)


class single_layer_CT(nn.Module):
    def __init__(self, num_channel=1, num_features=64, kernel_size=7, stride=1, padding=3, img_dim=320):
        super(single_layer_CT, self).__init__()
        self.data_layer = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(num_channel, num_features // 2, kernel_size, stride, padding)),
            nn.GroupNorm(4, num_features // 2),
            nn.LeakyReLU(0.1, inplace=True),
            nn.utils.spectral_norm(nn.Conv2d(num_features // 2, num_features // 2, kernel_size, stride, padding)),
            nn.GroupNorm(4, num_features // 2),
            nn.LeakyReLU(0.1, inplace=True)
        )

        # mixing layer
        model = []
        for _ in range(2):  # change to 3 for better results
            model += [nn.utils.spectral_norm(nn.Conv2d(num_features, num_features, kernel_size, stride, padding)),
                      nn.GroupNorm(4, num_features),
                      nn.LeakyReLU(0.1, inplace=True)]
        model += [nn.utils.spectral_norm(nn.Conv2d(num_features, num_channel, kernel_size, stride, padding))]
        self.mixing_layer = nn.Sequential(*model)

        self._initialize_weights()
        self.img_dim = img_dim
        self.nc = num_channel

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
                # nn.init.xavier_uniform_(m.weight, 0.0001)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, inj, xk):
        inp = self.data_layer(xk)
        return self.mixing_layer(torch.cat((inp, inj), dim=1))


class LUSER_SW_CT(nn.Module):
    def __init__(self, linear_op, deq, args):
        super(LUSER_SW_CT, self).__init__()
        self.eta = nn.Parameter(torch.ones(1) * args.eta)
        self.linear_op = linear_op
        self.R = deq
        self.maxiters = args.maxiters

    def _linear_op(self, x):
        return self.linear_op.forward(x)

    def _linear_adjoint(self, x):
        return self.linear_op.adjoint(x)

    def set_initial_point(self, y):
        self.initial_point = self._linear_adjoint(y)

    def get_gradient(self, z, y):
        return self.linear_op.gramian(z) - self._linear_adjoint(y)

    def forward_module(self, x, y):
        return x - torch.exp(self.eta) * self.get_gradient(x, y)

    def forward(self, x0, y, train):
        xk = x0
        for i in range(self.maxiters):
            grad = self.forward_module(xk, y)
            grad = grad.unsqueeze(1) if len(grad.shape) < 4 else grad
            xk = (self.R(grad, train) + grad)[:, 0, :, :]
        return xk

    def forward_singlestep(self, x0, y, i, train):
        grad = self.forward_module(x0, y)
        grad = grad.unsqueeze(1) if len(grad.shape) < 4 else grad
        xk = (self.R(grad, train) + grad)[:, 0, :, :]
        return xk


class DEQFixedPoint_MRI(nn.Module):
    def __init__(self, g, anderson, in_channels=4, out_channels=32, kernel_size=3, stride=1, padding=1, **kwargs):
        super(DEQFixedPoint_MRI, self).__init__()
        self.g = g
        self.solver = anderson
        self.kwargs = kwargs

        self.inputInj = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)),
            nn.GroupNorm(4, out_channels),  # uncomment for LUSER-DW-MRI
            nn.LeakyReLU(0.1, inplace=True),
            nn.utils.spectral_norm(nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding)),
            nn.GroupNorm(4, out_channels),  # uncomment for LUSER-DW-MRI
            nn.LeakyReLU(0.1, inplace=True),
        )
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # nn.init.xavier_uniform_(m.weight, 0.001)  # kaiming_normal_
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x0, train):
        with torch.no_grad():
            # inj = self.inputInj(torch.cat((x0, r0), dim=1))
            inj = self.inputInj(x0)
            xk, self.forward_res = self.solver(lambda xk: self.g(inj, xk), torch.zeros_like(x0), **self.kwargs)

        if train:
            # attach gradients
            # inj = self.inputInj(torch.cat((x0, r0), dim=1))
            inj = self.inputInj(x0)
            xk = self.g(inj, xk)
        return xk


class inverse_block_full(nn.Module):
    def __init__(self, linear_op, deq_list, args):
        super(inverse_block_full, self).__init__()
        self.eta = nn.Parameter(torch.ones(1) * args.eta)
        self.linear_op = linear_op
        self.R_list = nn.ModuleList(deq_list)
        self.maxiters = args.maxiters

    def _linear_op(self, x):
        return self.linear_op.forward(x)

    def _linear_adjoint(self, x):
        return self.linear_op.adjoint(x)

    def set_initial_point(self, y):
        self.initial_point = self._linear_adjoint(y)

    def get_gradient(self, z, y):
        return self.linear_op.gramian(z) - self._linear_adjoint(y)

    def forward_module(self, x, y):
        return x - torch.exp(self.eta) * self.get_gradient(x, y)

    def forward(self, x0, y, train):
        self.set_initial_point(y)
        xk = x0
        for i in range(self.maxiters):
            grad = self.forward_module(xk, y)
            xk = self.R_list[i](grad, train) + grad
        return xk

    def forward_singlestep(self, x0, y, i, train):
        grad = self.forward_module(x0, y)
        xk = self.R_list[i](grad, train) + grad
        return xk


class inverse_block_reverse_full(nn.Module):
    def __init__(self, linear_op, deq1, deq2, deq3, deq4, deq5, deq6, deq7, deq8, args):
        super(inverse_block_reverse_full, self).__init__()
        self.eta = nn.Parameter(torch.ones(1) * args.eta)
        self.linear_op = linear_op
        self.R1 = deq1
        self.R2 = deq2
        self.R3 = deq3
        self.R4 = deq4
        self.R5 = deq5
        self.R6 = deq6
        self.R7 = deq7
        self.R8 = deq8

    def _linear_op(self, x):
        return self.linear_op.forward(x)

    def _linear_adjoint(self, x):
        return self.linear_op.adjoint(x)

    def set_initial_point(self, y):
        self.initial_point = self._linear_adjoint(y)

    def get_gradient(self, z, y):
        return self.linear_op.gramian(z) - self._linear_adjoint(y)

    def forward_module(self, x, y):
        return x - torch.exp(self.eta) * self.get_gradient(x, y)

    def forward(self, x0, y, train):  # [8, 2, 320, 320], [8, 320, 320, 2],
        self.set_initial_point(y)  # [8, 2, 320, 320]

        grad = self.forward_module(x0, y)
        xk = self.R1(grad, train) + grad  # patch: [4, 32, 64, 64]
        grad = self.forward_module(xk, y)
        xk = self.R2(grad, train) + grad  # patch: [4, 32, 64, 64]
        grad = self.forward_module(xk, y)
        xk = self.R3(grad, train) + grad  # patch: [4, 32, 64, 64]
        grad = self.forward_module(xk, y)
        xk = self.R4(grad, train) + grad  # patch: [4, 32, 64, 64]
        grad = self.forward_module(xk, y)
        xk = self.R5(grad, train) + grad  # patch: [4, 32, 64, 64]
        grad = self.forward_module(xk, y)
        xk = self.R6(grad, train) + grad  # patch: [4, 32, 64, 64]
        grad = self.forward_module(xk, y)
        xk = self.R7(grad, train) + grad  # patch: [4, 32, 64, 64]
        grad = self.forward_module(xk, y)
        xk = self.R8(grad, train) + grad  # patch: [4, 32, 64, 64]

        return xk


class inverse_block_aux(nn.Module):
    def __init__(self, linear_op, deq_list, args, maxiters=None):
        super(inverse_block_aux, self).__init__()
        self.eta = nn.Parameter(torch.ones(1) * args.eta)
        self.linear_op = linear_op
        self.R_list = nn.ModuleList(deq_list)
        self.maxiters = args.maxiters if maxiters is None else maxiters

    def _linear_op(self, x):
        return self.linear_op.forward(x)

    def _linear_adjoint(self, x):
        return self.linear_op.adjoint(x)

    def set_initial_point(self, y):
        self.initial_point = self._linear_adjoint(y)

    def get_gradient(self, z, y):
        return self.linear_op.gramian(z) - self._linear_adjoint(y)

    def forward_module(self, x, y):
        return x - torch.exp(self.eta) * self.get_gradient(x, y)

    def forward(self, x0, y, train):
        xk = x0
        for i in range(self.maxiters):
            grad = self.forward_module(xk, y)
            xk = self.R_list[i](grad, train) + grad
        return xk

    def forward_singlestep(self, x0, y, i, train):
        grad = self.forward_module(x0, y)
        xk = self.R_list[i](grad, train) + grad
        return xk


class LUSER_DW_CT(nn.Module):
    def __init__(self, linear_op, deq_list, args):
        super(LUSER_DW_CT, self).__init__()
        self.eta = nn.Parameter(torch.ones(1) * args.eta)
        self.linear_op = linear_op
        self.R_list = nn.ModuleList(deq_list)
        self.maxiters = args.maxiters

    def _linear_op(self, x):
        return self.linear_op.forward(x)

    def _linear_adjoint(self, x):
        return self.linear_op.adjoint(x)

    def set_initial_point(self, y):
        self.initial_point = self._linear_adjoint(y)

    def get_gradient(self, z, y):
        return self.linear_op.gramian(z) - self._linear_adjoint(y)

    def forward_module(self, x, y):
        return x - torch.exp(self.eta) * self.get_gradient(x, y)

    def forward(self, x0, y, train):
        xk = x0
        for i in range(self.maxiters):
            grad = self.forward_module(xk, y)
            grad = grad.unsqueeze(1) if len(grad.shape) < 4 else grad
            xk = (self.R_list[i](grad, train) + grad)[:, 0, :, :]
        return xk

    def forward_singlestep(self, x0, y, i, train):
        grad = self.forward_module(x0, y)
        grad = grad.unsqueeze(1) if len(grad.shape) < 4 else grad
        xk = (self.R_list[i](grad, train) + grad)[:, 0, :, :]
        return xk
