import numpy as np
import torch

class SuperResolution:
    def __init__(self, channels, img_dim, ratio, device): #ratio = 2 or 4
        self.channels=channels
        self.img_dim=img_dim
        self.ratio=ratio
        self.device=device
    
    def downsampling(self, img):
        assert img.shape[1] == 3
        down_img = torch.zeros([img.shape[0], img.shape[1], int(img.shape[2]/self.ratio), int(img.shape[3]/self.ratio)]).to(self.device)
        for k in range(self.ratio):
            for j in range(self.ratio):
                down_img += img[:, :, k::self.ratio, j::self.ratio]
        down_img /= self.ratio**2
        return down_img

    
    def upsampling(self, img):
        up_img = torch.zeros([img.shape[0], img.shape[1], int(img.shape[2] * self.ratio), int(img.shape[3] * self.ratio)]).to(self.device)
        for k in range(self.ratio):
            for j in range(self.ratio):
                up_img[:, :, k::self.ratio, j::self.ratio] = img
        return up_img
    
    def prox(self, img, y_standard):
        return img - self.upsampling(self.downsampling(img)) + y_standard