import os
from glob import glob
import numpy as np
import json
import random
import ants
from tqdm import tqdm
import torch
from torch.utils.data import Dataset
from torch.multiprocessing import Pool
from scipy.ndimage import morphology
from .transforms_joint import *


class DiffusionJoint(Dataset):
    def __init__(self, dataroot, stage):
        super().__init__()
        print(dataroot)
        self.dataroot = dataroot
        self.phase = stage
        if self.phase == 'train':
            self.transform = get_train_transforms()
        elif self.phase == 'val':
            self.transform = get_test_transforms()
        elif self.phase == 'test':
            self.transform = get_test_transforms()

        self.dataset = glob(f'{self.dataroot}/{self.phase}/*/DTI/tensor')
        self.patient_list = [p.split('/')[-3] for p in self.dataset]

        print(f'{self.phase} dataset size - patient: {len(self.patient_list)}, volume: {len(self.dataset)}')

    def __getitem__(self, index):
        tensor_dir = self.dataset[index]
        patient_dir = '/'.join(tensor_dir.split('/')[:-2])
        
        tensor_xx_path = os.path.join(tensor_dir, 'tensor_Dxx.nii.gz')
        tensor_yy_path = os.path.join(tensor_dir, 'tensor_Dyy.nii.gz')
        tensor_zz_path = os.path.join(tensor_dir, 'tensor_Dzz.nii.gz')
        tensor_xy_path = os.path.join(tensor_dir, 'tensor_Dxy.nii.gz')
        tensor_xz_path = os.path.join(tensor_dir, 'tensor_Dxz.nii.gz')
        tensor_yz_path = os.path.join(tensor_dir, 'tensor_Dyz.nii.gz')
        data_0_1000xyz_path = os.path.join(patient_dir, 'data_0_1000xyz.nii.gz')
        mask_path = os.path.join(patient_dir, 'mask.nii.gz')
        bvecs_0_1000xyz_path = os.path.join(patient_dir, 'bvecs_0_1000xyz.txt')
        minmaxv_path = os.path.join(patient_dir, 'minmaxv.npy')
        
        nifti_tensor_xx = ants.image_read(tensor_xx_path)
        nifti_tensor_yy = ants.image_read(tensor_yy_path)
        nifti_tensor_zz = ants.image_read(tensor_zz_path)
        nifti_tensor_xy = ants.image_read(tensor_xy_path)
        nifti_tensor_xz = ants.image_read(tensor_xz_path)
        nifti_tensor_yz = ants.image_read(tensor_yz_path)
        nifti_b0_1000xyz = ants.image_read(data_0_1000xyz_path)
        nifti_mask = ants.image_read(mask_path)
        bvecs_0_1000xyz = np.loadtxt(bvecs_0_1000xyz_path)
        minmaxv = np.load(minmaxv_path)
        
        tensor_xx = nifti_tensor_xx.numpy()
        tensor_yy = nifti_tensor_yy.numpy()
        tensor_zz = nifti_tensor_zz.numpy()
        tensor_xy = nifti_tensor_xy.numpy()
        tensor_xz = nifti_tensor_xz.numpy()
        tensor_yz = nifti_tensor_yz.numpy()
        b0_1000xyz = nifti_b0_1000xyz.numpy()
        b0 = b0_1000xyz[..., 0]
        b1000x = b0_1000xyz[..., 1]
        b1000y = b0_1000xyz[..., 2]
        b1000z = b0_1000xyz[..., 3]
        mask = nifti_mask.numpy().astype(np.uint8)
        
        b1000xyz_dir = bvecs_0_1000xyz[1:]
        
        return self.transform({'tensor_xx': tensor_xx, 'tensor_yy': tensor_yy, 'tensor_zz': tensor_zz, 'tensor_xy': tensor_xy, 'tensor_xz': tensor_xz, 'tensor_yz': tensor_yz, 
                               'b0': b0, 'b1000x': b1000x, 'b1000y': b1000y, 'b1000z': b1000z, 'mask': mask, 'b1000xyz_dir': b1000xyz_dir, 'minmaxv': minmaxv})
        
    def __len__(self):
        """Return the total number of images in the dataset."""
        return len(self.dataset)
