import os
from glob import glob
import numpy as np
import json
import random
import ants
from tqdm import tqdm
import torch
from torch.utils.data import Dataset
from torch.multiprocessing import Pool
from scipy.ndimage import morphology
from .unidirection_transforms import *


data_mapping = {
    'Dxx': 4,
    'Dyy': 5,
    'Dzz': 6,
    'Dxy': 7,
    'Dxz': 8,
    'Dyz': 9,
}
data_type_mapping = {
    0: 'b0',
    1: 'b1000x',
    2: 'b1000y',
    3: 'b1000z',
}
tensor_minmaxv_mapping = {
    'Dxx': [0.0, 0.002],
    'Dyy': [0.0, 0.002],
    'Dzz': [0.0, 0.002],
    'Dxy': [-0.0005, 0.0005],
    'Dxz': [-0.0005, 0.0005],
    'Dyz': [-0.0005, 0.0005],
}

class UnidirectionDiffusion(Dataset):
    def __init__(self, dataroot, stage):
        super().__init__()
        self.dataroot = dataroot
        self.phase = stage
        if self.phase == 'train':
            self.transform = get_train_transforms()
        elif self.phase == 'val':
            self.transform = get_test_transforms()
        elif self.phase == 'test':
            self.transform = get_test_transforms()
        
        b_dataset = glob(f'{self.dataroot}/{self.phase}/*/data_0_1000xyz.nii.gz')
        b_dataset = [[p, idx] for p in b_dataset for idx in range(4)]
        tensor_dataset = glob(f'{self.dataroot}/{self.phase}/*/DTI/tensor/*')
        tensor_dataset = [[p, data_mapping[p.split('/')[-1].split('_')[1].split('.')[0]]] for p in tensor_dataset]
        self.patient_list = [p[0].split('/')[-4] for p in tensor_dataset]
        
        if self.phase == 'train':
            self.dataset = b_dataset + tensor_dataset
        else:
            self.dataset = tensor_dataset

        print(f'{self.phase} dataset size - patient: {len(self.patient_list)}, volume: {len(self.dataset)}')

    def __getitem__(self, index):
        data_path, data_idx = self.dataset[index]
        if data_idx < 4:
            patient_dir = '/'.join(data_path.split('/')[:-1])
            data_type = data_type_mapping[data_idx]
            minmaxv_path = os.path.join(patient_dir, 'minmaxv.npy')
            
            nifti_b0_1000xyz = ants.image_read(data_path)
            minmaxv = np.load(minmaxv_path)
            minmaxv = [0.0, minmaxv[1]] if data_idx == 0 else [0.0, minmaxv[3]]
            
            b0_1000xyz = nifti_b0_1000xyz.numpy()
            data = b0_1000xyz[..., data_idx]
        else:
            patient_dir = '/'.join(data_path.split('/')[:-3])
            data_type = data_path.split('/')[-1].split('_')[1].split('.')[0]

            nifti_data = ants.image_read(data_path)
            
            data = nifti_data.numpy()
            minmaxv = tensor_minmaxv_mapping[data_type]
            
        mask_path = os.path.join(patient_dir, 'mask.nii.gz')
        nifti_mask = ants.image_read(mask_path)
        mask = nifti_mask.numpy().astype(np.uint8)
        
        return self.transform({'data': data, 'mask': mask, 'data_idx': data_idx, 'data_type': data_type, 'minmaxv': minmaxv})
        
    def __len__(self):
        """Return the total number of images in the dataset."""
        return len(self.dataset)
