import numpy as np
import torch
class Masker_color():
    """Object for masking and demasking"""

    def __init__(self, width=3, mode='zero', infer_single_pass=False, include_mask_as_input=False):
        self.grid_size = width
        self.n_masks = width ** 2

        self.mode = mode
        self.infer_single_pass = infer_single_pass
        self.include_mask_as_input = include_mask_as_input

    def mask(self, X, i):

        phasex = i % self.grid_size
        phasey = (i // self.grid_size) % self.grid_size
        mask = pixel_grid_mask(X[0, 0].shape, self.grid_size, phasex, phasey)
        mask = mask.to(X.device)

        mask_inv = torch.ones(mask.shape).to(X.device) - mask

        if self.mode == 'interpolate':
            masked = interpolate_mask_color(X, mask, mask_inv)
        elif self.mode == 'zero':
            masked = X * mask_inv
        else:
            raise NotImplementedError
            
        if self.include_mask_as_input:
            net_input = torch.cat((masked, mask.repeat(X.shape[0], 1, 1, 1)), dim=1)
        else:
            net_input = masked

        return net_input, mask

    def __len__(self):
        return self.n_masks

    def infer_full_image(self, X, model):


        net_input, mask,masked_inv = self.mask(X, 0)
        net_output = model(net_input).detach().cpu()

        acc_tensor = torch.zeros(net_output.shape).cpu()

        for i in range(self.n_masks):
            net_input, mask, masked_inv = self.mask(X, i)
            net_output = model(net_input).detach()
            acc_tensor = acc_tensor + (net_output * mask).cpu()

        return acc_tensor

def pixel_grid_mask(shape, patch_size, phase_x, phase_y):
    A = torch.zeros(shape[-2:])
    for i in range(shape[-2]):
        for j in range(shape[-1]):
            if (i % patch_size == phase_x and j % patch_size == phase_y):
                A[i, j] = 1
    return torch.Tensor(A)



def interpolate_mask(tensor, mask, mask_inv):
    device = tensor.device

    mask = mask.to(device)

    kernel = np.array([[0.5, 1.0, 0.5], [1.0, 0.0, 1.0], (0.5, 1.0, 0.5)])
    kernel = kernel[np.newaxis, np.newaxis, :, :]
    kernel = torch.Tensor(kernel).to(device)
    kernel = kernel / kernel.sum()

    filtered_tensor = torch.nn.functional.conv2d(tensor, kernel, stride=1, padding=1)

    return filtered_tensor * mask + tensor * mask_inv



def interpolate_mask_color(tensor, mask, mask_inv):
    device = tensor.device

    mask = mask.to(device)

    kernel = np.array([[0.5, 1.0, 0.5], [1.0, 0.0, 1.0], (0.5, 1.0, 0.5)])
    kernel = kernel[np.newaxis, np.newaxis, :, :]
    kernel = torch.Tensor(kernel).to(device)
    kernel = kernel / kernel.sum()

    filtered_tensor_1 = torch.nn.functional.conv2d(tensor[:,0:1,:,:], kernel, stride=1, padding=1)
    filtered_tensor_2 = torch.nn.functional.conv2d(tensor[:,1:2,:,:], kernel, stride=1, padding=1)
    filtered_tensor_3 = torch.nn.functional.conv2d(tensor[:,2:3,:,:], kernel, stride=1, padding=1)
    filtered_tensor = torch.cat([filtered_tensor_1,filtered_tensor_2,filtered_tensor_3],1)
    return filtered_tensor * mask + tensor * mask_inv
