# -*- coding: utf-8 -*-
# Author: Shuojue Yang (main contribution) and Xiangde Luo (minor modification for WORD and other datasets).
# Date:   16 Dec. 2021
# Implementation for simulation of the sparse scribble annotation based on the dense annotation for the WORD dataset and other datasets.
# # Reference:
# @article{luo2022scribbleseg,
# title={Scribble-Supervised Medical Image Segmentation via Dual-Branch Network and Dynamically Mixed Pseudo Labels Supervision},
# author={Xiangde Luo, Minhao Hu, Wenjun Liao, Shuwei Zhai, Tao Song, Guotai Wang, Shaoting Zhang},
# journal={Medical Image Computing and Computer Assisted Intervention -- MICCAI 2022},
# year={2022},
# pages={528--538}}

# @article{luo2022word,
# title={{WORD}: A large scale dataset, benchmark and clinical applicable study for abdominal organ segmentation from CT image},
# author={Xiangde Luo, Wenjun Liao, Jianghong Xiao, Jieneng Chen, Tao Song, Xiaofan Zhang, Kang Li, Dimitris N. Metaxas, Guotai Wang, and Shaoting Zhang},
# journal={Medical Image Analysis},
# volume={82},
# pages={102642},
# year={2022},
# publisher={Elsevier}}

# @misc{wsl4mis2020,
# title={{WSL4MIS}},
# author={Luo, Xiangde},
# howpublished={\url{https://github.com/Luoxd1996/WSL4MIS}},
# year={2021}}
# If you have any questions, please contact Xiangde Luo (https://luoxd1996.github.io).


import glob
import math
import random
import sys

import cv2
import numpy as np
import SimpleITK as sitk
from PIL import Image
from scipy import ndimage
from skimage.morphology import skeletonize

sys.setrecursionlimit(1000000)
seed = 2022
np.random.seed(seed)
random.seed(seed)


def random_rotation(image, max_angle=15):
    angle = np.random.uniform(-max_angle, max_angle)
    img = Image.fromarray(image)
    img_rotate = img.rotate(angle)
    return img_rotate


def translate_img(img, x_shift, y_shift):

    (height, width) = img.shape[:2]
    matrix = np.float32([[1, 0, x_shift], [0, 1, y_shift]])
    trans_img = cv2.warpAffine(img, matrix, (width, height))
    return trans_img


def get_largest_two_component_2D(img, print_info=False, threshold=None):
    """
    Get the largest two components of a binary volume
    inputs:
        img: the input 2D volume
        threshold: a size threshold
    outputs:
        out_img: the output volume 
    """
    s = ndimage.generate_binary_structure(2, 2)  # iterate structure
    labeled_array, numpatches = ndimage.label(img, s)  # labeling
    sizes = ndimage.sum(img, labeled_array, range(1, numpatches+1))
    sizes_list = [sizes[i] for i in range(len(sizes))]
    sizes_list.sort()
    if(print_info):
        print('component size', sizes_list)
    if(len(sizes) == 1):
        out_img = [img]
    else:
        if(threshold):
            max_size1 = sizes_list[-1]
            max_label1 = np.where(sizes == max_size1)[0] + 1
            if max_label1.shape[0] > 1:
                max_label1 = max_label1[0]
            component1 = labeled_array == max_label1
            out_img = [component1]
            for temp_size in sizes_list:
                if(temp_size > threshold):
                    temp_lab = np.where(sizes == temp_size)[0] + 1
                    temp_cmp = labeled_array == temp_lab[0]
                    out_img.append(temp_cmp)
            return out_img
        else:
            max_size1 = sizes_list[-1]
            max_size2 = sizes_list[-2]
            max_label1 = np.where(sizes == max_size1)[0] + 1
            max_label2 = np.where(sizes == max_size2)[0] + 1
            if max_label1.shape[0] > 1:
                max_label1 = max_label1[0]
            if max_label2.shape[0] > 1:
                max_label2 = max_label2[0]
            component1 = labeled_array == max_label1
            component2 = labeled_array == max_label2
            if(max_size2*10 > max_size1):
                out_img = [component1, component2]
            else:
                out_img = [component1]
    return out_img


class Cutting_branch(object):
    def __init__(self):
        self.lst_bifur_pt = 0
        self.branch_state = 0
        self.lst_branch_state = 0
        self.direction2delta = {0: [-1, -1], 1: [-1, 0], 2: [-1, 1], 3: [
            0, -1], 4: [0, 0], 5: [0, 1], 6: [1, -1], 7: [1, 0], 8: [1, 1]}

    def __find_start(self, lab):
        y, x = lab.shape
        idxes = np.asarray(np.nonzero(lab))
        for i in range(idxes.shape[1]):
            pt = tuple([idxes[0, i], idxes[1, i]])
            assert lab[pt] == 1
            directions = []
            for d in range(9):
                if d == 4:
                    continue
                if self.__detect_pt_bifur_state(lab, pt, d):
                    directions.append(d)
            if len(directions) == 1:
                start = pt
                self.start = start
                self.output[start] = 1
                return start
        start = tuple([idxes[0, 0], idxes[1, 0]])
        self.output[start] = 1
        self.start = start
        return start

    def __detect_pt_bifur_state(self, lab, pt, direction):

        d = direction
        y = pt[0] + self.direction2delta[d][0]
        x = pt[1] + self.direction2delta[d][1]
        if lab[y, x] > 0:
            return True
        else:
            return False

    def __detect_neighbor_bifur_state(self, lab, pt):
        directions = []
        for i in range(9):
            if i == 4:
                continue
            if self.output[tuple([pt[0] + self.direction2delta[i][0], pt[1] + self.direction2delta[i][1]])] > 0:
                continue
            if self.__detect_pt_bifur_state(lab, pt, i):
                directions.append(i)

        if len(directions) == 0:
            self.end = pt
            return False
        else:
            direction = random.sample(directions, 1)[0]
            next_pt = tuple([pt[0] + self.direction2delta[direction]
                            [0], pt[1] + self.direction2delta[direction][1]])
            if len(directions) > 1 and pt != self.start:
                self.lst_output = self.output*1
                self.previous_bifurPts.append(pt)
            self.output[next_pt] = 1
            pt = next_pt
            self.__detect_neighbor_bifur_state(lab, pt)

    def __detect_loop_branch(self, end):
        for d in range(9):
            if d == 4:
                continue
            y = end[0] + self.direction2delta[d][0]
            x = end[1] + self.direction2delta[d][1]
            if (y, x) in self.previous_bifurPts:
                self.output = self.lst_output * 1
                return True

    def __call__(self, lab, seg_lab, iterations=1):
        self.previous_bifurPts = []
        self.output = np.zeros_like(lab)
        self.lst_output = np.zeros_like(lab)
        components = get_largest_two_component_2D(lab, threshold=15)
        if len(components) > 1:
            for c in components:
                start = self.__find_start(c)
                self.__detect_neighbor_bifur_state(c, start)
        else:
            c = components[0]
            start = self.__find_start(c)
            self.__detect_neighbor_bifur_state(c, start)
        self.__detect_loop_branch(self.end)
        struct = ndimage.generate_binary_structure(2, 2)
        output = ndimage.morphology.binary_dilation(
            self.output, structure=struct, iterations=iterations)
        shift_y = random.randint(-6, 6)
        shift_x = random.randint(-6, 6)
        if np.sum(seg_lab) > 1000:
            output = translate_img(output.astype(np.uint8), shift_x, shift_y)
            output = random_rotation(output)
        output = output * seg_lab
        return output


def scrible_2d(label, iteration=[4, 10]):
    lab = label
    skeleton_map = np.zeros_like(lab, dtype=np.int32)
    for i in range(lab.shape[0]):
        if np.sum(lab[i]) == 0:
            continue
        struct = ndimage.generate_binary_structure(2, 2)
        if np.sum(lab[i]) > 900 and iteration != 0 and iteration != [0] and iteration != None:
            iter_num = math.ceil(
                iteration[0]+random.random() * (iteration[1]-iteration[0]))
            slic = ndimage.morphology.binary_erosion(
                lab[i], structure=struct, iterations=iter_num)
        else:
            slic = lab[i]
        sk_slice = skeletonize(slic, method='lee')
        sk_slice = np.asarray((sk_slice == 255), dtype=np.int32)
        skeleton_map[i] = sk_slice
    return skeleton_map


def scribble4class(label, class_id, class_num, iteration=[4, 10], cut_branch=True):
    label = (label == class_id)
    sk_map = scrible_2d(label, iteration=iteration)
    if cut_branch and class_id != 0:
        cut = Cutting_branch()
        for i in range(sk_map.shape[0]):
            lab = sk_map[i]
            if lab.sum() < 1:
                continue
            sk_map[i] = cut(lab, seg_lab=label[i])
    if class_id == 0:
        class_id = class_num
    return sk_map * class_id


def generate_scribble(label, iterations, cut_branch=True):
    class_num = np.max(label) + 1
    output = np.zeros_like(label, dtype=np.uint8)
    for i in range(class_num):
        it = iterations[i] if isinstance(iterations, list) else iterations
        scribble = scribble4class(
            label, i, class_num, it, cut_branch=cut_branch)
        output += scribble.astype(np.uint8)
    return output


if __name__ == "__main__":
    num = 0
    for i in sorted(glob.glob("../imgs/*_lab.nii.gz")):
        print("{} Begin".format(i.split("/")[-1]))
        itk_data = sitk.ReadImage(i)
        label = sitk.GetArrayFromImage(itk_data)
        num_classes = 3  # total segmentation classes
        output = generate_scribble(label, tuple([1, num_classes-1]))
        # ignore index for partially cross-entropy loss
        output[output == 0] = 255
        output[output == num_classes] = 0
        itk_scr = sitk.GetImageFromArray(output)
        itk_scr.CopyInformation(itk_data)
        sitk.WriteImage(itk_scr, i.replace('_lab.nii.gz', '_scribble.nii.gz'))
        print("{} End".format(i.split("/")[-1]))
        print(num)
        num += 1
