import os

import torch
import numpy as np
from tqdm import tqdm
import torchio as tio

from data.brats2020.BRATS2020_Unprocessed import BRATSDatasetUnprocessed


def process2numpy(dataset: BRATSDatasetUnprocessed, save_dir: str):
    """
    the dataset is initially a 4D tensor (channels, height, width, slices), we want to save each slice as a numpy
    """
    pbar = tqdm(dataset)
    for i, (x, y) in enumerate(pbar, start=1):

        for _i, _x in enumerate(x):
            non_zeros = _x > 0
            # low, high = torch.quantile(_x[non_zeros], torch.Tensor([0.01, 0.99]))
            rescale = tio.RescaleIntensity(
                out_min_max=(-1, 1),
                # percentiles=(1, 99),
                # in_min_max=(low.item(), high.item()),
            )
            # input is x y z but should be channels x y z
            _x = _x.unsqueeze(0)
            _x = rescale(_x).squeeze(0)
            x[_i] = _x

        # bottom_remove = 80
        # top_remove = 26
        # x = x[..., bottom_remove:-top_remove]
        # y = y[..., bottom_remove:-top_remove]

        data = torch.cat([x, y], dim=0)

        channels, height, width, slices = data.shape
        dataset_length = len(str(len(dataset)))
        slice_length = len(str(slices))

        for s in range(slices):
            pbar.set_description(f'processing slice {s}/{slices}')
            name = f'{i:0{dataset_length}}_{s:0{slice_length}}'

            slice_data = data[..., s].numpy()

            np.save(os.path.join(save_dir, name), slice_data)


def _main_process2numpy():
    """
    Responsible to transform the dataset to numpy arrays
    """
    unprocessed_path = r'MICCAI_BraTS2020_TrainingData'
    processed_path = r'brats2020_processed_run/'
    dataset = BRATSDatasetUnprocessed(unprocessed_path, test_flag=False)
    print(len(dataset))
    x, y = dataset[0]
    print(x.shape)
    print(y.shape)

    process2numpy(dataset, processed_path)

    print('done')


if __name__ == '__main__':
    _main_process2numpy()
