import torch
import torch.nn.functional as F
import torch.nn as nn


class JOR_iter(torch.nn.Module):
    def __init__(self, error_threshold=0.00001, max_iter_num=20000, use_cuda=True):
        super(JOR_iter, self).__init__()
        self.error_threshold = error_threshold
        self.max_iter_num = max_iter_num
        self.a = 1
        self.Kernel = torch.Tensor([[[[0, 2 * self.a, 0], [1.0, 0, -1.0], [0, 2 * self.a, 0]]]])
        if use_cuda:
            self.Kernel = self.Kernel.cuda()

    def forward(self, U_0, G, b, f=None):
        h = 1.0 / (U_0.shape[-1] - 1)
        iter_num = self.max_iter_num
        for i in range(self.max_iter_num):
            Kernel = self.Kernel * 1
            Kernel[0, 0, 1] = Kernel[0, 0, 1] * h
            #  print(Kernel / (4 * self.a))
            U_1 = F.conv2d(U_0, Kernel / (4 * self.a), padding=1)
            if f is not None:
                U_1 = U_1 + h * h * f / (2 * self.a)
            U_1 = U_1 * G + (1 - G) * b
            error = (U_1 - U_0).abs().max().item()
            if error < self.error_threshold:
                iter_num = i + 1
                break
            U_0 = U_1
        return U_1, iter_num



class Conv_iter(torch.nn.Module):
    def __init__(self, error_threshold=0.00001, max_iter_num=20000, Hs = [1], use_cuda=True):
        super(Conv_iter, self).__init__()
        self.error_threshold = error_threshold
        self.max_iter_num = max_iter_num
        self.conv_num = len(Hs) + 1
        self.a = 1
        self.Kernel = torch.Tensor([[[[0, 2 * self.a, 0], [1.0, 0, -1.0], [0, 2 * self.a, 0]]]])
        if use_cuda:
            self.Kernel = self.Kernel.cuda()
        self.H = nn.ModuleList()
        in_channel = 1
        for out_channel in Hs:
            self.H.append(nn.Conv2d(in_channel, out_channel, (3, 3), padding=1, bias=False))
            in_channel = out_channel
        self.H = nn.Sequential(*self.H)
       # print(self.H)


    def forward(self, U_0, G, b, iter_num, f=None):
        h = 1.0 / (U_0.shape[-1] - 1)
        for i in range(iter_num):
            Kernel = self.Kernel * 1
            Kernel[0, 0, 1] = Kernel[0, 0, 1] * h
            #  print(Kernel / (4 * self.a))
            U_1 = F.conv2d(U_0, Kernel / (4 * self.a), padding=1)
            if f is not None:
                U_1 = U_1 + h * h * f / (2 * self.a)
            U_1 = U_1 * G + (1 - G) * b
            w = U_1 - U_0
            U_0 = U_1 + G * self.H(w)
        return U_0

    def evaluation(self, U_0, G, b, f=None):
        h = 1.0 / (U_0.shape[-1] - 1)
        iter_num = self.max_iter_num
        for i in range(self.max_iter_num):
            Kernel = self.Kernel * 1
            Kernel[0, 0, 1] = Kernel[0, 0, 1] * h
            #  print(Kernel / (4 * self.a))
            U_1 = F.conv2d(U_0, Kernel / (4 * self.a), padding=1)
            if f is not None:
                U_1 = U_1 + h * h * f / (2 * self.a)
            U_1 = U_1 * G + (1 - G) * b
            w = U_1 - U_0
            U_1 = U_1 + G * self.H(w)
            error = (U_1 - U_0).abs().max().item()
            if error < self.error_threshold:
                iter_num = i + 1
                break
            U_0 = U_1
        return U_1, iter_num
