# Copyright 2020 - 2022 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import os

import numpy as np
import torch

from monai import data, transforms
from monai.data import load_decathlon_datalist
from .reader import Transpose
import math
import random

# bg_value = {'CT':-4.25,'PET':-1.10,'PET_CT':-4.25}
# all_keys = ['t1c', 't2w', 't1n', 't2f',]
all_keys = ['t1n', 't1c', 't2w',  't2f', 'seg']

modal_keys = ['t1n', 't1c', 't2w',  't2f', ]

class Sampler(torch.utils.data.Sampler):
    def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, make_even=True):
        if num_replicas is None:
            if not torch.distributed.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = torch.distributed.get_world_size()
        if rank is None:
            if not torch.distributed.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = torch.distributed.get_rank()
        self.shuffle = shuffle
        self.make_even = make_even
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas
        indices = list(range(len(self.dataset)))
        self.valid_length = len(indices[self.rank : self.total_size : self.num_replicas])

    def __iter__(self):
        if self.shuffle:
            g = torch.Generator()
            g.manual_seed(self.epoch)
            indices = torch.randperm(len(self.dataset), generator=g).tolist()
        else:
            indices = list(range(len(self.dataset)))
        if self.make_even:
            if len(indices) < self.total_size:
                if self.total_size - len(indices) < len(indices):
                    indices += indices[: (self.total_size - len(indices))]
                else:
                    extra_ids = np.random.randint(low=0, high=len(indices), size=self.total_size - len(indices))
                    indices += [indices[ids] for ids in extra_ids]
            assert len(indices) == self.total_size
        indices = indices[self.rank : self.total_size : self.num_replicas]
        self.num_samples = len(indices)
        return iter(indices)

    def __len__(self):
        return self.num_samples

    def set_epoch(self, epoch):
        self.epoch = epoch


def get_nnunet_loader(args):
    data_dir = args.data_dir
    datalist_json = os.path.join(data_dir, args.json_list)
    # train transform
    train_transform_list =[transforms.LoadImaged(keys=all_keys)]
    train_transform_list += [
        transforms.AddChanneld(keys=all_keys),
        transforms.NormalizeIntensityd(keys=modal_keys),
        transforms.CropForegroundD(keys=all_keys, source_key=all_keys[0], select_fn=lambda x: (x >= -0.35)),
        transforms.RandSpatialCropd(
                keys=all_keys,roi_size=(args.roi_x, args.roi_y, args.roi_z), random_size=False
        ),
        transforms.RandFlipd(keys=all_keys, prob=args.RandFlipd_prob, spatial_axis=0),
        transforms.RandFlipd(keys=all_keys, prob=args.RandFlipd_prob, spatial_axis=1),
        transforms.RandFlipd(keys=all_keys, prob=args.RandFlipd_prob, spatial_axis=2),
        transforms.RandRotate90d(keys=all_keys, prob=args.RandRotate90d_prob, max_k=3),
        transforms.ToTensord(keys=all_keys),
        ]
    train_transform = transforms.Compose(train_transform_list)

    # val transform
    val_transform_list = [transforms.LoadImaged(keys=modal_keys)]
    val_transform_list += [
        transforms.NormalizeIntensityd(keys=modal_keys)]
    val_transform_list += [transforms.ToTensord(keys=modal_keys)]
    val_transform = transforms.Compose(val_transform_list)

    # test transform
    test_transform_list = [transforms.LoadImaged(keys=modal_keys)]
    test_transform_list += [
        # transforms.AddChanneld(keys=all_keys),
        # transforms.Resized(keys=all_keys,spatial_size=(args.roi_x, args.roi_y, args.roi_z)),
        transforms.NormalizeIntensityd(keys=modal_keys)]
    test_transform_list += [transforms.ToTensord(keys=modal_keys)]
    test_transform = transforms.Compose(test_transform_list)
    
    if args.test_mode:
        test_files = load_decathlon_datalist(datalist_json, True, "validation", base_dir=data_dir)
        test_ds = data.Dataset(data=test_files, transform=test_transform)
        test_sampler = Sampler(test_ds, shuffle=False) if args.distributed else None
        test_loader = data.DataLoader(
            test_ds,
            batch_size=1,
            shuffle=False,
            # num_workers=args.workers,
            num_workers=1,
            sampler=test_sampler,
            pin_memory=True,
            persistent_workers=True,
        )
        loader = test_loader
    else:
        datalist = load_decathlon_datalist(datalist_json, True, "training", base_dir=data_dir)
        if args.use_normal_dataset:
            train_ds = data.Dataset(data=datalist, transform=train_transform)
        else:
            train_ds = data.CacheDataset(
                data=datalist, transform=train_transform, cache_num=24, cache_rate=1.0, num_workers=args.workers
            )
        train_sampler = Sampler(train_ds) if args.distributed else None
        train_loader = data.DataLoader(
            train_ds,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            num_workers=args.workers,
            sampler=train_sampler,
            pin_memory=True,
        )
        val_files = load_decathlon_datalist(datalist_json, True, "validation", base_dir=data_dir)
        val_ds = data.Dataset(data=val_files, transform=val_transform)
        val_sampler = Sampler(val_ds, shuffle=False) if args.distributed else None
        val_loader = data.DataLoader(
            val_ds, 
            batch_size=1, 
            shuffle=False, 
            num_workers=args.workers, 
            # num_workers=0,
            sampler=val_sampler, 
            pin_memory=True
        )
        loader = [train_loader, val_loader]

    return loader




