import numpy as np
import torch
from .fastmri_utils import fft2c_new, ifft2c_new
from torch.nn import functional as F
import math
from abc import ABC, abstractmethod
from torch import nn
import scipy
import yaml

def fft2_m(x):
  """ FFT for multi-coil """
  if not torch.is_complex(x):
      x = x.type(torch.complex64)
  return torch.view_as_complex(fft2c_new(torch.view_as_real(x)))


def ifft2_m(x):
  """ IFFT for multi-coil """
  if not torch.is_complex(x):
      x = x.type(torch.complex64)
  return torch.view_as_complex(ifft2c_new(torch.view_as_real(x)))


class H_functions:
    """
    A class replacing the SVD of a matrix H, perhaps efficiently.
    All input vectors are of shape (Batch, ...).
    All output vectors are of shape (Batch, DataDimension).
    """

    def V(self, vec):
        """
        Multiplies the input vector by V
        """
        raise NotImplementedError()

    def Vt(self, vec):
        """
        Multiplies the input vector by V transposed
        """
        raise NotImplementedError()

    def U(self, vec):
        """
        Multiplies the input vector by U
        """
        raise NotImplementedError()

    def Ut(self, vec):
        """
        Multiplies the input vector by U transposed
        """
        raise NotImplementedError()

    def singulars(self):
        """
        Returns a vector containing the singular values. The shape of the vector should be the same as the smaller dimension (like U)
        """
        raise NotImplementedError()

    def add_zeros(self, vec):
        """
        Adds trailing zeros to turn a vector from the small dimension (U) to the big dimension (V)
        """
        raise NotImplementedError()
    
    def H(self, vec):
        """
        Multiplies the input vector by H
        """
        temp = self.Vt(vec)
        singulars = self.singulars()
        return self.U(singulars * temp[:, :singulars.shape[0]])
    
    def Ht(self, vec):
        """
        Multiplies the input vector by H transposed
        """
        temp = self.Ut(vec)
        singulars = self.singulars()
        return self.V(self.add_zeros(singulars * temp[:, :singulars.shape[0]]))
    
    def H_pinv(self, vec):
        """
        Multiplies the input vector by the pseudo inverse of H
        """
        temp = self.Ut(vec)
        singulars = self.singulars()
        singular_inverse = singulars
        singular_inverse[singulars != 0] = 1 / singulars[singulars != 0]
        temp[:, :singulars.shape[0]] = temp[:, :singulars.shape[0]] * singular_inverse
        return self.V(self.add_zeros(temp))


#Denoising
class Denoising(H_functions):
    def __init__(self, channels, img_dim, device):
        self._singulars = torch.ones(channels * img_dim**2, device=device)

    def V(self, vec):
        return vec.clone().reshape(vec.shape[0], -1)

    def Vt(self, vec):
        return vec.clone().reshape(vec.shape[0], -1)

    def U(self, vec):
        return vec.clone().reshape(vec.shape[0], -1)

    def Ut(self, vec):
        return vec.clone().reshape(vec.shape[0], -1)

    def singulars(self):
        return self._singulars
    def H(self, x):
        return x

    def H_pinv(self, x):
        return x

    def add_zeros(self, vec):
        return vec.clone().reshape(vec.shape[0], -1)
    
    def forward(self, x):
        return self.H(x)
        # return x
    
    def proj(self, x, y):
        return y
class SuperResolution:
    def __init__(self, channels, img_dim, ratio, device): #ratio = 2 or 4
        self.channels=channels
        self.img_dim=img_dim
        self.ratio=ratio
        self.device=device
    
    def downsampling(self, img):
        assert img.shape[1] == 3
        down_img = torch.zeros([img.shape[0], img.shape[1], int(img.shape[2]/self.ratio), int(img.shape[3]/self.ratio)]).to(self.device)
        for k in range(self.ratio):
            for j in range(self.ratio):
                down_img += img[:, :, k::self.ratio, j::self.ratio]
        down_img /= self.ratio**2
        return down_img

    
    def upsampling(self, img):
        up_img = torch.zeros([img.shape[0], img.shape[1], int(img.shape[2] * self.ratio), int(img.shape[3] * self.ratio)]).to(self.device)
        for k in range(self.ratio):
            for j in range(self.ratio):
                up_img[:, :, k::self.ratio, j::self.ratio] = img
        return up_img

    def forward(self, x):
        return self.downsampling(x)
    
    def H(self, x):
        return self.downsampling(x)
    
    def Ht(self, y):
        return 1/self.ratio**2 * self.upsampling(y)
    
    def H_pinv(self, y):
        return self.upsampling(y)

    def proj(self, x, y, alpha_obs=1.0):
        y = y * math.sqrt(alpha_obs)
        return x + self.upsampling(y - self.downsampling(x))
    
    def eq_var(self, var):
        return self.ratio ** 2 * var
    
    def get_type(self):
        return 'simple'

#Convolution-based super-resolution
class SRConv(H_functions):
    def mat_by_img(self, M, v, dim):
        return torch.matmul(M, v.reshape(v.shape[0] * self.channels, dim,
                        dim)).reshape(v.shape[0], self.channels, M.shape[0], dim)

    def img_by_mat(self, v, M, dim):
        return torch.matmul(v.reshape(v.shape[0] * self.channels, dim,
                        dim), M).reshape(v.shape[0], self.channels, dim, M.shape[1])

    def __init__(self, kernel, channels, img_dim, device, stride = 1):
        self.img_dim = img_dim
        self.channels = channels
        self.ratio = stride
        small_dim = img_dim // stride
        self.small_dim = small_dim
        #build 1D conv matrix
        H_small = torch.zeros(small_dim, img_dim, device=device)
        for i in range(stride//2, img_dim + stride//2, stride):
            for j in range(i - kernel.shape[0]//2, i + kernel.shape[0]//2):
                j_effective = j
                #reflective padding
                if j_effective < 0: j_effective = -j_effective-1
                if j_effective >= img_dim: j_effective = (img_dim - 1) - (j_effective - img_dim)
                #matrix building
                H_small[i // stride, j_effective] += kernel[j - i + kernel.shape[0]//2]
        #get the svd of the 1D conv
        self.U_small, self.singulars_small, self.V_small = torch.svd(H_small, some=False)
        ZERO = 3e-2
        self.singulars_small[self.singulars_small < ZERO] = 0
        #calculate the singular values of the big matrix
        self._singulars = torch.matmul(self.singulars_small.reshape(small_dim, 1), self.singulars_small.reshape(1, small_dim)).reshape(small_dim**2)
        #permutation for matching the singular values. See P_1 in Appendix D.5.
        self._perm = torch.Tensor([self.img_dim * i + j for i in range(self.small_dim) for j in range(self.small_dim)] + \
                                  [self.img_dim * i + j for i in range(self.small_dim) for j in range(self.small_dim, self.img_dim)]).to(device).long()

    def V(self, vec):
        #invert the permutation
        temp = torch.zeros(vec.shape[0], self.img_dim**2, self.channels, device=vec.device)
        temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels)[:, :self._perm.shape[0], :]
        temp[:, self._perm.shape[0]:, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels)[:, self._perm.shape[0]:, :]
        temp = temp.permute(0, 2, 1)
        #multiply the image by V from the left and by V^T from the right
        out = self.mat_by_img(self.V_small, temp, self.img_dim)
        out = self.img_by_mat(out, self.V_small.transpose(0, 1), self.img_dim).reshape(vec.shape[0], -1)
        return out

    def Vt(self, vec):
        #multiply the image by V^T from the left and by V from the right
        temp = self.mat_by_img(self.V_small.transpose(0, 1), vec.clone(), self.img_dim)
        temp = self.img_by_mat(temp, self.V_small, self.img_dim).reshape(vec.shape[0], self.channels, -1)
        #permute the entries
        temp[:, :, :self._perm.shape[0]] = temp[:, :, self._perm]
        temp = temp.permute(0, 2, 1)
        return temp.reshape(vec.shape[0], -1)

    def U(self, vec):
        #invert the permutation
        temp = torch.zeros(vec.shape[0], self.small_dim**2, self.channels, device=vec.device)
        temp[:, :self.small_dim**2, :] = vec.clone().reshape(vec.shape[0], self.small_dim**2, self.channels)
        temp = temp.permute(0, 2, 1)
        #multiply the image by U from the left and by U^T from the right
        out = self.mat_by_img(self.U_small, temp, self.small_dim)
        out = self.img_by_mat(out, self.U_small.transpose(0, 1), self.small_dim).reshape(vec.shape[0], -1)
        return out

    def Ut(self, vec):
        #multiply the image by U^T from the left and by U from the right
        temp = self.mat_by_img(self.U_small.transpose(0, 1), vec.clone(), self.small_dim)
        temp = self.img_by_mat(temp, self.U_small, self.small_dim).reshape(vec.shape[0], self.channels, -1)
        #permute the entries
        temp = temp.permute(0, 2, 1)
        return temp.reshape(vec.shape[0], -1)

    def singulars(self):
        return self._singulars.repeat_interleave(3).reshape(-1)

    def add_zeros(self, vec):
        reshaped = vec.clone().reshape(vec.shape[0], -1)
        temp = torch.zeros((vec.shape[0], reshaped.shape[1] * self.ratio**2), device=vec.device)
        temp[:, :reshaped.shape[1]] = reshaped
        return temp
    
    # def forward(self, x):
    #     return self.H(x)
    
    # def proj(self, x, y, alpha_obs=1.0):
    #     y = y * math.sqrt(alpha_obs)
    #     return x + self.Ht(y - self.H(x))

class Inpainting:
    def __init__(self, channels, img_dim, missing_r, device):
        self.channels = channels
        self.img_dim = img_dim
        indices = torch.zeros(img_dim**2)
        indices[missing_r] = 1
        self.mask = indices.reshape([img_dim, img_dim]).unsqueeze(0).unsqueeze(0).to(device)
    
    def forward(self, x):
        return x * (1-self.mask)

    def H_pinv(self, y):
        return y * (1-self.mask)

    def proj(self, x, y, alpha_obs=1.0):
        y = y * math.sqrt(alpha_obs)
        return x * self.mask + y * (1-self.mask)

    def eq_var(self, var):
        return var

    def get_type(self):
        return 'simple'

class Deblurring(H_functions):
    def mat_by_img(self, M, v):
        return torch.matmul(M, v.reshape(v.shape[0] * self.channels, self.img_dim,
                        self.img_dim)).reshape(v.shape[0], self.channels, M.shape[0], self.img_dim)

    def img_by_mat(self, v, M):
        return torch.matmul(v.reshape(v.shape[0] * self.channels, self.img_dim,
                        self.img_dim), M).reshape(v.shape[0], self.channels, self.img_dim, M.shape[1])

    def __init__(self, kernel, channels, img_dim, device, ZERO = 3e-2):
        self.img_dim = img_dim
        self.channels = channels
        #build 1D conv matrix
        H_small = torch.zeros(img_dim, img_dim, device=device)
        for i in range(img_dim):
            for j in range(i - kernel.shape[0]//2, i + kernel.shape[0]//2):
                if j < 0 or j >= img_dim: continue
                H_small[i, j] = kernel[j - i + kernel.shape[0]//2]
        #get the svd of the 1D conv
        self.U_small, self.singulars_small, self.V_small = torch.svd(H_small, some=False)
        #ZERO = 3e-2
        self.singulars_small[self.singulars_small < ZERO] = 0
        #calculate the singular values of the big matrix
        self._singulars = torch.matmul(self.singulars_small.reshape(img_dim, 1), self.singulars_small.reshape(1, img_dim)).reshape(img_dim**2)
        #sort the big matrix singulars and save the permutation
        self._singulars, self._perm = self._singulars.sort(descending=True) #, stable=True)

    def V(self, vec):
        #invert the permutation
        temp = torch.zeros(vec.shape[0], self.img_dim**2, self.channels, device=vec.device)
        temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels)
        temp = temp.permute(0, 2, 1)
        #multiply the image by V from the left and by V^T from the right
        out = self.mat_by_img(self.V_small, temp)
        out = self.img_by_mat(out, self.V_small.transpose(0, 1)).reshape(vec.shape[0], -1)
        return out

    def Vt(self, vec):
        #multiply the image by V^T from the left and by V from the right
        temp = self.mat_by_img(self.V_small.transpose(0, 1), vec.clone())
        temp = self.img_by_mat(temp, self.V_small).reshape(vec.shape[0], self.channels, -1)
        #permute the entries according to the singular values
        temp = temp[:, :, self._perm].permute(0, 2, 1)
        return temp.reshape(vec.shape[0], -1)

    def U(self, vec):
        #invert the permutation
        temp = torch.zeros(vec.shape[0], self.img_dim**2, self.channels, device=vec.device)
        temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels)
        temp = temp.permute(0, 2, 1)
        #multiply the image by U from the left and by U^T from the right
        out = self.mat_by_img(self.U_small, temp)
        out = self.img_by_mat(out, self.U_small.transpose(0, 1)).reshape(vec.shape[0], -1)
        return out

    def Ut(self, vec):
        #multiply the image by U^T from the left and by U from the right
        temp = self.mat_by_img(self.U_small.transpose(0, 1), vec.clone())
        temp = self.img_by_mat(temp, self.U_small).reshape(vec.shape[0], self.channels, -1)
        #permute the entries according to the singular values
        temp = temp[:, :, self._perm].permute(0, 2, 1)
        return temp.reshape(vec.shape[0], -1)

    def singulars(self):
        return self._singulars.repeat(1, 3).reshape(-1)

    def add_zeros(self, vec):
        return vec.clone().reshape(vec.shape[0], -1)
    
    def forward(self, x):
        return self.H(x)

    def proj(self, x, y, alpha_obs=1.0):
        return x + self.H_pinv(y - self.H(x)).view(y.shape[0], 3, x.shape[2], x.shape[3])
    
    def eq_var(self, var):
        print('This function should not be called')
        return

    def get_type(self):
        return 'SVD'

class PhaseRetrievalOperator:
    def __init__(self, oversample, device):
        # print(oversample)
        self.pad = int((oversample / 8.0) * 512)
        self.device = device
        
    def forward(self, data, **kwargs):
        padded = F.pad(data, (self.pad, self.pad, self.pad, self.pad))
        amplitude = fft2_m(padded).abs()
        return amplitude

    def H_pinv(self, x):
        x = ifft2_m(x).abs()
        x = self.undo_padding(x, self.pad, self.pad, self.pad, self.pad)
        return x
    
    def undo_padding(self, tensor, pad_left, pad_right, pad_top, pad_bottom):
        # Assuming 'tensor' is the 4D tensor
        # 'pad_left', 'pad_right', 'pad_top', 'pad_bottom' are the padding values
        if tensor.dim() != 4:
            raise ValueError("Input tensor should have 4 dimensions.")
        return tensor[:, :, pad_top : -pad_bottom, pad_left : -pad_right]

    def proj(self, x, y, alpha_obs=1.0):
        # print(self.pad)
        y = y * math.sqrt(alpha_obs)
        x_pad = F.pad(x, (self.pad, self.pad, self.pad, self.pad))
        fx = fft2_m(x_pad)
        # print(torch.min(fx.abs()))
        # fx_abs = fx.abs()
        # fx_abs[fx_abs<1e-5]=1e-5
        epsilon = 1e-8
        fx_prox = fx * y / (fx.abs() + epsilon)
        prox_x = ifft2_m(fx_prox)[:, :, self.pad:-self.pad, self.pad:-self.pad].real
        x = prox_x
        return prox_x

    def eq_var(self, var): 
        # print(256+2*self.pad)
        return var * (256+2*self.pad)**2/256**2

    def get_type(self):
        return 'simple'

class HDR(H_functions):
    def __init__(self):
        pass

    def forward(self, image):
        # Assert that image is in range [-1, 1]
        x = image
        x = torch.clip(x / 0.5, -1, 1)
        return x

    def H_pinv(self, x):
        return x * 0.5

    def proj(self, x, y, alpha_obs=1.0):
        # y = y * math.sqrt(alpha_obs)
        output = torch.zeros_like(x) + x
        # thre = alpha_obs.sqrt()
        thre = 1.0
        mask1 = torch.logical_and(torch.abs(y) >= thre, torch.abs(x) < thre/2)
        # mask1 = torch.logical_or(y > 2 * x, y < 2 * x)
        # print(mask1)
        if alpha_obs == 1.0:
            mask2 = torch.abs(y) < 1
        else:
            mask2 = torch.abs(y) < thre/2 # interesting
        output[mask1] = y[mask1] / 2
        output[mask2] = y[mask2] / 2
        return output

    def eq_var(self, var):
        return var / 4

    def get_type(self):
        return 'simple'

class LinearOperator(ABC):
    @abstractmethod
    def forward(self, data, **kwargs):
        # calculate A * X
        pass

    @abstractmethod
    def transpose(self, data, **kwargs):
        # calculate A^T * X
        pass
    
    def ortho_project(self, data, **kwargs):
        # calculate (I - A^T * A)X
        return data - self.transpose(self.forward(data, **kwargs), **kwargs)

    def project(self, data, measurement, **kwargs):
        # calculate (I - A^T * A)Y - AX
        return self.ortho_project(measurement, **kwargs) - self.forward(data, **kwargs)


class GaussianBlurOperator(LinearOperator):
    def __init__(self, kernel_size, intensity, device):
        self.device = device
        self.kernel_size = kernel_size
        self.conv = Blurkernel(blur_type='gaussian',
                               kernel_size=kernel_size,
                               std=intensity,
                               device=device).to(device)
        self.kernel = self.conv.get_kernel()
        self.conv.update_weights(self.kernel.type(torch.float32))

    def forward(self, data, **kwargs):
        return self.conv(data)

    def transpose(self, data, **kwargs):
        return data

    def get_kernel(self):
        return self.kernel.view(1, 1, self.kernel_size, self.kernel_size)

    def H_pinv(self, x):
        return x

class Blurkernel(nn.Module):
    def __init__(self, blur_type='gaussian', kernel_size=31, std=3.0, device=None):
        super().__init__()
        self.blur_type = blur_type
        self.kernel_size = kernel_size
        self.std = std
        self.device = device
        self.seq = nn.Sequential(
            nn.ReflectionPad2d(self.kernel_size//2),
            nn.Conv2d(3, 3, self.kernel_size, stride=1, padding=0, bias=False, groups=3)
        )

        self.weights_init()

    # @torch.no_grad()
    def forward(self, x):
        return self.seq(x)

    def weights_init(self):
        if self.blur_type == "gaussian":
            n = np.zeros((self.kernel_size, self.kernel_size))
            n[self.kernel_size // 2,self.kernel_size // 2] = 1
            k = scipy.ndimage.gaussian_filter(n, sigma=self.std)
            k = torch.from_numpy(k)
            self.k = k
            for name, f in self.named_parameters():
                f.data.copy_(k)
        elif self.blur_type == "motion":
            k = Kernel(size=(self.kernel_size, self.kernel_size), intensity=self.std).kernelMatrix
            k = torch.from_numpy(k)
            self.k = k
            for name, f in self.named_parameters():
                f.data.copy_(k)

    def update_weights(self, k):
        if not torch.is_tensor(k):
            k = torch.from_numpy(k).to(self.device)
        for name, f in self.named_parameters():
            f.data.copy_(k)

    def get_kernel(self):
        return self.k


class WalshHadamardCS(H_functions):
    # def fwht(self, vec): #the Fast Walsh Hadamard Transform is the same as its inverse
    #     a = vec.reshape(vec.shape[0], self.channels, self.img_dim**2)
    #     h = 1
    #     while h < self.img_dim**2:
    #         a = a.reshape(vec.shape[0], self.channels, -1, h * 2)
    #         b = a.clone()
    #         a[:, :, :, :h] = b[:, :, :, :h] + b[:, :, :, h:2*h]
    #         a[:, :, :, h:2*h] = b[:, :, :, :h] - b[:, :, :, h:2*h]
    #         h *= 2
    #     a = a.reshape(vec.shape[0], self.channels, self.img_dim**2) / self.img_dim
    #     return a
    def fwht(self, vec):
        B = vec.shape[0]
        a = vec.reshape(B, self.channels, self.img_dim**2)

        h = 1
        while h < self.img_dim**2:
            a = a.reshape(B, self.channels, -1, h * 2)

            x1 = a[:, :, :, :h]
            x2 = a[:, :, :, h:2*h]

            a = torch.cat([x1 + x2, x1 - x2], dim=-1)

            h *= 2

        a = a.reshape(B, self.channels, self.img_dim**2) / self.img_dim
        return a

    def __init__(self, channels, img_dim, ratio, perm, device):
        self.channels = channels
        self.img_dim = img_dim
        self.ratio = ratio
        self.perm = perm
        self._singulars = torch.ones(channels * img_dim**2 // ratio, device=device)


    def V(self, vec):
        temp = torch.zeros(vec.shape[0], self.channels, self.img_dim**2, device=vec.device)
        temp[:, :, self.perm] = vec.clone().reshape(vec.shape[0], -1, self.channels).permute(0, 2, 1)
        return self.fwht(temp).reshape(vec.shape[0], -1)

    def Vt(self, vec):
        return self.fwht(vec.clone())[:, :, self.perm].permute(0, 2, 1).reshape(vec.shape[0], -1)

    def U(self, vec):
        return vec.clone().reshape(vec.shape[0], -1)

    def Ut(self, vec):
        return vec.clone().reshape(vec.shape[0], -1)

    def singulars(self):
        return self._singulars

    def add_zeros(self, vec):
        out = torch.zeros(vec.shape[0], self.channels * self.img_dim**2, device=vec.device)
        out[:, :self.channels * self.img_dim**2 // self.ratio] = vec.clone().reshape(vec.shape[0], -1)
        return out
    
    def forward(self, vec):
        return self.H(vec)

    def proj(self, x, y, alpha_obs=1.0):
        x = x.clone()
        with torch.autocast("cuda", dtype=torch.float32):
            return x + self.H_pinv(y - self.H(x)).view(y.shape[0], 3, x.shape[2], x.shape[3])

    def proj_y0(self, x, predicted_y0, y, alpha_obs=1.0):
        with torch.autocast("cuda", dtype=torch.float32):
            return x + self.H_pinv(y - predicted_y0).view(y.shape[0], 3, x.shape[2], x.shape[3])


#Anisotropic Deblurring
class Deblurring2D(H_functions):
    def mat_by_img(self, M, v):
        with torch.autocast("cuda", dtype=torch.float32):
            return torch.matmul(M, v.reshape(v.shape[0] * self.channels, self.img_dim,
                        self.img_dim)).reshape(v.shape[0], self.channels, M.shape[0], self.img_dim)

    def img_by_mat(self, v, M):
        with torch.autocast("cuda", dtype=torch.float32):
            return torch.matmul(v.reshape(v.shape[0] * self.channels, self.img_dim,
                        self.img_dim), M).reshape(v.shape[0], self.channels, self.img_dim, M.shape[1])

    def __init__(self, kernel1, kernel2, channels, img_dim, device):
        with torch.autocast("cuda", dtype=torch.float32):
            self.img_dim = img_dim
            self.channels = channels
            #build 1D conv matrix - kernel1
            H_small1 = torch.zeros(img_dim, img_dim, device=device)
            for i in range(img_dim):
                for j in range(i - kernel1.shape[0]//2, i + kernel1.shape[0]//2):
                    if j < 0 or j >= img_dim: continue
                    H_small1[i, j] = kernel1[j - i + kernel1.shape[0]//2]
            #build 1D conv matrix - kernel2
            H_small2 = torch.zeros(img_dim, img_dim, device=device)
            for i in range(img_dim):
                for j in range(i - kernel2.shape[0]//2, i + kernel2.shape[0]//2):
                    if j < 0 or j >= img_dim: continue
                    H_small2[i, j] = kernel2[j - i + kernel2.shape[0]//2]
            #get the svd of the 1D conv
            self.U_small1, self.singulars_small1, self.V_small1 = torch.svd(H_small1, some=False)
            self.U_small2, self.singulars_small2, self.V_small2 = torch.svd(H_small2, some=False)
            ZERO = 3e-2
            self.singulars_small1[self.singulars_small1 < ZERO] = 0
            self.singulars_small2[self.singulars_small2 < ZERO] = 0
            #calculate the singular values of the big matrix
            self._singulars = torch.matmul(self.singulars_small1.reshape(img_dim, 1), self.singulars_small2.reshape(1, img_dim)).reshape(img_dim**2)
            #sort the big matrix singulars and save the permutation
            self._singulars, self._perm = self._singulars.sort(descending=True) #, stable=True)

    def V(self, vec):
        with torch.autocast("cuda", dtype=torch.float32):
            #invert the permutation
            temp = torch.zeros(vec.shape[0], self.img_dim**2, self.channels, device=vec.device).to(torch.float32).to(torch.float32)
            temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels).to(torch.float32)
            temp = temp.permute(0, 2, 1)
            #multiply the image by V from the left and by V^T from the right
            out = self.mat_by_img(self.V_small1, temp)
            out = self.img_by_mat(out, self.V_small2.transpose(0, 1)).reshape(vec.shape[0], -1)
            return out

    def Vt(self, vec):
        # with torch.autocast("cuda", dtype=torch.float32):
            #multiply the image by V^T from the left and by V from the right
        temp = self.mat_by_img(self.V_small1.transpose(0, 1), vec.clone())
        temp = self.img_by_mat(temp, self.V_small2).reshape(vec.shape[0], self.channels, -1)
        #permute the entries according to the singular values
        temp = temp[:, :, self._perm].permute(0, 2, 1)
        return temp.reshape(vec.shape[0], -1)

    def U(self, vec):
        # with torch.autocast("cuda", dtype=torch.float32):
        #invert the permutation
        temp = torch.zeros(vec.shape[0], self.img_dim**2, self.channels, device=vec.device).to(torch.float32)
        temp[:, self._perm, :] = vec.clone().reshape(vec.shape[0], self.img_dim**2, self.channels).to(torch.float32)
        temp = temp.permute(0, 2, 1)
        #multiply the image by U from the left and by U^T from the right
        out = self.mat_by_img(self.U_small1, temp)
        out = self.img_by_mat(out, self.U_small2.transpose(0, 1)).reshape(vec.shape[0], -1)
        return out

    def Ut(self, vec):
        with torch.autocast("cuda", dtype=torch.float32):
            #multiply the image by U^T from the left and by U from the right
            temp = self.mat_by_img(self.U_small1.transpose(0, 1), vec.clone())
            temp = self.img_by_mat(temp, self.U_small2).reshape(vec.shape[0], self.channels, -1)
            #permute the entries according to the singular values
            temp = temp[:, :, self._perm].permute(0, 2, 1)
            return temp.reshape(vec.shape[0], -1)

    def singulars(self):
        with torch.autocast("cuda", dtype=torch.float32):
            return self._singulars.repeat(1, 3).reshape(-1)

    def add_zeros(self, vec):
        with torch.autocast("cuda", dtype=torch.float32):
            return vec.clone().reshape(vec.shape[0], -1)

    def forward(self, vec):
        with torch.autocast("cuda", dtype=torch.float32):
            return self.H(vec)
    
    # def H_pinv(self, vec):
    #     """
    #     Multiplies the input vector by the pseudo inverse of H
    #     """
    #     temp = self.Ut(vec)
    #     singulars = self.singulars()
    #     singular_inverse = singulars
    #     # mask = singulars.
    #     # print(singulars.abs().min())
    #     mask = singulars.abs() > 5e-2
    #     singular_inverse[mask] = 1 / singulars[mask]
    #     singular_inverse[~mask] = 20
    #     temp[:, :singulars.shape[0]] = temp[:, :singulars.shape[0]] * singular_inverse
    #     return self.V(self.add_zeros(temp))
    
    def H_pinv(self, vec):
        return vec

    def proj(self, x, y, alpha_obs=1.0):
        with torch.autocast("cuda", dtype=torch.float32):
            return x + self.H_pinv(y - self.H(x)).view(y.shape[0], 3, x.shape[2], x.shape[3])

class NonlinearBlurOperator(H_functions):
    def __init__(self, device, opt_yml_path='./bkse/options/generate_blur/default.yml'):
        self.device = device
        self.blur_model = self.prepare_nonlinear_blur_model(opt_yml_path)
        self.random_kernel = torch.randn(1, 512, 2, 2).to(self.device) * 1.2
         
    def prepare_nonlinear_blur_model(self, opt_yml_path):
        '''
        Nonlinear deblur requires external codes (bkse).
        '''
        from bkse.models.kernel_encoding.kernel_wizard import KernelWizard

        with open(opt_yml_path, "r") as f:
            opt = yaml.safe_load(f)["KernelWizard"]
            model_path = opt["pretrained"]
        blur_model = KernelWizard(opt)
        blur_model.eval()
        blur_model.load_state_dict(torch.load(model_path)) 
        blur_model = blur_model.to(self.device)
        return blur_model
    
    def forward(self, data, **kwargs):
        kernel = self.random_kernel.repeat(data.shape[0], 1, 1, 1)
        data = (data + 1.0) / 2.0  #[-1, 1] -> [0, 1]
        data = F.interpolate(
            data,
            size=(256, 256),
            mode="bilinear",     
            align_corners=False
        )
        blurred = self.blur_model.adaptKernel(data, kernel=kernel)
        blurred = F.interpolate(
            blurred,
            size=(512, 512),
            mode="bilinear",
            align_corners=False
        )
        blurred = (blurred * 2.0 - 1.0).clamp(-1, 1) #[0, 1] -> [-1, 1]
        return blurred

    def H_pinv(self, x):
        return x
    
    def is_linear(self):
        return False
        
    def H(self, data, **kwargs):
        return self.forward(data)