# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import glob
import json
import sys

import skimage.io

import numpy as np
import torch
from torchvision import datasets
import random
import nibabel as nib

from monai.networks import one_hot
from monai.utils.type_conversion import convert_to_tensor
from monai.transforms import LoadImage, EnsureChannelFirst, Orientation

DEBUG = False


class SplitNNMedicalDataset(object):  # TODO: use torch.utils.data.Dataset with batch sampling
    def __init__(self, img_root,
                 label_root,
                 dataset="training", 
                 img_transform=None,
                 search_ext=".nii.gz",
                 to_onehot_y: bool = True,
                 num_classes: int = 2,
                 spatial_transforms=None,
                 save_dataset_list=None,
                 seed=0):
        """Medical segmentation dataset with index to extract a mini-batch based on given batch indices
        Useful for SplitNN training

        Args:
            img_root: image root
            label_root: label root
            dataset: which set to use
            img_transform: image transforms
            num_classes:
        Returns:
            A PyTorch dataset
        """
        np.random.seed(seed)
        random.seed(seed)
        torch.manual_seed(seed)

        self.img_root = img_root
        self.label_root = label_root
        self.dataset = dataset
        self.img_transform = img_transform
        self.search_ext = search_ext
        self.to_onehot_y = to_onehot_y
        self.num_classes = num_classes
        self.spatial_transforms = spatial_transforms
        self.save_dataset_list = save_dataset_list

        self.train_files = None
        self.valid_files = None
        self.test_files = None

        if ".nii" in self.search_ext:
            self.img_loader = LoadImage()  # TODO: use in next monai release (ensure_channel_first=True)
            self.channel_first = EnsureChannelFirst()
            self.orienter = Orientation(axcodes="RAS")
        else:
            self.img_loader = None
            self.channel_first = None
            self.orienter = None

        self._split_data()
        self.data, self.target = self._load_data()

        assert(self.dataset, str), "Choose a string to select the dataset"

    def _split_data(self):
        files = glob.glob(os.path.join(self.img_root, "**", "*" + self.search_ext), recursive=True)
        assert len(files) > 0, f"No images found in {self.img_root}"
        random.shuffle(files)

        # DEBUG
        if DEBUG:
            files = files[0:500]

        # TODO: use fixed datalist?
        splits = [0.7, 0.1, 0.2]  # train, val, test
        idx1 = int(splits[0]*len(files))
        idx2 = int((splits[0] + splits[1])*len(files))

        self.train_files = files[0:idx1]
        self.valid_files = files[idx1:idx2]
        self.test_files = files[idx2::]

        assert(len(np.intersect1d(self.train_files, self.valid_files)) == 0)
        assert(len(np.intersect1d(self.train_files, self.test_files)) == 0)
        assert (len(np.intersect1d(self.valid_files, self.test_files)) == 0)

        if self.save_dataset_list is not None:
            dataset_list = {"training": self.train_files,
                            "validation": self.valid_files,
                            "testing": self.test_files}
            os.makedirs(os.path.dirname(self.save_dataset_list), exist_ok=True)
            with open(self.save_dataset_list, "w") as f:
                json.dump(dataset_list, f, indent=4)
            print("Saved data list at", self.save_dataset_list)

        print(f"split image into {len(self.train_files)}/{len(self.valid_files)}/{len(self.test_files)} "
              f"parts using ratios {splits}.")

    def _load_image(self, file):
        if ".png" in file:
            img = skimage.io.imread(file)
            if len(np.shape(img)) > 2:
                img = img[:, :, 0]  # use only first channel if rgb
            img = np.asarray(img, dtype=np.float32)
            if np.max(img) > 1.0:
                img = img / 255.0
        elif ".nii.gz" in file or ".nii" in file:
            img, meta = self.img_loader(file)
            img = self.channel_first(img, meta)
            # img = self.orienter(img, affine=meta["affine"])  # not used for 2D...
            # remove last dim if 2D
            img = np.squeeze(img, axis=-1)
        else:
            raise ValueError(f"Loading {file} not supported")
        return img

    def _load_data(self):
        # find images
        data, target = [], []

        if self.dataset == "training":
            files = self.train_files
        elif self.dataset == "validation":
            files = self.valid_files
        elif self.dataset == "testing":
            files = self.test_files
        else:
            raise ValueError(f"No such dataset supported {self.dataset}")

        for i, img_file in enumerate(files):
            if i % 1000 == 0:
                print(f"{self.dataset} loading {i+1} of {len(files)}: {img_file}")

            # find target
            tar_file = glob.glob(os.path.join(self.label_root, os.path.basename(img_file)))
            assert len(tar_file) == 1, f"No matching label found for {os.path.basename(img_file)} in {self.label_root}"
            tar_file = tar_file[0]

            img = convert_to_tensor(self._load_image(img_file), dtype=torch.float32)
            tar = convert_to_tensor(self._load_image(tar_file), dtype=torch.float32)

            if self.to_onehot_y:
                tar = one_hot(tar, dim=0, num_classes=self.num_classes)

            data.append(img)
            target.append(tar)
        return data, target

    def __getitem__(self, index, img_channel):
        """ Get batch item

        Args:
            index: index of data entry
            img_channel: int for channel or list of channels to return. If None, return all channels.
        Returns:
            Mini-batch data of img and target.
        """

        img, target = self.data[index], self.target[index]
        if self.img_transform is not None:
            img = self.img_transform(img)
        if self.spatial_transforms is not None:
            #os.makedirs("./DEBUG", exist_ok=True)
            #print("1@@@ img", np.shape(img))
            #print("1@@@ target", np.shape(target))
            #nib.save(nib.Nifti1Image(img.numpy(), np.eye(4)), f"./DEBUG/before_spat_img_{index}.nii.gz")
            #nib.save(nib.Nifti1Image(target.numpy(), np.eye(4)), f"./DEBUG/before_spat_target_{index}.nii.gz")

            # assumes spatial transforms are applied to both image and label
            nch_img = img.shape[0]
            combined = torch.cat([img, target], dim=0)  # TODO: this is a  real hack
            combined = self.spatial_transforms(combined)
            img = combined[0:nch_img, ...]
            target = combined[nch_img::, ...]

        if img_channel is not None:
            if isinstance(img_channel, list):
                img = img[img_channel, ...]
            elif isinstance(img_channel, int):
                img = img[[img_channel], ...]
            else:
                raise ValueError(f"`img_channel` should be `list` or `int` but was {type(self.img_channel)}")

        if DEBUG:
            os.makedirs("./DEBUG", exist_ok=True)
            print("@@@ img", np.shape(img))
            print("@@@ target", np.shape(target))
            nib.save(nib.Nifti1Image(img.numpy(), np.eye(4)), f"./DEBUG/img_{index}.nii.gz")
            nib.save(nib.Nifti1Image(target.numpy(), np.eye(4)), f"./DEBUG/target_{index}.nii.gz")

        return img, target

    # TODO: this can probably made more efficient using batch_sampler
    def get_batch(self, batch_indices, returns="all", img_channel=None):
        assert len(batch_indices) > 0, "empty batch indices provided!"
        img_batch = []
        target_batch = []
        for idx in batch_indices:
            img, target = self.__getitem__(index=idx, img_channel=img_channel)
            img_batch.append(img)
            target_batch.append(target)
        img_batch = torch.stack(img_batch, dim=0)
        target_batch = torch.stack(target_batch, dim=0)

        #print(" img_batch", torch.mean(img_batch), torch.min(img_batch), torch.max(img_batch), img_batch.shape, type(img_batch))
        #print(" target_batch", torch.unique(target_batch), target_batch.shape, type(target_batch))

        if returns == "all":
            return img_batch, target_batch
        elif returns == "image":
            return img_batch
        elif returns == "label":
            return target_batch
        else:
            raise ValueError(f"Expected `returns` to be 'all', 'image', or 'label', but got '{returns}'")

    def __len__(self):
        return len(self.data)

    def get_data(self):
        return self.data, self.target
