import argparse
import os

import SimpleITK as sitk

join = os.path.join
import cc3d
import numpy as np
from skimage import transform
from tqdm import tqdm


def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--nii_path', type=str, help='Path to folder with nii images')
    parser.add_argument('--gt_path', type=str, help='Path to folder with nii ground truth masks (labels)')
    parser.add_argument('--img_name_suffix', type=str, default="_0000.nii.gz")
    parser.add_argument('--gt_name_suffix', type=str, default=".nii.gz")
    parser.add_argument('--npy_path', type=str, help='Path to save npy files (path to dataset)')
    parser.add_argument('--proportion', type=float, default=1, help='Proportion of slices to be sampled from each CT')
    return parser


def main():
    parser = get_parser()
    args = parser.parse_args()
    # convert nii image to npz files, including original image and corresponding masks
    modality = "CT"
    anatomy = "Abd"  # anantomy + dataset name
    img_name_suffix = args.img_name_suffix
    gt_name_suffix = args.gt_name_suffix
    prefix = modality + "_" + anatomy + "_"

    nii_path = args.nii_path
    gt_path = args.gt_path 
    npy_path = args.npy_path + prefix[:-1]
    os.makedirs(join(npy_path, "gts"), exist_ok=True)
    os.makedirs(join(npy_path, "imgs"), exist_ok=True)

    image_size = 1024
    voxel_num_thre2d = 100
    voxel_num_thre3d = 1000

    names = sorted(os.listdir(gt_path))
    print(f"ori \# files {len(names)=}")
    names = [
        name
        for name in names
        if os.path.exists(join(nii_path, name.split(gt_name_suffix)[0] + img_name_suffix))
    ]
    print(f"after sanity check \# files {len(names)=}")
    tumor_id = None  # only set this when there are multiple tumors; convert semantic masks to instance masks
    # set window level and width
    # https://radiopaedia.org/articles/windowing-ct
    WINDOW_LEVEL = 40  # only for CT images
    WINDOW_WIDTH = 400  # only for CT images

    # %% save preprocessed images and masks as npz files
    for name in tqdm(names):  # use all cases 
        image_name = name.split(gt_name_suffix)[0] + img_name_suffix
        gt_name = name
        gt_sitk = sitk.ReadImage(join(gt_path, gt_name))
        gt_data_ori = np.uint8(sitk.GetArrayFromImage(gt_sitk))
        # label tumor masks as instances and remove from gt_data_ori
        if tumor_id is not None:
            tumor_bw = np.uint8(gt_data_ori == tumor_id)
            gt_data_ori[tumor_bw > 0] = 0
            # label tumor masks as instances
            tumor_inst, tumor_n = cc3d.connected_components(
                tumor_bw, connectivity=26, return_N=True
            )
            # put the tumor instances back to gt_data_ori
            gt_data_ori[tumor_inst > 0] = (
                tumor_inst[tumor_inst > 0] + np.max(gt_data_ori) + 1
            )

        # exclude the objects with less than 1000 pixels in 3D
        gt_data_ori = cc3d.dust(
            gt_data_ori, threshold=voxel_num_thre3d, connectivity=26, in_place=True
        )
        # remove small objects with less than 100 pixels in 2D slices

        for slice_i in range(gt_data_ori.shape[0]):
            gt_i = gt_data_ori[slice_i, :, :]
            # remove small objects with less than 100 pixels
            # reason: fro such small objects, the main challenge is detection rather than segmentation
            gt_data_ori[slice_i, :, :] = cc3d.dust(
                gt_i, threshold=voxel_num_thre2d, connectivity=8, in_place=True
            )
        # find non-zero slices
        z_index, _, _ = np.where(gt_data_ori > 0)
        z_index = np.unique(z_index)

        if len(z_index) > 0:
            # crop the ground truth with non-zero slices
            gt_roi = gt_data_ori[z_index, :, :]
            # load image and preprocess
            img_sitk = sitk.ReadImage(join(nii_path, image_name))
            image_data = sitk.GetArrayFromImage(img_sitk)
            # nii preprocess start
            if modality == "CT":
                lower_bound = WINDOW_LEVEL - WINDOW_WIDTH / 2
                upper_bound = WINDOW_LEVEL + WINDOW_WIDTH / 2
                image_data_pre = np.clip(image_data, lower_bound, upper_bound)
                image_data_pre = (
                    (image_data_pre - np.min(image_data_pre))
                    / (np.max(image_data_pre) - np.min(image_data_pre))
                    * 255.0
                )
            else:
                lower_bound, upper_bound = np.percentile(
                    image_data[image_data > 0], 0.5
                ), np.percentile(image_data[image_data > 0], 99.5)
                image_data_pre = np.clip(image_data, lower_bound, upper_bound)
                image_data_pre = (
                    (image_data_pre - np.min(image_data_pre))
                    / (np.max(image_data_pre) - np.min(image_data_pre))
                    * 255.0
                )
                image_data_pre[image_data == 0] = 0

            image_data_pre = np.uint8(image_data_pre)
            img_roi = image_data_pre[z_index, :, :]
            # np.savez_compressed(join(npy_path, prefix + gt_name.split(gt_name_suffix)[0]+'.npz'), imgs=img_roi, gts=gt_roi, spacing=img_sitk.GetSpacing())
            # save the image and ground truth as nii files for sanity check;
            # they can be removed
            img_roi_sitk = sitk.GetImageFromArray(img_roi)
            gt_roi_sitk = sitk.GetImageFromArray(gt_roi)
            sitk.WriteImage(
                img_roi_sitk,
                join(npy_path, prefix + gt_name.split(gt_name_suffix)[0] + "_img.nii.gz"),
            )
            sitk.WriteImage(
                gt_roi_sitk,
                join(npy_path, prefix + gt_name.split(gt_name_suffix)[0] + "_gt.nii.gz"),
            )
            # save the each CT image as npy file based on a proportion to be saved
            for i in range(img_roi.shape[0]):
                if np.random.uniform(0, 1) < args.proportion:
                    img_i = img_roi[i, :, :]
                    img_3c = np.repeat(img_i[:, :, None], 3, axis=-1)
                    resize_img_skimg = transform.resize(
                        img_3c,
                        (image_size, image_size),
                        order=3,
                        preserve_range=True,
                        mode="constant",
                        anti_aliasing=True,
                    )
                    resize_img_skimg_01 = (resize_img_skimg - resize_img_skimg.min()) / np.clip(
                        resize_img_skimg.max() - resize_img_skimg.min(), a_min=1e-8, a_max=None
                    )  # normalize to [0, 1], (H, W, 3)
                    gt_i = gt_roi[i, :, :]
                    resize_gt_skimg = transform.resize(
                        gt_i,
                        (image_size, image_size),
                        order=0,
                        preserve_range=True,
                        mode="constant",
                        anti_aliasing=False,
                    )
                    resize_gt_skimg = np.uint8(resize_gt_skimg)
                    assert resize_img_skimg_01.shape[:2] == resize_gt_skimg.shape
                    np.save(
                        join(
                            npy_path,
                            "imgs",
                            prefix
                            + gt_name.split(gt_name_suffix)[0]
                            + "-"
                            + str(i).zfill(3)
                            + ".npy",
                        ),
                        resize_img_skimg_01,
                    )
                    np.save(
                        join(
                            npy_path,
                            "gts",
                            prefix
                            + gt_name.split(gt_name_suffix)[0]
                            + "-"
                            + str(i).zfill(3)
                            + ".npy",
                        ),
                        resize_gt_skimg,
                    )


if __name__ == "__main__":
    main()
