import argparse
import multiprocessing
import shutil
from multiprocessing import Pool
from typing import Union, Tuple, List, Callable

import numpy as np
from acvl_utils.morphology.morphology_helper import remove_all_but_largest_component
from batchgenerators.utilities.file_and_folder_operations import load_json, subfiles, maybe_mkdir_p, join, isfile, \
    isdir, save_pickle, load_pickle, save_json
from nnunetv2.configuration import default_num_processes
from nnunetv2.evaluation.accumulate_cv_results import accumulate_cv_results
from nnunetv2.evaluation.evaluate_predictions import region_or_label_to_mask, compute_metrics_on_folder, \
    load_summary_json, label_or_region_to_key
from nnunetv2.imageio.base_reader_writer import BaseReaderWriter
from nnunetv2.paths import nnUNet_raw
from nnunetv2.utilities.file_path_utilities import folds_tuple_to_string
from nnunetv2.utilities.json_export import recursive_fix_for_json_export
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager


def remove_all_but_largest_component_from_segmentation(segmentation: np.ndarray,
                                                       labels_or_regions: Union[int, Tuple[int, ...],
                                                                                List[Union[int, Tuple[int, ...]]]],
                                                       background_label: int = 0) -> np.ndarray:
    mask = np.zeros_like(segmentation, dtype=bool)
    if not isinstance(labels_or_regions, list):
        labels_or_regions = [labels_or_regions]
    for l_or_r in labels_or_regions:
        mask |= region_or_label_to_mask(segmentation, l_or_r)
    mask_keep = remove_all_but_largest_component(mask)
    ret = np.copy(segmentation)  # do not modify the input!
    ret[mask & ~mask_keep] = background_label
    return ret


def apply_postprocessing(segmentation: np.ndarray, pp_fns: List[Callable], pp_fn_kwargs: List[dict]):
    for fn, kwargs in zip(pp_fns, pp_fn_kwargs):
        segmentation = fn(segmentation, **kwargs)
    return segmentation


def load_postprocess_save(segmentation_file: str,
                          output_fname: str,
                          image_reader_writer: BaseReaderWriter,
                          pp_fns: List[Callable],
                          pp_fn_kwargs: List[dict]):
    seg, props = image_reader_writer.read_seg(segmentation_file)
    seg = apply_postprocessing(seg[0], pp_fns, pp_fn_kwargs)
    image_reader_writer.write_seg(seg, output_fname, props)


def determine_postprocessing(folder_predictions: str,
                             folder_ref: str,
                             plans_file_or_dict: Union[str, dict],
                             dataset_json_file_or_dict: Union[str, dict],
                             num_processes: int = default_num_processes,
                             keep_postprocessed_files: bool = True):
    """
    Determines nnUNet postprocessing. Its output is a postprocessing.pkl file in folder_predictions which can be
    used with apply_postprocessing_to_folder.

    Postprocessed files are saved in folder_predictions/postprocessed. Set
    keep_postprocessed_files=False to delete these files after this function is done (temp files will eb created
    and deleted regardless).

    If plans_file_or_dict or dataset_json_file_or_dict are None, we will look for them in input_folder
    """
    output_folder = join(folder_predictions, 'postprocessed')

    if plans_file_or_dict is None:
        expected_plans_file = join(folder_predictions, 'plans.json')
        if not isfile(expected_plans_file):
            raise RuntimeError(f"Expected plans file missing: {expected_plans_file}. The plans files should have been "
                               f"created while running nnUNetv2_predict. Sadge.")
        plans_file_or_dict = load_json(expected_plans_file)
    plans_manager = PlansManager(plans_file_or_dict)

    if dataset_json_file_or_dict is None:
        expected_dataset_json_file = join(folder_predictions, 'dataset.json')
        if not isfile(expected_dataset_json_file):
            raise RuntimeError(
                f"Expected plans file missing: {expected_dataset_json_file}. The plans files should have been "
                f"created while running nnUNetv2_predict. Sadge.")
        dataset_json_file_or_dict = load_json(expected_dataset_json_file)

    if not isinstance(dataset_json_file_or_dict, dict):
        dataset_json = load_json(dataset_json_file_or_dict)
    else:
        dataset_json = dataset_json_file_or_dict

    rw = plans_manager.image_reader_writer_class()
    label_manager = plans_manager.get_label_manager(dataset_json)
    labels_or_regions = label_manager.foreground_regions if label_manager.has_regions else label_manager.foreground_labels

    predicted_files = subfiles(folder_predictions, suffix=dataset_json['file_ending'], join=False)
    ref_files = subfiles(folder_ref, suffix=dataset_json['file_ending'], join=False)
    # we should print a warning if not all files from folder_ref are present in folder_predictions
    if not all([i in predicted_files for i in ref_files]):
        print(f'WARNING: Not all files in folder_ref were found in folder_predictions. Determining postprocessing '
              f'should always be done on the entire dataset!')

    # before we start we should evaluate the imaegs in the source folder
    if not isfile(join(folder_predictions, 'summary.json')):
        compute_metrics_on_folder(folder_ref,
                                  folder_predictions,
                                  join(folder_predictions, 'summary.json'),
                                  rw,
                                  dataset_json['file_ending'],
                                  labels_or_regions,
                                  label_manager.ignore_label,
                                  num_processes)

    # we save the postprocessing functions in here
    pp_fns = []
    pp_fn_kwargs = []

    # pool party!
    with multiprocessing.get_context("spawn").Pool(num_processes) as pool:
        # now let's see whether removing all but the largest foreground region improves the scores
        output_here = join(output_folder, 'temp', 'keep_largest_fg')
        maybe_mkdir_p(output_here)
        pp_fn = remove_all_but_largest_component_from_segmentation
        kwargs = {
            'labels_or_regions': label_manager.foreground_labels,
        }

        pool.starmap(
            load_postprocess_save,
            zip(
                [join(folder_predictions, i) for i in predicted_files],
                [join(output_here, i) for i in predicted_files],
                [rw] * len(predicted_files),
                [[pp_fn]] * len(predicted_files),
                [[kwargs]] * len(predicted_files)
            )
        )
        compute_metrics_on_folder(folder_ref,
                                  output_here,
                                  join(output_here, 'summary.json'),
                                  rw,
                                  dataset_json['file_ending'],
                                  labels_or_regions,
                                  label_manager.ignore_label,
                                  num_processes)
        # now we need to figure out if doing this improved the dice scores. We will implement that defensively in so far
        # that if a single class got worse as a result we won't do this. We can change this in the future but right now I
        # prefer to do it this way
        baseline_results = load_summary_json(join(folder_predictions, 'summary.json'))
        pp_results = load_summary_json(join(output_here, 'summary.json'))
        do_this = pp_results['foreground_mean']['Dice'] > baseline_results['foreground_mean']['Dice']
        if do_this:
            for class_id in pp_results['mean'].keys():
                if pp_results['mean'][class_id]['Dice'] < baseline_results['mean'][class_id]['Dice']:
                    do_this = False
                    break
        if do_this:
            print(f'Results were improved by removing all but the largest foreground region. '
                  f'Mean dice before: {round(baseline_results["foreground_mean"]["Dice"], 5)} '
                  f'after: {round(pp_results["foreground_mean"]["Dice"], 5)}')
            source = output_here
            pp_fns.append(pp_fn)
            pp_fn_kwargs.append(kwargs)
        else:
            print(f'Removing all but the largest foreground region did not improve results!')
            source = folder_predictions

        # in the old nnU-Net we could just apply all-but-largest component removal to all classes at the same time and
        # then evaluate for each class whether this improved results. This is no longer possible because we now support
        # region-based predictions and regions can overlap, causing interactions
        # in principle the order with which the postprocessing is applied to the regions matter as well and should be
        # investigated, but due to some things that I am too lazy to explain right now it's going to be alright (I think)
        # to stick to the order in which they are declared in dataset.json (if you want to think about it then think about
        # region_class_order)
        # 2023_02_06: I hate myself for the comment above. Thanks past me
        if len(labels_or_regions) > 1:
            for label_or_region in labels_or_regions:
                pp_fn = remove_all_but_largest_component_from_segmentation
                kwargs = {
                    'labels_or_regions': label_or_region,
                }

                output_here = join(output_folder, 'temp', 'keep_largest_perClassOrRegion')
                maybe_mkdir_p(output_here)

                pool.starmap(
                    load_postprocess_save,
                    zip(
                        [join(source, i) for i in predicted_files],
                        [join(output_here, i) for i in predicted_files],
                        [rw] * len(predicted_files),
                        [[pp_fn]] * len(predicted_files),
                        [[kwargs]] * len(predicted_files)
                    )
                )
                compute_metrics_on_folder(folder_ref,
                                          output_here,
                                          join(output_here, 'summary.json'),
                                          rw,
                                          dataset_json['file_ending'],
                                          labels_or_regions,
                                          label_manager.ignore_label,
                                          num_processes)
                baseline_results = load_summary_json(join(source, 'summary.json'))
                pp_results = load_summary_json(join(output_here, 'summary.json'))
                do_this = pp_results['mean'][label_or_region]['Dice'] > baseline_results['mean'][label_or_region]['Dice']
                if do_this:
                    print(f'Results were improved by removing all but the largest component for {label_or_region}. '
                          f'Dice before: {round(baseline_results["mean"][label_or_region]["Dice"], 5)} '
                          f'after: {round(pp_results["mean"][label_or_region]["Dice"], 5)}')
                    if isdir(join(output_folder, 'temp', 'keep_largest_perClassOrRegion_currentBest')):
                        shutil.rmtree(join(output_folder, 'temp', 'keep_largest_perClassOrRegion_currentBest'))
                    shutil.move(output_here, join(output_folder, 'temp', 'keep_largest_perClassOrRegion_currentBest'), )
                    source = join(output_folder, 'temp', 'keep_largest_perClassOrRegion_currentBest')
                    pp_fns.append(pp_fn)
                    pp_fn_kwargs.append(kwargs)
                else:
                    print(f'Removing all but the largest component for {label_or_region} did not improve results! '
                          f'Dice before: {round(baseline_results["mean"][label_or_region]["Dice"], 5)} '
                          f'after: {round(pp_results["mean"][label_or_region]["Dice"], 5)}')
    [shutil.copy(join(source, i), join(output_folder, i)) for i in subfiles(source, join=False)]
    save_pickle((pp_fns, pp_fn_kwargs), join(folder_predictions, 'postprocessing.pkl'))

    baseline_results = load_summary_json(join(folder_predictions, 'summary.json'))
    final_results = load_summary_json(join(output_folder, 'summary.json'))
    tmp = {
        'input_folder': {i: baseline_results[i] for i in ['foreground_mean', 'mean']},
        'postprocessed': {i: final_results[i] for i in ['foreground_mean', 'mean']},
        'postprocessing_fns': [i.__name__ for i in pp_fns],
        'postprocessing_kwargs': pp_fn_kwargs,
    }
    # json is very annoying. Can't handle tuples as dict keys.
    tmp['input_folder']['mean'] = {label_or_region_to_key(k): tmp['input_folder']['mean'][k] for k in
                                   tmp['input_folder']['mean'].keys()}
    tmp['postprocessed']['mean'] = {label_or_region_to_key(k): tmp['postprocessed']['mean'][k] for k in
                                    tmp['postprocessed']['mean'].keys()}
    # did I already say that I hate json? "TypeError: Object of type int64 is not JSON serializable"
    recursive_fix_for_json_export(tmp)
    save_json(tmp, join(folder_predictions, 'postprocessing.json'))

    shutil.rmtree(join(output_folder, 'temp'))

    if not keep_postprocessed_files:
        shutil.rmtree(output_folder)
    return pp_fns, pp_fn_kwargs


def apply_postprocessing_to_folder(input_folder: str,
                                   output_folder: str,
                                   pp_fns: List[Callable],
                                   pp_fn_kwargs: List[dict],
                                   plans_file_or_dict: Union[str, dict] = None,
                                   dataset_json_file_or_dict: Union[str, dict] = None,
                                   num_processes=8) -> None:
    """
    If plans_file_or_dict or dataset_json_file_or_dict are None, we will look for them in input_folder
    """
    if plans_file_or_dict is None:
        expected_plans_file = join(input_folder, 'plans.json')
        if not isfile(expected_plans_file):
            raise RuntimeError(f"Expected plans file missing: {expected_plans_file}. The plans file should have been "
                               f"created while running nnUNetv2_predict. Sadge. If the folder you want to apply "
                               f"postprocessing to was create from an ensemble then just specify one of the "
                               f"plans files of the ensemble members in plans_file_or_dict")
        plans_file_or_dict = load_json(expected_plans_file)
    plans_manager = PlansManager(plans_file_or_dict)

    if dataset_json_file_or_dict is None:
        expected_dataset_json_file = join(input_folder, 'dataset.json')
        if not isfile(expected_dataset_json_file):
            raise RuntimeError(
                f"Expected plans file missing: {expected_dataset_json_file}. The dataset.json should have been "
                f"copied while running nnUNetv2_predict/nnUNetv2_ensemble. Sadge.")
        dataset_json_file_or_dict = load_json(expected_dataset_json_file)

    if not isinstance(dataset_json_file_or_dict, dict):
        dataset_json = load_json(dataset_json_file_or_dict)
    else:
        dataset_json = dataset_json_file_or_dict

    rw = plans_manager.image_reader_writer_class()

    maybe_mkdir_p(output_folder)
    with multiprocessing.get_context("spawn").Pool(num_processes) as p:
        files = subfiles(input_folder, suffix=dataset_json['file_ending'], join=False)

        _ = p.starmap(load_postprocess_save,
                      zip(
                          [join(input_folder, i) for i in files],
                          [join(output_folder, i) for i in files],
                          [rw] * len(files),
                          [pp_fns] * len(files),
                          [pp_fn_kwargs] * len(files)
                      )
                      )


def entry_point_determine_postprocessing_folder():
    parser = argparse.ArgumentParser('Writes postprocessing.pkl and postprocessing.json in input_folder.')
    parser.add_argument('-i', type=str, required=True, help='Input folder')
    parser.add_argument('-ref', type=str, required=True, help='Folder with gt labels')
    parser.add_argument('-plans_json', type=str, required=False, default=None,
                        help="plans file to use. If not specified we will look for the plans.json file in the "
                             "input folder (input_folder/plans.json)")
    parser.add_argument('-dataset_json', type=str, required=False, default=None,
                        help="dataset.json file to use. If not specified we will look for the dataset.json file in the "
                             "input folder (input_folder/dataset.json)")
    parser.add_argument('-np', type=int, required=False, default=default_num_processes,
                        help=f"number of processes to use. Default: {default_num_processes}")
    parser.add_argument('--remove_postprocessed', action='store_true', required=False,
                        help='set this is you don\'t want to keep the postprocessed files')

    args = parser.parse_args()
    determine_postprocessing(args.i, args.ref, args.plans_json, args.dataset_json, args.np,
                             not args.remove_postprocessed)


def entry_point_apply_postprocessing():
    parser = argparse.ArgumentParser('Apples postprocessing specified in pp_pkl_file to input folder.')
    parser.add_argument('-i', type=str, required=True, help='Input folder')
    parser.add_argument('-o', type=str, required=True, help='Output folder')
    parser.add_argument('-pp_pkl_file', type=str, required=True, help='postprocessing.pkl file')
    parser.add_argument('-np', type=int, required=False, default=default_num_processes,
                        help=f"number of processes to use. Default: {default_num_processes}")
    parser.add_argument('-plans_json', type=str, required=False, default=None,
                        help="plans file to use. If not specified we will look for the plans.json file in the "
                             "input folder (input_folder/plans.json)")
    parser.add_argument('-dataset_json', type=str, required=False, default=None,
                        help="dataset.json file to use. If not specified we will look for the dataset.json file in the "
                             "input folder (input_folder/dataset.json)")
    args = parser.parse_args()
    pp_fns, pp_fn_kwargs = load_pickle(args.pp_pkl_file)
    apply_postprocessing_to_folder(args.i, args.o, pp_fns, pp_fn_kwargs, args.plans_json, args.dataset_json, args.np)


if __name__ == '__main__':
    trained_model_folder = '/home/fabian/results/nnUNet_remake/Dataset004_Hippocampus/nnUNetTrainer__nnUNetPlans__3d_fullres'
    labelstr = join(nnUNet_raw, 'Dataset004_Hippocampus', 'labelsTr')
    plans_manager = PlansManager(join(trained_model_folder, 'plans.json'))
    dataset_json = load_json(join(trained_model_folder, 'dataset.json'))
    folds = (0, 1, 2, 3, 4)
    label_manager = plans_manager.get_label_manager(dataset_json)

    merged_output_folder = join(trained_model_folder, f'crossval_results_folds_{folds_tuple_to_string(folds)}')
    accumulate_cv_results(trained_model_folder, merged_output_folder, folds, 8, False)

    fns, kwargs = determine_postprocessing(merged_output_folder, labelstr, plans_manager.plans,
                                           dataset_json, 8, keep_postprocessed_files=True)
    save_pickle((fns, kwargs), join(trained_model_folder, 'postprocessing.pkl'))
    fns, kwargs = load_pickle(join(trained_model_folder, 'postprocessing.pkl'))

    apply_postprocessing_to_folder(merged_output_folder, merged_output_folder + '_pp', fns, kwargs,
                                   plans_manager.plans, dataset_json,
                                   8)
    compute_metrics_on_folder(labelstr,
                              merged_output_folder + '_pp',
                              join(merged_output_folder + '_pp', 'summary.json'),
                              plans_manager.image_reader_writer_class(),
                              dataset_json['file_ending'],
                              label_manager.foreground_regions if label_manager.has_regions else label_manager.foreground_labels,
                              label_manager.ignore_label,
                              8)
