import numpy as np
#import matplotlib.pyplot as plt
from skimage.transform import iradon
import sys
import warnings

if __name__ == '__main__':
    if len(sys.argv) == 1:
        file = "../data/station_radon_rec.npy"
        warnings.warn('use default working file ../data/station.npy')
    else:
        file = sys.argv[1]

    station_radon_rec = np.load(file)
    #print(station_radon_rec.shape)
    n = station_radon_rec.shape[0]
    ## output the 2D function to R 
    #initialize output 
    station_rec = np.zeros(station_radon_rec.shape)
    #theta grids 0 - 180
    theta = np.linspace(0., 180., station_rec[0,:,:].shape[1], endpoint=False)
    for i in range(n):
        sinogram = station_radon_rec[i,:,:] 
        image_rec = iradon(sinogram, theta=theta, filter_name='ramp')
        station_rec[i,:,:] = image_rec

    np.save(file, station_rec)



###############################
######### Case study    #######
###############################
'''
if __name__ == "__main__":
    station_radon = np.load("../data/station_radon.npy")
    station_radon_rec = np.load("../data/station_radon_rec.npy")
    station = np.load("../data/station.npy")
    sinogram = station_radon_rec[0,:,:]
    image = station[0,:,:] 
    reconstruction_fbp = iradon(sinogram, theta=theta, filter_name='ramp')
    error = reconstruction_fbp - image
    print(f'FBP rms reconstruction error: {np.sqrt(np.mean(error**2)):.3g}')

    imkwargs = dict(vmin=-0.001, vmax=0.001)
    fig, ax = plt.subplots(1,3,figsize = (8,6),sharex=True, sharey = True)
    ax[0].set_title("Original image")
    ax[0].imshow(image, cmap=plt.cm.Greys_r)
    ax[1].set_title("Reconstruction \nFiltered back projection")
    ax[1].imshow(reconstruction_fbp, cmap=plt.cm.Greys_r)
    ax[2].set_title("Reconstruction error\nFiltered back projection")
    pos2 = ax[2].imshow(reconstruction_fbp-image, cmap=plt.cm.Greys_r, **imkwargs)
    #fig.colorbar(pos2, ax=ax[2],location='right', anchor=(0, 0.4), shrink=0.4)
    plt.show()

    imkwargs = dict(vmin=-0.001, vmax=0.001)
    fig, ax = plt.subplots(1,2,figsize = (8,6),sharex=True, sharey = True)
    ax[0].imshow(station_radon_rec[0,:,:], cmap=plt.cm.Greys_r)
    ax[1].imshow(station_radon[0,:,:], cmap=plt.cm.Greys_r)
    #fig.colorbar(pos2, ax=ax[2],location='right', anchor=(0, 0.4), shrink=0.4)
    plt.show()

    

'''
    




