#    Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.

from copy import deepcopy

from batchgenerators.transforms.abstract_transforms import AbstractTransform
from skimage.morphology import label, ball
from skimage.morphology.binary import binary_erosion, binary_dilation, binary_closing, binary_opening
import numpy as np


class RemoveRandomConnectedComponentFromOneHotEncodingTransform(AbstractTransform):
    def __init__(self, channel_idx, key="data", p_per_sample=0.2, fill_with_other_class_p=0.25,
                 dont_do_if_covers_more_than_X_percent=0.25, p_per_label=1):
        """
        :param dont_do_if_covers_more_than_X_percent: dont_do_if_covers_more_than_X_percent=0.25 is 25%!
        :param channel_idx: can be list or int
        :param key:
        """
        self.p_per_label = p_per_label
        self.dont_do_if_covers_more_than_X_percent = dont_do_if_covers_more_than_X_percent
        self.fill_with_other_class_p = fill_with_other_class_p
        self.p_per_sample = p_per_sample
        self.key = key
        if not isinstance(channel_idx, (list, tuple)):
            channel_idx = [channel_idx]
        self.channel_idx = channel_idx

    def __call__(self, **data_dict):
        data = data_dict.get(self.key)
        for b in range(data.shape[0]):
            if np.random.uniform() < self.p_per_sample:
                for c in self.channel_idx:
                    if np.random.uniform() < self.p_per_label:
                        workon = np.copy(data[b, c])
                        num_voxels = np.prod(workon.shape, dtype=np.uint64)
                        lab, num_comp = label(workon, return_num=True)
                        if num_comp > 0:
                            component_ids = []
                            component_sizes = []
                            for i in range(1, num_comp + 1):
                                component_ids.append(i)
                                component_sizes.append(np.sum(lab == i))
                            component_ids = [i for i, j in zip(component_ids, component_sizes) if j < num_voxels*self.dont_do_if_covers_more_than_X_percent]
                            #_ = component_ids.pop(np.argmax(component_sizes))
                            #else:
                            #    component_ids = list(range(1, num_comp + 1))
                            if len(component_ids) > 0:
                                random_component = np.random.choice(component_ids)
                                data[b, c][lab == random_component] = 0
                                if np.random.uniform() < self.fill_with_other_class_p:
                                    other_ch = [i for i in self.channel_idx if i != c]
                                    if len(other_ch) > 0:
                                        other_class = np.random.choice(other_ch)
                                        data[b, other_class][lab == random_component] = 1
        data_dict[self.key] = data
        return data_dict


class MoveSegAsOneHotToData(AbstractTransform):
    def __init__(self, channel_id, all_seg_labels, key_origin="seg", key_target="data", remove_from_origin=True):
        self.remove_from_origin = remove_from_origin
        self.all_seg_labels = all_seg_labels
        self.key_target = key_target
        self.key_origin = key_origin
        self.channel_id = channel_id

    def __call__(self, **data_dict):
        origin = data_dict.get(self.key_origin)
        target = data_dict.get(self.key_target)
        seg = origin[:, self.channel_id:self.channel_id+1]
        seg_onehot = np.zeros((seg.shape[0], len(self.all_seg_labels), *seg.shape[2:]), dtype=seg.dtype)
        for i, l in enumerate(self.all_seg_labels):
            seg_onehot[:, i][seg[:, 0] == l] = 1
        target = np.concatenate((target, seg_onehot), 1)
        data_dict[self.key_target] = target

        if self.remove_from_origin:
            remaining_channels = [i for i in range(origin.shape[1]) if i != self.channel_id]
            origin = origin[:, remaining_channels]
            data_dict[self.key_origin] = origin
        return data_dict


class ApplyRandomBinaryOperatorTransform(AbstractTransform):
    def __init__(self, channel_idx, p_per_sample=0.3, any_of_these=(binary_dilation, binary_erosion, binary_closing,
                                                                    binary_opening),
                 key="data", strel_size=(1, 10), p_per_label=1):
        self.p_per_label = p_per_label
        self.strel_size = strel_size
        self.key = key
        self.any_of_these = any_of_these
        self.p_per_sample = p_per_sample

        assert not isinstance(channel_idx, tuple), "bäh"

        if not isinstance(channel_idx, list):
            channel_idx = [channel_idx]
        self.channel_idx = channel_idx

    def __call__(self, **data_dict):
        data = data_dict.get(self.key)
        for b in range(data.shape[0]):
            if np.random.uniform() < self.p_per_sample:
                ch = deepcopy(self.channel_idx)
                np.random.shuffle(ch)
                for c in ch:
                    if np.random.uniform() < self.p_per_label:
                        operation = np.random.choice(self.any_of_these)
                        selem = ball(np.random.uniform(*self.strel_size))
                        workon = np.copy(data[b, c]).astype(int)
                        res = operation(workon, selem).astype(workon.dtype)
                        data[b, c] = res

                        # if class was added, we need to remove it in ALL other channels to keep one hot encoding
                        # properties
                        # we modify data
                        other_ch = [i for i in ch if i != c]
                        if len(other_ch) > 0:
                            was_added_mask = (res - workon) > 0
                            for oc in other_ch:
                                data[b, oc][was_added_mask] = 0
                            # if class was removed, leave it at background
        data_dict[self.key] = data
        return data_dict


class ApplyRandomBinaryOperatorTransform2(AbstractTransform):
    def __init__(self, channel_idx, p_per_sample=0.3, p_per_label=0.3, any_of_these=(binary_dilation, binary_closing),
                 key="data", strel_size=(1, 10)):
        """
        2019_11_22: I have no idea what the purpose of this was...

        the same as above but here we should use only expanding operations. Expansions will replace other labels
        :param channel_idx: can be list or int
        :param p_per_sample:
        :param any_of_these:
        :param fill_diff_with_other_class:
        :param key:
        :param strel_size:
        """
        self.strel_size = strel_size
        self.key = key
        self.any_of_these = any_of_these
        self.p_per_sample = p_per_sample
        self.p_per_label = p_per_label

        assert not isinstance(channel_idx, tuple), "bäh"

        if not isinstance(channel_idx, list):
            channel_idx = [channel_idx]
        self.channel_idx = channel_idx

    def __call__(self, **data_dict):
        data = data_dict.get(self.key)
        for b in range(data.shape[0]):
            if np.random.uniform() < self.p_per_sample:
                ch = deepcopy(self.channel_idx)
                np.random.shuffle(ch)
                for c in ch:
                    if np.random.uniform() < self.p_per_label:
                        operation = np.random.choice(self.any_of_these)
                        selem = ball(np.random.uniform(*self.strel_size))
                        workon = np.copy(data[b, c]).astype(int)
                        res = operation(workon, selem).astype(workon.dtype)
                        data[b, c] = res

                        # if class was added, we need to remove it in ALL other channels to keep one hot encoding
                        # properties
                        # we modify data
                        other_ch = [i for i in ch if i != c]
                        if len(other_ch) > 0:
                            was_added_mask = (res - workon) > 0
                            for oc in other_ch:
                                data[b, oc][was_added_mask] = 0
                            # if class was removed, leave it at backgound
        data_dict[self.key] = data
        return data_dict
