#    Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.
import multiprocessing
from multiprocessing.pool import Pool
from typing import Tuple, Union

import numpy as np
import pandas as pd
from batchgenerators.utilities.file_and_folder_operations import *
from nnunetv2.configuration import default_num_processes
from nnunetv2.imageio.base_reader_writer import BaseReaderWriter
from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json
from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
from nnunetv2.utilities.utils import get_identifiers_from_splitted_dataset_folder, \
    get_filenames_of_train_images_and_targets

color_cycle = (
    "000000",
    "4363d8",
    "f58231",
    "3cb44b",
    "e6194B",
    "911eb4",
    "ffe119",
    "bfef45",
    "42d4f4",
    "f032e6",
    "000075",
    "9A6324",
    "808000",
    "800000",
    "469990",
)


def hex_to_rgb(hex: str):
    assert len(hex) == 6
    return tuple(int(hex[i:i + 2], 16) for i in (0, 2, 4))


def generate_overlay(input_image: np.ndarray, segmentation: np.ndarray, mapping: dict = None,
                     color_cycle: Tuple[str, ...] = color_cycle,
                     overlay_intensity: float = 0.6):
    """
    image can be 2d greyscale or 2d RGB (color channel in last dimension!)

    Segmentation must be label map of same shape as image (w/o color channels)

    mapping can be label_id -> idx_in_cycle or None

    returned image is scaled to [0, 255] (uint8)!!!
    """
    # create a copy of image
    image = np.copy(input_image)

    if image.ndim == 2:
        image = np.tile(image[:, :, None], (1, 1, 3))
    elif image.ndim == 3:
        if image.shape[2] == 1:
            image = np.tile(image, (1, 1, 3))
        else:
            raise RuntimeError(f'if 3d image is given the last dimension must be the color channels (3 channels). '
                               f'Only 2D images are supported. Your image shape: {image.shape}')
    else:
        raise RuntimeError("unexpected image shape. only 2D images and 2D images with color channels (color in "
                           "last dimension) are supported")

    # rescale image to [0, 255]
    image = image - image.min()
    image = image / image.max() * 255

    # create output
    if mapping is None:
        uniques = np.sort(pd.unique(segmentation.ravel()))  # np.unique(segmentation)
        mapping = {i: c for c, i in enumerate(uniques)}

    for l in mapping.keys():
        image[segmentation == l] += overlay_intensity * np.array(hex_to_rgb(color_cycle[mapping[l]]))

    # rescale result to [0, 255]
    image = image / image.max() * 255
    return image.astype(np.uint8)


def select_slice_to_plot(image: np.ndarray, segmentation: np.ndarray) -> int:
    """
    image and segmentation are expected to be 3D

    selects the slice with the largest amount of fg (regardless of label)

    we give image so that we can easily replace this function if needed
    """
    fg_mask = segmentation != 0
    fg_per_slice = fg_mask.sum((1, 2))
    selected_slice = int(np.argmax(fg_per_slice))
    return selected_slice


def select_slice_to_plot2(image: np.ndarray, segmentation: np.ndarray) -> int:
    """
    image and segmentation are expected to be 3D (or 1, x, y)

    selects the slice with the largest amount of fg (how much percent of each class are in each slice? pick slice
    with highest avg percent)

    we give image so that we can easily replace this function if needed
    """
    classes = [i for i in np.sort(pd.unique(segmentation.ravel())) if i != 0]
    fg_per_slice = np.zeros((image.shape[0], len(classes)))
    for i, c in enumerate(classes):
        fg_mask = segmentation == c
        fg_per_slice[:, i] = fg_mask.sum((1, 2))
        fg_per_slice[:, i] /= fg_per_slice.sum()
    fg_per_slice = fg_per_slice.mean(1)
    return int(np.argmax(fg_per_slice))


def plot_overlay(image_file: str, segmentation_file: str, image_reader_writer: BaseReaderWriter, output_file: str,
                 overlay_intensity: float = 0.6):
    import matplotlib.pyplot as plt

    image, props = image_reader_writer.read_images((image_file, ))
    image = image[0]
    seg, props_seg = image_reader_writer.read_seg(segmentation_file)
    seg = seg[0]

    assert image.shape == seg.shape, "image and seg do not have the same shape: %s, %s" % (
        image_file, segmentation_file)

    assert image.ndim == 3, 'only 3D images/segs are supported'

    selected_slice = select_slice_to_plot2(image, seg)
    # print(image.shape, selected_slice)

    overlay = generate_overlay(image[selected_slice], seg[selected_slice], overlay_intensity=overlay_intensity)

    plt.imsave(output_file, overlay)


def plot_overlay_preprocessed(case_file: str, output_file: str, overlay_intensity: float = 0.6, channel_idx=0):
    import matplotlib.pyplot as plt
    data = np.load(case_file)['data']
    seg = np.load(case_file)['seg'][0]

    assert channel_idx < (data.shape[0]), 'This dataset only supports channel index up to %d' % (data.shape[0] - 1)

    image = data[channel_idx]
    seg[seg < 0] = 0

    selected_slice = select_slice_to_plot2(image, seg)

    overlay = generate_overlay(image[selected_slice], seg[selected_slice], overlay_intensity=overlay_intensity)

    plt.imsave(output_file, overlay)


def multiprocessing_plot_overlay(list_of_image_files, list_of_seg_files, image_reader_writer,
                                 list_of_output_files, overlay_intensity,
                                 num_processes=8):
    with multiprocessing.get_context("spawn").Pool(num_processes) as p:
        r = p.starmap_async(plot_overlay, zip(
            list_of_image_files, list_of_seg_files, [image_reader_writer] * len(list_of_output_files),
            list_of_output_files, [overlay_intensity] * len(list_of_output_files)
        ))
        r.get()


def multiprocessing_plot_overlay_preprocessed(list_of_case_files, list_of_output_files, overlay_intensity,
                                              num_processes=8, channel_idx=0):
    with multiprocessing.get_context("spawn").Pool(num_processes) as p:
        r = p.starmap_async(plot_overlay_preprocessed, zip(
            list_of_case_files, list_of_output_files, [overlay_intensity] * len(list_of_output_files),
                                                      [channel_idx] * len(list_of_output_files)
        ))
        r.get()


def generate_overlays_from_raw(dataset_name_or_id: Union[int, str], output_folder: str,
                               num_processes: int = 8, channel_idx: int = 0, overlay_intensity: float = 0.6):
    dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id)
    folder = join(nnUNet_raw, dataset_name)
    dataset_json = load_json(join(folder, 'dataset.json'))
    dataset = get_filenames_of_train_images_and_targets(folder, dataset_json)

    image_files = [v['images'][channel_idx] for v in dataset.values()]
    seg_files = [v['label'] for v in dataset.values()]

    assert all([isfile(i) for i in image_files])
    assert all([isfile(i) for i in seg_files])

    maybe_mkdir_p(output_folder)
    output_files = [join(output_folder, i + '.png') for i in dataset.keys()]

    image_reader_writer = determine_reader_writer_from_dataset_json(dataset_json, image_files[0])()
    multiprocessing_plot_overlay(image_files, seg_files, image_reader_writer, output_files, overlay_intensity, num_processes)


def generate_overlays_from_preprocessed(dataset_name_or_id: Union[int, str], output_folder: str,
                                        num_processes: int = 8, channel_idx: int = 0,
                                        configuration: str = None,
                                        plans_identifier: str = 'nnUNetPlans',
                                        overlay_intensity: float = 0.6):
    dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id)
    folder = join(nnUNet_preprocessed, dataset_name)
    if not isdir(folder): raise RuntimeError("run preprocessing for that task first")

    plans = load_json(join(folder, plans_identifier + '.json'))
    if configuration is None:
        if '3d_fullres' in plans['configurations'].keys():
            configuration = '3d_fullres'
        else:
            configuration = '2d'
    data_identifier = plans['configurations'][configuration]["data_identifier"]
    preprocessed_folder = join(folder, data_identifier)

    if not isdir(preprocessed_folder):
        raise RuntimeError(f"Preprocessed data folder for configuration {configuration} of plans identifier "
                           f"{plans_identifier} ({dataset_name}) does not exist. Run preprocessing for this "
                           f"configuration first!")

    identifiers = [i[:-4] for i in subfiles(preprocessed_folder, suffix='.npz', join=False)]

    output_files = [join(output_folder, i + '.png') for i in identifiers]
    image_files = [join(preprocessed_folder, i + ".npz") for i in identifiers]

    maybe_mkdir_p(output_folder)
    multiprocessing_plot_overlay_preprocessed(image_files, output_files, overlay_intensity=overlay_intensity,
                                              num_processes=num_processes, channel_idx=channel_idx)


def entry_point_generate_overlay():
    import argparse
    parser = argparse.ArgumentParser("Plots png overlays of the slice with the most foreground. Note that this "
                                     "disregards spacing information!")
    parser.add_argument('-d', type=str, help="Dataset name or id", required=True)
    parser.add_argument('-o', type=str, help="output folder", required=True)
    parser.add_argument('-np', type=int, default=default_num_processes, required=False,
                        help=f"number of processes used. Default: {default_num_processes}")
    parser.add_argument('-channel_idx', type=int, default=0, required=False,
                        help="channel index used (0 = _0000). Default: 0")
    parser.add_argument('--use_raw', action='store_true', required=False, help="if set then we use raw data. else "
                                                                               "we use preprocessed")
    parser.add_argument('-p', type=str, required=False, default='nnUNetPlans',
                        help='plans identifier. Only used if --use_raw is not set! Default: nnUNetPlans')
    parser.add_argument('-c', type=str, required=False, default=None,
                        help='configuration name. Only used if --use_raw is not set! Default: None = '
                             '3d_fullres if available, else 2d')
    parser.add_argument('-overlay_intensity', type=float, required=False, default=0.6,
                        help='overlay intensity. Higher = brighter/less transparent')


    args = parser.parse_args()

    if args.use_raw:
        generate_overlays_from_raw(args.d, args.o, args.np, args.channel_idx,
                                   overlay_intensity=args.overlay_intensity)
    else:
        generate_overlays_from_preprocessed(args.d, args.o, args.np, args.channel_idx, args.c, args.p,
                                            overlay_intensity=args.overlay_intensity)


if __name__ == '__main__':
    entry_point_generate_overlay()