import sys
import os
from argparse import ArgumentParser

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))


import cv2
from joblib import Parallel, delayed
from skimage import io
from skimage.transform import rescale, iradon
import numpy as np
from tqdm import tqdm
import torch
from chip.utils.sinogram import compute_sinogram


def process_image(img, img_path, num_projections):
    path_lr = f"data/low_res_{num_projections}/" + os.path.splitext(img)[0] + '.npy'
    path_hr = f"data/high_res_{num_projections}/" + os.path.splitext(img)[0] + '.npy'
    # if os.path.exists(path_lr) and os.path.exists(path_hr):
    #     return

    img_path = os.path.join(img_path, img)
    sc_gray_w = io.imread(img_path, as_gray=True)
    image = rescale(sc_gray_w, scale=1, mode='reflect', channel_axis=None)
    image = torch.tensor(image)

    hr_sino = compute_sinogram(image)
    # Mimic lowering the resolution by a factor of 4 but keep it the same size
    lowres = cv2.resize(hr_sino.numpy(), None, fx=1. / 8, fy=1., interpolation=cv2.INTER_CUBIC)
    lowres_big = cv2.resize(lowres, None, fx=8., fy=1, interpolation=cv2.INTER_CUBIC)

    theta = np.linspace(0., 180., num_projections, endpoint=False)

    reconstruction_lowres = iradon(lowres_big.T, theta=theta, filter_name='shepp-logan')
    reconstruction_lowres = reconstruction_lowres[::-1, :]
    if not os.path.exists(f"data/low_res_{num_projections}"):
        os.makedirs(f"data/low_res_{num_projections}")
    if not os.path.exists(f"data/high_res_{num_projections}"):
        os.makedirs(f"data/high_res_{num_projections}")

    np.save(path_lr, reconstruction_lowres)
    np.save(path_hr, image.numpy())


# Parallelizing using joblib
parser = ArgumentParser()
parser.add_argument("--num_cores", type=int, default=10)
parser.add_argument("--img_path", type=str, default='data/imgs_synthetic')

args = vars(parser.parse_args())

num_cores = args['num_cores']
img_path = args['img_path']
files = os.listdir(img_path)
Parallel(n_jobs=num_cores)(delayed(process_image)(img, img_path, 90) for img in tqdm(files))
