#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from random import randint
import torch
import torch.nn as nn
import torch.nn.functional as F

class Donut3x3(nn.Module):
    def __init__(self, channel):
        super().__init__()
        self.conv = nn.Conv2d(channel, channel, 3, padding=1, 
                              padding_mode="reflect", bias=False, groups=channel)
        self.conv.weight.data = torch.zeros_like(self.conv.weight.data)
        self.conv.weight.data[:, :, 0, 1] = 0.25
        self.conv.weight.data[:, :, 1, 0] = 0.25
        self.conv.weight.data[:, :, 1, 2] = 0.25
        self.conv.weight.data[:, :, 2, 1] = 0.25
        self.conv.requires_grad_(False)

    def forward(self, x):
        return self.conv(x)

class Noise2Self(nn.Module):
    def __init__(self, model, color=False):
        super().__init__()
        self.model = model
        self.n = 4
        self.donut3x3 = Donut3x3(1 if not color else 3)
    
    def one_pass(self, x, random=True, i=0, j=0):
        if random:
            i, j = randint(0, self.n-1), randint(0, self.n-1)
        
        x_den, mask = torch.zeros_like(x), torch.zeros_like(x)
        mask[..., i::self.n, j::self.n] = 1
        x_den[..., i::self.n, j::self.n] = self.model(self.donut3x3(x) * mask + (1-mask) * x)[..., i::self.n, j::self.n]
        return x_den

    def forward(self, x, sigma=None):
        _, _, h, w = x.size()
        r1, r2 = h % 8, w % 8
        x = F.pad(x, pad=(0, 8-r2 if r2 > 0 else 0, 0, 8-r1 if r1 > 0 else 0), mode='reflect')

        if self.training:
            output = self.one_pass(x, random=True)
        else:
            output = torch.zeros_like(x)
            for i in range(self.n):
                for j in range(self.n):
                    x_den = self.one_pass(x, random=False, i=i, j=j)
                    output = output + x_den
        return output[..., :h, :w]