import os
from glob import glob
import numpy as np
import json
import random
import ants
from tqdm import tqdm
import torch
from torch.utils.data import Dataset
from torch.multiprocessing import Pool
from scipy.ndimage import morphology
from .multidirection_transforms import *


class MultiDirectionDiffusion(Dataset):
    def __init__(self, dataroot, stage):
        super().__init__()
        self.dataroot = dataroot
        self.phase = stage
        if self.phase == 'train':
            self.transform = get_train_transforms()
        elif self.phase == 'val':
            self.transform = get_test_transforms()
        elif self.phase == 'test':
            self.transform = get_test_transforms()

        self.dataset = glob(f'{self.dataroot}/{self.phase}/*')
        self.patient_list = [p.split('/')[-1] for p in self.dataset]

        print(f'{self.phase} dataset size - patient: {len(self.patient_list)}, volume: {len(self.dataset)}')

    def __getitem__(self, index):
        patient_dir = self.dataset[index]
        
        tensor_xx_path = os.path.join(patient_dir, 'DTI', 'tensor', 'tensor_Dxx.nii.gz')
        tensor_xy_path = os.path.join(patient_dir, 'DTI', 'tensor', 'tensor_Dxy.nii.gz')
        tensor_xz_path = os.path.join(patient_dir, 'DTI', 'tensor', 'tensor_Dxz.nii.gz')
        tensor_yy_path = os.path.join(patient_dir, 'DTI', 'tensor', 'tensor_Dyy.nii.gz')
        tensor_yz_path = os.path.join(patient_dir, 'DTI', 'tensor', 'tensor_Dyz.nii.gz')
        tensor_zz_path = os.path.join(patient_dir, 'DTI', 'tensor', 'tensor_Dzz.nii.gz')
        
        data_0_1000xyz_path = os.path.join(patient_dir, 'data_0_1000xyz.nii.gz')
        mask_path = os.path.join(patient_dir, 'mask.nii.gz')
        bvecs_0_1000xyz_path = os.path.join(patient_dir, 'bvecs_0_1000xyz.txt')
        minmaxv_path = os.path.join(patient_dir, 'minmaxv.npy')
        
        nifti_tensor_xx = ants.image_read(tensor_xx_path)
        nifti_tensor_xy = ants.image_read(tensor_xy_path)
        nifti_tensor_xz = ants.image_read(tensor_xz_path)
        nifti_tensor_yy = ants.image_read(tensor_yy_path)
        nifti_tensor_yz = ants.image_read(tensor_yz_path)
        nifti_tensor_zz = ants.image_read(tensor_zz_path)
        nifti_b0_1000xyz = ants.image_read(data_0_1000xyz_path)
        nifti_mask = ants.image_read(mask_path)
        bvecs_0_1000xyz = np.loadtxt(bvecs_0_1000xyz_path)
        minmaxv = np.load(minmaxv_path)
        
        tensor_xx = nifti_tensor_xx.numpy()
        tensor_xy = nifti_tensor_xy.numpy()
        tensor_xz = nifti_tensor_xz.numpy()
        tensor_yy = nifti_tensor_yy.numpy()
        tensor_yz = nifti_tensor_yz.numpy()
        tensor_zz = nifti_tensor_zz.numpy()
        b0_1000xyz = nifti_b0_1000xyz.numpy()
        b0 = b0_1000xyz[..., 0]
        b1000x = b0_1000xyz[..., 1]
        b1000y = b0_1000xyz[..., 2]
        b1000z = b0_1000xyz[..., 3]
        mask = nifti_mask.numpy().astype(np.uint8)
        
        b1000xyz_dir = bvecs_0_1000xyz[1:]
        
        data = self.transform({'tensor_xx': tensor_xx, 'tensor_xy': tensor_xy, 'tensor_xz': tensor_xz, 'tensor_yy': tensor_yy, 'tensor_yz': tensor_yz, 'tensor_zz': tensor_zz, 
                               'b0': b0, 'b1000x': b1000x, 'b1000y': b1000y, 'b1000z': b1000z, 'mask': mask, 'b1000xyz_dir': b1000xyz_dir, 'minmaxv': minmaxv})
        
        data.update({'patient_dir': patient_dir, 'patient_id': patient_dir.split('/')[-1]})
        return data
        
    def __len__(self):
        """Return the total number of images in the dataset."""
        return len(self.dataset)
