import torch
import torch.nn as nn
import torch.nn.functional as F

def pad_for_fft(X,Ysz): 
    Xsz = Xh, Xw = X.size()[-2:]
    Ysz = Yh, Yw = Ysz[-2:]
    
    X_pad = F.pad(X,(0,Yw-Xw,0,Yh-Xh))
    return X_pad

def to_complex(X_re, X_im=None): 
    X = X_re.unsqueeze(-1)
    if X_im is None: 
        return torch.cat([X, torch.zeros_like(X)], -1)
    else: 
        return torch.cat([X, X_im.unsqueeze(-1)], -1)

def conjugate(X_cplx): 
    return torch.cat([re(X_cplx), -im(X_cplx)], -1)

def re(X_cplx): 
    i0 = tuple(slice(None,None,None) for _ in range(X_cplx.dim()-1))
    return X_cplx[i0 + (slice(0,1,1),)]

def im(X_cplx): 
    i0 = tuple(slice(None,None,None) for _ in range(X_cplx.dim()-1))
    return X_cplx[i0 + (slice(1,2,1),)]

def transpose(X): 
    return X.permute([1,0] + [i for i in range(2,X.dim())])

def permute2forward(X): 
    return X.permute([i for i in range(2,X.dim())]+[0,1])

def permute2backward(X): 
    return  X.permute([X.dim()-2,X.dim()-1]+[i for i in range(0,X.dim()-2)])

def matmul(X,Y): 
    X_pmt, Y_pmt = (permute2forward(Z) for Z in (X,Y))
    out = torch.matmul(X_pmt, Y_pmt)
    return permute2backward(out)

# complex matrix arithmetic
# (A + Bi)(C + Di) = AC - BD + (AD + BC)i
def matmul_cplx(X,Y): 
    Z_re = matmul(re(X), re(Y)) - matmul(im(X), im(Y))
    Z_im = matmul(re(X), im(Y)) + matmul(im(X), re(Y))
    return torch.cat([Z_re, Z_im], -1)

class FFTConv2d(nn.Module): 
    def __init__(self, conv, Xsz): 
        super(FFTConv2d, self).__init__()
        assert isinstance(conv, nn.Conv2d)
        self.Xsz = Xsz
        self.conv = conv
        
        # Flip the kernel to be a convolution since PyTorch implements a cross
        # correlation, and pad/roll the weight so that output is centered
        ph, pw = (k//2 for k in conv.weight.size()[-2:])
        conv_weight = torch.flip(conv.weight, (2,3))
        weight_pad = pad_for_fft(conv_weight, Xsz)
        weight_pad = torch.roll(weight_pad, (-ph,-pw), (-2,-1))

        # Precompute fourier representation of W and W*
        weight_fft = torch.fft(to_complex(weight_pad), 2)
        self.weight_fft_t = transpose(weight_fft)
        self.weight_t_fft_t = conjugate(weight_fft)

        # Precompute inv(I + W*W)
        # Compute W*W        
        WHW_fft = matmul_cplx(self.weight_fft_t, self.weight_t_fft_t)
        # Add 1 on the diagonal, but only to the real component
        torch.diagonal(WHW_fft, dim1=0, dim2=1)[:,:,0,:] += 1
        
        # move dimensions for batch inverse and separate real and imag parts
        # compute inverse of complex matrix with real matrixes 
        A, B = (permute2forward(a) for a in (re(WHW_fft), im(WHW_fft)))
        i = A.size(-1)
        block = torch.cat([torch.cat([A, -B], -1),
                           torch.cat([B,  A], -1)], -2)
        block_inv = permute2backward(torch.inverse(block))
        X,Y = block_inv[:i,:i], block_inv[i:,:i]

        self.Hinv_fft = torch.cat([X,Y],-1)
    
    def _mul(self, X, Y_fft): 
        X_fft = torch.fft(to_complex(X), 2)
        YX_fft = matmul_cplx(X_fft, Y_fft)
        return re(torch.ifft(YX_fft,2)).squeeze(-1)
    
    def forward(self, X): 
        return self._mul(X, self.weight_fft_t)
    
    def t(self, X): 
        return self._mul(X, self.weight_t_fft_t)
    
    def ridge_inverse(self, X): 
        X_fft = torch.fft(to_complex(X), 2)
        Hinvx_fft = matmul_cplx(X_fft, self.Hinv_fft)
        return re(torch.ifft(Hinvx_fft,2)).squeeze(-1)

if __name__ == "__main__": 
    torch.random.manual_seed(0)
    min_img_sz = 28
    max_img_sz = 30
    batch_size = 10
    inch = 16
    outch = 32
    max_kernel = 10
    for img_sz in range(min_img_sz, max_img_sz+1):
        for kernel_size in range(1,max_kernel+1): 
            for padding in range(0, kernel_size//2+1): 
                print(f"Testing: image size={img_sz} | kernel size={kernel_size} | padding={padding}")
                X = torch.randn(batch_size,inch,img_sz,img_sz)
                conv = nn.Conv2d(inch, outch, kernel_size, padding=padding, bias=False)
                Y = conv(X)
                convT = nn.ConvTranspose2d(4,3,kernel_size, padding=padding, bias=False)
                convT.weight = conv.weight

                k = kernel_size//2
                # reduced padding when forward convolution shrinks in size
                r = (k-padding)
                # extra padding for even kernel
                even = (kernel_size % 2 == 0)
                tr = lambda i: even*(i % 2 == 0)

                X_pad = F.pad(X, tuple(2*k-r+tr(_) for _ in range(4)))
                Y_pad = F.pad(Y, tuple(2*k for _ in range(4)))

                fft_conv = FFTConv2d(conv, X_pad.size())

                Y0 = fft_conv(X_pad)
                diff = (Y0.size(-1) - conv(X).size(-1))
                k = diff//2
                offset = diff % 2 # 1 if non-even padding, zero otherwise
                n = Y0.size(-1) - offset
                err = (Y0[:,:,k:n-k,k:n-k] - conv(X)).abs().max().item()
                print(f"Forward pass error: {err:.4e} ({'pass' if err < 1e-5 else '**fail**'})")
                assert err < 1e-5

                X0 = fft_conv.t(Y_pad)
                diff = (X0.size(-1) - convT(Y).size(-1))
                k = diff//2 
                offset = diff % 2
                n = X0.size(-1) - offset
                err = (X0[:,:,k+even:n-k+even,k+even:n-k+even] - convT(Y)).abs().max().item()
                print(f"Transpose pass error: {err:.4e} ({'pass' if err < 1e-5 else '**fail**'})")
                assert err < 1e-5

                HinvX_pad = fft_conv.ridge_inverse(X_pad)
                err = (X_pad - (HinvX_pad + fft_conv.t(fft_conv(HinvX_pad)))).abs().max().item()
                print(f"Ridge inverse error: {err:.4e} ({'pass' if err < 1e-5 else '**fail**'})")
                assert err < 1e-5

                print("")