'''
This first python function maps tomographic data from sinogram space to reconstruction space according
the Fourier slice theorem. This theorem states that the Fourier transform of the reconstruction
(expressed in polar coordiantes) equals the 1D Fourier transform of the sinogram:
FFT2[r](k, theta) = FFT1[s](k, theta) where r labels the reconstruction and s the sinogram.
(k, theta) are the polar coordinates in Fourier-reconstruction space, i.e.
(k_x, k_y) = k (cos theta, sin theta) where (k_x, k_y) are the cartesian coordinates of the
FFT of the reconstruction. The result of this gridRec like algorithm is however not the full
reconstruction (in real space), but its FFT (in cartesian coordinates, but along specific angles theta,
i.e. FFT2[r](k, theta_s) interpolated onto the cartesian grid, for some angles theta_s; this 
interpolation/convolution implicitly high-pass filters the sinogram ('radially') ).

The second python function is standard gridrec

The main section is to a) exemplify the usage of these functions and b) demonstrate how
undersampling a sinogram effects reconstruction quality and c) to demonstrate the (missing) 
interpolation correction in Fourier space. (Can only easily be applied in real space)
Angular undersampling has the effect of 'shattering' the reconstruction, whereas the omitted real 
space correction/division overlays a 'brightness-gradient' to the reconstruction.

Finally note that the S&L sinogram used herein is a analytically generated sinogram (noisefree)
'''
import numpy as np

from chip.gridrec.rt import gridRecFFT
from chip.gridrec.postCorrectRec import postCorrect


def gridRecAUS(sinogram, angles, subsampledAngles, centerOfRotation = 0.0):
    '''
    This function calculates the FFT of a tomographic reconstruction based on gridRec algorithm.
    A subset of angles can be chosen from the overall sinogram.
    Params: 
          sinogram: 2d float64 np array of shape (numberOfAngles, numberOfPixels)
          angles:   1d float64 np array of shape (numberOfAngles)
                    This array contains the angular positions of the rotated sample (in radians)
          subsampledAngles: 1d int32 array of shape (numberOfSubsampledAngles)
                    This array contains the indices to be considered in angles array
          centerOfRotation: (float) If none is specified, numberOfPixels/2 is chosen (ideal alignment)
    Returns:
            Subsampled sinogram mapped to Fourier space. (Zero frequency in center)
            Type: 2d complex128 np array of shape (numberOfPixels, numberOfPixels)
            This equals the Fourier transform of the angularly subsampled reconstruction.
            (Up to the correction factor for the interpolation function)
    '''
    if centerOfRotation == 0.0:
        centerOfRotation = sinogram.shape[1]/2
    return gridRecFFT(1, sinogram, angles, subsampledAngles, centerOfRotation)
    
    
def gridRec(sinogram, angles, subsampledAngles = None, centerOfRotation = 0.0, doPostCorrect = True, doNormalize = True, zeroPaddingFactor = 1):
    '''
    This function implements the inverse Radon transform.
    A subset of angles can be chosen from the overall sinogram.
    Params: 
          sinogram: 2d float64 np array of shape (numberOfAngles, numberOfPixels)
          angles:   1d float64 np array of shape (numberOfAngles)
                    This array contains the angular positions of the rotated sample (in radians)
          subsampledAngles: 1d int32 array of shape (numberOfSubsampledAngles)
                    This array contains the indices to be considered in angles array
          centerOfRotation: (float) If none is specified, numberOfPixels/2 is chosen (ideal alignment)
          doPostCorrect: Boolean defining if division by IFFT[kernel] should be applied to reconstruction
          doNormalize: Boolean defining if reconstruction should be normalized (values in [0..1])
          zeroPaddingFactor: The sinogram is (additionally) padded by the power of two if this int 
    Returns:
            Reconstruction
            Type: 2d float64 np array of shape (numberOfPixels, numberOfPixels)
    '''
    def applyDiskMask(reconstruction):
        radius = min(reconstruction.shape)//2
        coords = np.array(np.ogrid[:reconstruction.shape[0], :reconstruction.shape[1]], dtype=object)
        dist = ((coords - np.array(reconstruction.shape)//2) ** 2).sum(0)
        outsideReconstructionCircle = dist > radius**2
        reconstruction[outsideReconstructionCircle] = 0
        return reconstruction

    numberOfAngularPositions  = sinogram.shape[0]
    numberOfXPixels = sinogram.shape[1]
    if centerOfRotation == 0.0:
        centerOfRotation = numberOfXPixels/2
    if subsampledAngles is None:
        subsampledAngles = np.arange(numberOfAngularPositions).astype(np.int32)
    numberOfZeroPaddedXPixels = int(2**np.ceil(np.log2(numberOfXPixels))) #Anyhow padd to next power of 2 (Good for FFT alg)
    numberOfZeroPaddedXPixels *= 2**zeroPaddingFactor
    numberOfOriginalXPixels = numberOfXPixels
    if numberOfZeroPaddedXPixels != numberOfXPixels:
        startDataInZeroPaddedSinogram = (numberOfZeroPaddedXPixels - numberOfXPixels)//2
        zeroPaddedSinogram = np.zeros((numberOfAngularPositions, numberOfZeroPaddedXPixels))
        centerOfRotation += startDataInZeroPaddedSinogram
        zeroPaddedSinogram[:, startDataInZeroPaddedSinogram:numberOfZeroPaddedXPixels - startDataInZeroPaddedSinogram] = sinogram
        sinogram =(zeroPaddedSinogram).astype(np.float32)
        numberOfXPixels = numberOfZeroPaddedXPixels
    recFFT = gridRecAUS(sinogram, angles, subsampledAngles, centerOfRotation) #FFT of reconstruction (along subs angles)
    recFFT = np.fft.ifftshift(recFFT) #standard ordering of frequencies
    rec = (np.fft.ifft2(recFFT)).real #real part of IFFT is reconstruction (imag part = 0 in theory)
    reconstruction = np.fft.fftshift(rec) # reorder values to be visualized
    if doPostCorrect:
        reconstruction = postCorrect(reconstruction) #Apply correction filter (divide by IFFT of interpolation kernel)
    if numberOfOriginalXPixels != numberOfXPixels:
        xStart = startDataInZeroPaddedSinogram
        xStop = numberOfZeroPaddedXPixels - startDataInZeroPaddedSinogram
        reconstruction = reconstruction[xStart:xStop, xStart:xStop]
    if doNormalize:
        reconstruction = applyDiskMask(reconstruction)
        reconstruction -= reconstruction.min(); reconstruction /= reconstruction.max()
    return reconstruction