import numpy as np
import torch

class Inpainting:
    def __init__(self, channels, img_dim, missing_r, device):
        self.channels = channels
        self.img_dim = img_dim
        indices = torch.zeros(img_dim**2)
        indices[missing_r] = 1
        self.mask = indices.reshape([img_dim, img_dim]).unsqueeze(0).unsqueeze(0).to(device)

    def prox(self, x, y):
        return x * self.mask + y * (1-self.mask)
    
    def obs(self, x):
        return x * (1-self.mask)