import torch
import torch.nn.functional as F
import torch.nn as nn


class Multi_grid_iter(torch.nn.Module):
    def __init__(self, error_threshold=0.0001, max_iter_num=200000, down_layer_num=2, smoothing_num=2, use_cuda=True):
        super(Multi_grid_iter, self).__init__()
        self.error_threshold = error_threshold
        self.max_iter_num = max_iter_num
        self.down_layer_num = down_layer_num
        self.smoothing_num = smoothing_num
        self.Kernel = torch.Tensor([[[[0, 1 / 4, 0], [1 / 4, 0, 1 / 4], [0, 1 / 4, 0]]]])
        self.restriction_Kernel = torch.Tensor([[[[0, 1, 0],
                                                  [1, 4, 1],
                                                  [0, 1, 0]]]]) / 8.0

        self.use_cuda = use_cuda
        if use_cuda:
            self.Kernel = self.Kernel.cuda()
            self.restriction_Kernel = self.restriction_Kernel.cuda()

    def base_step(self, x, G, b, f):
        h = 1.0 / (x.shape[-1] - 1)
        x = F.conv2d(x, self.Kernel, padding=1)
        x = x + h * h * f * 0.25
        x = x * G + (1 - G) * b
        return x

    def residual(self, x, G, b, f):
        y = self.base_step(x, G, b, f)
        res = y - x
        return res

    def restriction(self, x, G):
        x = F.conv2d(x[:, :, 1:-1, 1:-1], self.restriction_Kernel, stride=2)
        x = F.pad(x, (1, 1, 1, 1))
        x = x * G
        return x

    def interpolation(self, x, G):
        new_size = x.size()[-1] * 2 - 1
        x = F.interpolate(x, size=new_size, mode='bilinear', align_corners=True)
        x = x * G
        return x


    def multi_grid_step(self, x, Gs, b, f, layer):
        for i in range(self.smoothing_num):
            x = self.base_step(x, Gs[layer], b, f)

        if layer + 1 < self.down_layer_num:
            res = self.residual(x, Gs[layer], b, f)
            res = self.restriction(res, Gs[layer+1])
            nex_x = torch.rand(res.size()) * 0
            if self.use_cuda: nex_x = nex_x.cuda()
            nex_b = torch.zeros_like(nex_x)
            ek = self.multi_grid_step(nex_x, Gs, nex_b, res, layer + 1)
            ek = self.interpolation(ek, Gs[layer])
            x = x + ek
        for i in range(self.smoothing_num):
            x = self.base_step(x, Gs[layer], b, f)
        return x

    def forward(self, U_0, Gs, bs, fs):
        iter_num = self.max_iter_num
        for i in range(self.max_iter_num):
            U_1 = self.base_step(U_0, Gs[0], bs[0], fs[0])
            U_1 = self.multi_grid_step(U_1, Gs, bs[0], fs[0], 0)
            error = (U_1 - U_0).abs().max().item()
            if error < self.error_threshold:
                iter_num = i + 1
                break
            U_0 = U_1
        return U_1, iter_num



class Unet_iter(torch.nn.Module):
    def __init__(self, error_threshold=0.0001, max_iter_num=200000, down_layer_num=2, smoothing_num=2, use_cuda=True, alpha=0.8):
        super(Unet_iter, self).__init__()
        self.error_threshold = error_threshold
        self.max_iter_num = max_iter_num
        self.down_layer_num = down_layer_num
        self.smoothing_num = smoothing_num
        self.alpha = alpha
        self.Kernel = torch.Tensor([[[[0, 1 / 4, 0], [1 / 4, 0, 1 / 4], [0, 1 / 4, 0]]]])
        self.use_cuda = use_cuda
        if use_cuda:
            self.Kernel = self.Kernel.cuda()
        first_layers = []
        for i in range(down_layer_num):
            for j in range(smoothing_num):
                first_layers.append(nn.Conv2d(1, 1, 3, stride=1, padding=1, bias=False))
        self.first_layers = nn.ModuleList(first_layers)

        pooling_layers = []
        for i in range(down_layer_num - 1):
            pooling_layers.append(nn.Conv2d(1, 1, 3, stride=2, padding=0, bias=False))
        self.pooling_layers = nn.ModuleList(pooling_layers)

        second_layers = []
        for i in range(down_layer_num):
            for j in range(smoothing_num):
                second_layers.append(nn.Conv2d(1, 1, 3, stride=1, padding=1, bias=False))
        self.second_layers = nn.ModuleList(second_layers)

    def H(self, r, Gs):
        intermediates = []
        for i in range(self.down_layer_num):
            for j in range(self.smoothing_num):
                idx = i * self.smoothing_num + j
                r = self.first_layers[idx](r) * Gs[i]
            intermediates.append(r)
            if i + 1 < self.down_layer_num:
                r = F.pad(r, (1, 1, 1, 1))
                r = self.pooling_layers[i](r)
                r = r * Gs[i + 1]

        for i in range(self.down_layer_num):
            for j in range(self.smoothing_num):
                idx = i * self.smoothing_num + j
                r = self.second_layers[idx](r) * Gs[self.down_layer_num - i - 1]
            r = r + intermediates[self.down_layer_num - i - 1]
            if i + 1 < self.down_layer_num:
                new_size = r.size(-1) * 2 - 1
                r = F.interpolate(r, size=new_size, mode='bilinear', align_corners=True) *  Gs[self.down_layer_num - i - 2]
        return r

    def forward(self, U_0, Gs, bs, fs, iter_num):
        h = 1.0 / (U_0.shape[-1] - 1)
        z = torch.zeros_like(U_0)
        if self.use_cuda: z = z.cuda()
        for i in range(iter_num):
            U_1 = F.conv2d(U_0, self.Kernel, padding=1)
            U_1 = U_1 + h * h * fs[0] * 0.25
            U_1 = U_1 * Gs[0] + (1 - Gs[0]) * bs[0]
            w = U_1 - U_0
            z = self.H(w, Gs) * self.alpha + (1 - self.alpha) * z
            U_0 = U_1 + Gs[0] * z
        return U_0

    def evaluation(self, U_0, Gs, bs, fs):
        h = 1.0 / (U_0.shape[-1] - 1)
        iter_num = self.max_iter_num
        z = torch.zeros_like(U_0)
        if self.use_cuda: z = z.cuda()
        for i in range(self.max_iter_num):
            U_1 = F.conv2d(U_0, self.Kernel, padding=1)
            U_1 = U_1 + h * h * fs[0] * 0.25
            U_1 = U_1 * Gs[0] + (1 - Gs[0]) * bs[0]
            w = U_1 - U_0
            z = self.H(w, Gs) * self.alpha + (1 - self.alpha) * z
            U_1 = U_1 + Gs[0] * z
            error = (U_1 - U_0).abs().max().item()
            if error < self.error_threshold:
                iter_num = i + 1
                break
            U_0 = U_1
        return U_1, iter_num
