import numpy as np
import os
import tomosipo as ts
from latent import Latent
from network import bSiren
from tomosipo.torch_support import to_autograd
from imageio import imread
import sys

os.environ["CUDA_VISIBLE_DEVICES"] = '0'

node_num = 10 # number of joint reconstruction nodes
num_proj = 40 # number of projections
imgs = np.zeros((node_num, 1, 501, 501), dtype=np.float32)

# idx specifies the slice number
idx = int(sys.argv[1])

for i in range(node_num):
    # load the original image
    d = imread('.tiff')
    imgs[i,0] = d

imgs = (imgs - imgs.min())/(imgs.max()-imgs.min()) # normalize to [0,1]

# crop the center part of the image
from tomopy.misc.corr import circ_mask
imgs[:,0,:,:] = circ_mask(imgs[:,0,:,:], axis=0, ratio=1)
imgs = np.rot90(imgs, k=1, axes=(2, 3)).copy()

# cerate the forward operator
angles = np.linspace(0, np.pi, num_proj, endpoint=False)
vg = ts.volume(shape=(1, 501, 501), size=(1, 1, 1))
pg = ts.parallel(angles=angles, shape=(1, 501), size=(1, 1))
A = ts.operator(vg, pg)
f = to_autograd(A)

# generate the projection data
projs = np.zeros((imgs.shape[0], 1, num_proj, 501), dtype=np.float32)
for i in range(imgs.shape[0]):
    proj = A(imgs[i])
    projs[i] = proj

# warp up the data
data = {'projs': projs, 'recon': imgs}
# cerate the baysian model
model = bSiren(network_depth=8, network_width=128, network_input_size=256, network_output_size=1, 
               weight_scale=0, rho_offset=-10, lambda_kl=1e-14, device='cuda')
# cerate the latent variable    
latent = Latent(data=data, model=model, forward_op=f, glob_iters=100, local_epochs=300, burnin_epochs= (0,10000),
                 node_lr=1e-5, work_dir='experiement/', device='cuda')
latent.train()