"""
https://github.com/JuliaWolleb/Diffusion-based-Segmentation/blob/main/guided_diffusion/bratsloader.py
"""
import os
from typing import Union, Tuple

import nibabel
import torch
from jaxtyping import jaxtyped, Float
from beartype import beartype as typechecker
from torch.utils.data import Dataset
from tqdm import tqdm


class BRATSDatasetUnprocessed(Dataset):
    def __init__(self, directory: str, test_flag: bool):
        """
        directory is expected to contain some folder structure:
                  if some subfolder contains only files, all of these
                  files are assumed to have a name like
                  brats_train_001_XXX_123_w.nii.gz
                  where XXX is one of t1, t1ce, t2, flair, seg
                  we assume these five files belong to the same image
                  seg is supposed to contain the segmentation
        """
        super().__init__()
        self.directory = os.path.expanduser(directory)
        self.test_flag = test_flag

        if test_flag:
            self.seqtypes = ['t1', 't1ce', 't2', 'flair']
        else:
            self.seqtypes = ['t1', 't1ce', 't2', 'flair', 'seg']

        self.seqtypes_set = set(self.seqtypes)
        self.database = []
        for root, dirs, files in os.walk(self.directory):
            # if there are no sub-dirs, we have data
            if dirs:
                continue

            files.sort()
            datapoint = dict()
            # extract all files as channels
            for f in files:
                seqtype = f.split('_')[3].split('.')[0]
                datapoint[seqtype] = os.path.join(root, f)
            assert set(datapoint.keys()) == self.seqtypes_set, \
                f'datapoint is incomplete, keys are {datapoint.keys()}, should be {self.seqtypes_set}'
            self.database.append(datapoint)

    @jaxtyped
    @typechecker
    def __getitem__(
        self,
        i: int,
    ) -> Union[
         Tuple[Float[torch.Tensor, '4 240 240 155'], Float[torch.Tensor, '1 240 240 155']],
         Tuple[Float[torch.Tensor, '4 240 240 155'], str],
     ]:
        out = []
        filedict = self.database[i]
        for seqtype in self.seqtypes:
            nib_img = nibabel.load(filedict[seqtype])
            path = filedict[seqtype]
            data = torch.tensor(nib_img.get_fdata())
            out.append(data)
        out = torch.stack(out)
        if self.test_flag:
            image = out.float()
            return image, path
        else:
            image = out[:-1, ...]
            label = out[-1, ...][None, ...]
            return image.float(), label.float()

    def __len__(self) -> int:
        return len(self.database)


if __name__ == '__main__':
    dataset = BRATSDatasetUnprocessed(
        directory=r'MICCAI_BraTS2020_TrainingData',
        test_flag=False,
    )
    print(len(dataset))
    shape_set_x = set()
    shape_set_y = set()
    for i in tqdm(range(len(dataset))):
        x, y = dataset[i]
        shape_set_x.add(x.shape)
        shape_set_y.add(y.shape)
    print(shape_set_x)
    print(shape_set_y)
    """
    {torch.Size([4, 240, 240, 155])}
    {torch.Size([1, 240, 240, 155])}    
    """

    min_indices = torch.tensor([244, 244, 155])
    max_indices = torch.tensor([0, 0, 0])

    for i in tqdm(range(len(dataset))):
        xs, y = dataset[i]
        for x in xs:
            non_zero_indices = torch.nonzero(x)
            if non_zero_indices.size(0) > 0:
                min_indices = torch.min(min_indices, torch.min(non_zero_indices, dim=0).values)
                max_indices = torch.max(max_indices, torch.max(non_zero_indices, dim=0).values)

    # Calculate the crop size based on the minimum and maximum indices
    print("Minimum Indices:", min_indices)
    print("Maximum Indices:", max_indices)
    crop_size = max_indices - min_indices + 1

    print("Minimum Crop Size (z, y, x):", crop_size)

    """
    Minimum Indices: tensor([40, 29,  0])
    Maximum Indices: tensor([196, 222, 148])
    Minimum Crop Size (z, y, x): tensor([157, 194, 149])
    """
