# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import glob

import skimage.io

import numpy as np
import torch
from torchvision import datasets
import random


class SplitNNCXRDataset(object):  # TODO: use torch.utils.data.Dataset with batch sampling
    def __init__(self, root, dataset="training", transform=None,
                 target_names=["COVID", "Lung_Opacity", "Normal", "Viral Pneumonia"],
                 seed=0):
        """CIFAR-10 dataset with index to extract a mini-batch based on given batch indices
        Useful for SplitNN training

        Args:
            root: data root
            data_idx: to specify the data for a particular client site.
                If index provided, extract subset, otherwise use the whole set
            dataset: which set to use
            transform: image transforms
        Returns:
            A PyTorch dataset
        """
        random.seed(seed)

        self.root = root
        self.dataset = dataset
        self.transform = transform
        self.target_names = target_names

        self.train_files = None
        self.valid_files = None
        self.test_files = None
        self._split_data()

        self.data, self.target = self._load_data()

        assert(self.dataset, str), "Choose a string to select the dataset"

    def _split_data(self):
        files = glob.glob(os.path.join(self.root, "**", "*.png"), recursive=True)
        assert len(files) > 0, f"No images found in {self.root}"
        random.shuffle(files)

        # DEBUG
        files = files[0:500]

        # TODO: use fixed datalist?
        splits = [0.7, 0.1, 0.2]  # train, val, test
        idx1 = int(splits[0]*len(files))
        idx2 = int((splits[0] + splits[1])*len(files))

        self.train_files = files[0:idx1]
        self.valid_files = files[idx1:idx2]
        self.test_files = files[idx2::]

        assert(len(np.intersect1d(self.train_files, self.valid_files)) == 0)
        assert(len(np.intersect1d(self.train_files, self.test_files)) == 0)
        assert (len(np.intersect1d(self.valid_files, self.test_files)) == 0)

        print(f"split image into {len(self.train_files)}/{len(self.valid_files)}/{len(self.test_files)} "
              f"parts using ratios {splits}.")

    def _load_data(self):
        # find images
        data, target = [], []

        if self.dataset == "training":
            files = self.train_files
        elif self.dataset == "validation":
            files = self.valid_files
        elif self.dataset == "testing":
            files = self.test_files
        else:
            raise ValueError(f"No such dataset supported {self.dataset}")

        for i, file in enumerate(files):
            if i % 1000 == 0:
                print(f"loading {i+1} of {len(files)}: {file}")
            img = skimage.io.imread(file)
            if len(np.shape(img)) > 2:
                img = img[:, :, 0]  # use only first channel if rgb
            img = np.asarray(img, dtype=np.float32)
            if np.max(img) > 1.0:
                img = img/255.0
            # find target
            tar = None
            for t, target_name in enumerate(self.target_names):
                if target_name in file:
                    tar = t
                assert tar is not None, f"No matching target found for {file}!"
            data.append(img)
            target.append(tar)
        return data, target

    def __getitem__(self, index):
        img, target = self.data[index], self.target[index]
        if self.transform is not None:
            img = self.transform(img)
        return img, target

    # TODO: this can probably made more efficient using batch_sampler
    def get_batch(self, batch_indices, returns="all"):
        assert len(batch_indices) > 0, "empty batch indices provided!"
        img_batch = []
        target_batch = []
        for idx in batch_indices:
            img, target = self.__getitem__(idx)
            img_batch.append(img)
            target_batch.append(torch.tensor(target, dtype=torch.long))
        img_batch = torch.stack(img_batch, dim=0)
        target_batch = torch.stack(target_batch, dim=0)

        #print(" img_batch", torch.mean(img_batch), torch.min(img_batch), torch.max(img_batch), img_batch.shape, type(img_batch))
        #print(" target_batch", torch.unique(target_batch), target_batch.shape, type(target_batch))

        if returns == "all":
            return img_batch, target_batch
        elif returns == "image":
            return img_batch
        elif returns == "label":
            return target_batch
        else:
            raise ValueError(f"Expected `returns` to be 'all', 'image', or 'label', but got '{returns}'")

    def __len__(self):
        return len(self.data)
