from .fastmri_utils import fft2c_new, ifft2c_new
from torch.nn import functional as F
from torchvision import torch
from abc import ABC, abstractmethod




def fft2_m(x):
  """ FFT for multi-coil """
  if not torch.is_complex(x):
      x = x.type(torch.complex64)
  return torch.view_as_complex(fft2c_new(torch.view_as_real(x)))


def ifft2_m(x):
  """ IFFT for multi-coil """
  if not torch.is_complex(x):
      x = x.type(torch.complex64)
  return torch.view_as_complex(ifft2c_new(torch.view_as_real(x)))


class NonLinearOperator(ABC):
    @abstractmethod
    def forward(self, data, **kwargs):
        pass

    def project(self, data, measurement, **kwargs):
        return data + measurement - self.forward(data) 

class PhaseRetrievalOperator(NonLinearOperator):
    def __init__(self, oversample, device):
        self.pad = int((oversample / 8.0) * 256)
        self.device = device
        
    def forward(self, data, **kwargs):
        padded = F.pad(data, (self.pad, self.pad, self.pad, self.pad))
        amplitude = fft2_m(padded).abs()
        return amplitude

    def prox_by_error_bp(self, x, amplitute, alpha_obs):
        x_pad = F.pad(x, (self.pad, self.pad, self.pad, self.pad))
        # x_pad_noise = (1-alpha_obs).sqrt() * torch.randn_like(x_pad)
        # x_pad_noise[:, :, self.pad:-self.pad, self.pad:-self.pad] = x_pad[:, :, self.pad:-self.pad, self.pad:-self.pad]
        fx = fft2_m(x_pad)
        # print(fx)
        fx_prox = fx * amplitute / (fx.abs() + 1e-8)
        prox_x = ifft2_m(fx_prox)[:, :, self.pad:-self.pad, self.pad:-self.pad].real
        return prox_x
    

if __name__ == '__main__':
    phase_operator = PhaseRetrievalOperator(oversample=2.0, device='cuda')