from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np
from configs.config_setting import setting_config as config
from PIL import Image
from einops.layers.torch import Rearrange
from scipy.ndimage.morphology import binary_dilation
from torch.utils.data import Dataset
from torchvision import transforms
import cv2
from utils import *

## Temporary
class dataset_loader(Dataset):
    """ dataset class for Brats datasets
    """
    def __init__(self, path_Data, train = True, Test = False):
        super(dataset_loader, self)
        self.train = train
        if train:
          self.data   = np.load(path_Data+'data_train.npy')
          self.mask   = np.load(path_Data+'mask_train.npy')
        else:
          if Test:
            self.data   = np.load(path_Data+'data_test.npy')
            self.mask   = np.load(path_Data+'mask_test.npy')
          else:
            self.data   = np.load(path_Data+'data_val.npy')
            self.mask   = np.load(path_Data+'mask_val.npy')          
        
        # 数据归一化
        self.data = self.data / 255.

        # 将离散的mask值映射到连续类别索引
        self._preprocess_mask()

        # 生成one-hot编码
        self.mask_one_hot = np.eye(config.num_classes)[self.mask]
        self.mask_one_hot = np.moveaxis(self.mask_one_hot, -1, 1)  # [b, c, h, w]

    def _preprocess_mask(self):
        """将原始mask值[0,38,75,113]映射到[0,1,2,3]"""
        original_values = [0, 38, 75, 113]
        mapped_values = {v: i for i, v in enumerate(original_values)}
        
        # 使用向量化操作提高效率
        mask_mapped = np.zeros_like(self.mask, dtype=np.int8)
        for orig_val, mapped_val in mapped_values.items():
            mask_mapped[self.mask == orig_val] = mapped_val
        
        self.mask = mask_mapped

    def __getitem__(self, idx):
        image = torch.tensor(self.data[idx]).float().permute(2, 0, 1)  # 转换为 tensor
        mask = torch.tensor(self.mask_one_hot[idx]).float()  # 转换为 one-hot 编码 tensor
        return image, mask
          
    def __len__(self):
        return len(self.data)
    