## import densities
import numpy as np
#import matplotlib.pyplot as plt
from skimage.transform import radon
from skimage.transform import iradon
import sys
import warnings

if __name__ == '__main__':
    if len(sys.argv) == 1:
        file = "../data/station.npy"
        warnings.warn('use default working file ../data/station.npy')
    else:
        file = sys.argv[1]
    station = np.load(file)
    n = station.shape[0]
    ## output the 2D function to R 
    #initialize output 
    station_radon = np.zeros(station.shape)
    #theta grids 0 - 180
    theta = np.linspace(0., 180., max(station[0,:,:].shape), endpoint=False)
    for i in range(n):
        image = station[i,:,:] #first arg: rho; second arg: theta 
        sinogram = radon(image, theta=theta)
        station_radon[i,:,:] = sinogram
        #reconstruction_fbp = iradon(sinogram, theta=theta, filter_name='ramp')

    np.save(file, station_radon)



###############################
######### Case study    #######
###############################

'''

## take the first station FRESNO 
image = station[0,:,:] 
## radon transform each density functions to obtain the decomposition 
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4.5))
ax1.set_title("Original")
ax1.imshow(image, cmap=plt.cm.Greys_r)

theta = np.linspace(0., 180., max(image.shape), endpoint=False)
sinogram = radon(image, theta=theta)
dx, dy = 0.5 * 180.0 / max(image.shape), 0.5 / sinogram.shape[0]
ax2.set_title("Radon transform\n(Sinogram)")
ax2.set_xlabel("Projection angle (deg)")
ax2.set_ylabel("Projection position (pixels)")
ax2.imshow(sinogram, cmap=plt.cm.Greys_r,
           extent=(-dx, 180.0 + dx, -dy, sinogram.shape[0] + dy),
           aspect='auto')

fig.tight_layout()
#plt.show()
#fig.savefig("../figure/figure_station/FRESNO.png")

#################################
## Reconstructed density(ramp)###
#################################
reconstruction_fbp = iradon(sinogram, theta=theta, filter_name='ramp', circle = True)
error = reconstruction_fbp - image
print(f'FBP rms reconstruction error: {np.sqrt(np.mean(error**2)):.3g}')

#imkwargs = dict(vmin=-0.2, vmax=0.2)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4.5),
                               sharex=True, sharey=True)
ax1.set_title("Reconstruction\nFiltered back projection")
pos1 = ax1.imshow(reconstruction_fbp, cmap=plt.cm.Greys_r)
#fig.colorbar(pos1, ax=ax1)
ax2.set_title("Reconstruction error\nFiltered back projection")
pos2 = ax2.imshow(reconstruction_fbp - image, cmap=plt.cm.Greys_r)
#fig.colorbar(pos2, ax=ax2)
plt.show()
print(np.sum(np.abs(image-reconstruction_fbp))/ np.sum(np.abs(image)))
'''

