import torch
import torch.nn.functional as F



class JOR_iter(torch.nn.Module):
    def __init__(self, error_threshold=0.00001, max_iter_num=20000, use_cuda=True):
        super(JOR_iter,self).__init__()
        self.error_threshold = error_threshold
        self.max_iter_num = max_iter_num
        self.Kernel = torch.Tensor([[[[0,1.0,0],[1.0,0,1.0],[0,1.0,0]]]])
        self.k = 3
        if use_cuda:
            self.Kernel = self.Kernel.cuda()
    
    def forward(self, U_0, G, b, N, f=None):
        self.G = G
        self.f = f
        self.b = b
        self.N = N
        self.h = 1 / N
        iter_num = self.max_iter_num
        for i in range(self.max_iter_num):
            U_1 =  F.conv2d(U_0, self.Kernel / (4 - self.k**2 * self.h ** 2), padding=1)
            if self.f is not None:
                U_1 = U_1 + self.f * self.h ** 2 / (self.k**2 * self.h ** 2 - 4)
            U_1 = U_1 * self.G + (1 - self.G) * self.b
            error = (U_1 - U_0).abs().max().item()
          #  print(i, error)
            if error < self.error_threshold:
                iter_num = i + 1
                break
            U_0 = U_1
         #   print(i)
        return U_1, iter_num