import os
import torch
from torch.utils.data import Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
import SimpleITK as sitk
import ast
import numpy as np


#----------------------------------------------------
# 1. 有标签数据集
#----------------------------------------------------
class LabeledDataset(Dataset):
    def __init__(self, dataframe, image_size=512, heatmap_size=128, sigma=3.0, train=True):
        self.data = dataframe
        self.image_size = image_size
        self.heatmap_size = heatmap_size
        self.train = train
        self.sigma = sigma
        
        self.train_transform = A.Compose([
            A.Resize(self.image_size, self.image_size),
            
            # --- 基础几何增强 ---
            A.HorizontalFlip(p=0.5),
            A.ShiftScaleRotate(shift_limit=0.08, scale_limit=0.15, rotate_limit=15, p=0.7, border_mode=0),
            
            # --- 强大的几何畸变 ---
            A.ElasticTransform(p=0.5, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03),
            A.GridDistortion(p=0.5),

            # --- 强度与颜色变换 ---
            A.OneOf([
                A.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, p=1.0),
                A.RandomGamma(gamma_limit=(80, 120), p=1.0),
            ], p=0.8),
            
            # --- 噪声与伪影模拟 ---
            A.GaussNoise(var_limit=(10.0, 50.0), p=0.5),
            # A.Lambda(image=lambda x, **kwargs: add_speckle_noise(x, severity=0.15), p=0.5),

            # --- 擦除与遮挡 ---
            A.CoarseDropout(max_holes=8, max_height=self.image_size // 8, max_width=self.image_size // 8, 
                            fill_value=0, p=0.5),

            # --- 最后的预处理 ---
            A.CLAHE(clip_limit=4.0, p=1.0),
            ToTensorV2()
        ], p=1.0, keypoint_params=A.KeypointParams(format='xy'))

        # 验证时使用的基础变换
        self.val_transform = A.Compose([
            A.Resize(self.image_size, self.image_size),
            A.CLAHE(clip_limit=4.0, p=1.0),
            ToTensorV2()
        ], p=1.0, keypoint_params=A.KeypointParams(format='xy'))

    def __len__(self):
        return len(self.data)
    
    def generate_heatmap(self, center_x, center_y, height, width):
        x = np.arange(0, width, 1, np.float32)
        y = np.arange(0, height, 1, np.float32)[:, np.newaxis]
        heatmap = np.exp(-((x - center_x) ** 2 + (y - center_y) ** 2) / (2 * self.sigma ** 2))
        return heatmap
        
    def __getitem__(self, index):
        row = self.data.iloc[index]
        
        image_np_original = sitk.GetArrayFromImage(sitk.ReadImage(row['Path']))
        if image_np_original.ndim == 2:
            image_np_original = np.stack([image_np_original] * 3, axis=-1)

        image_to_transform = image_np_original.copy()

        keypoints = [ast.literal_eval(row["PS1"]), ast.literal_eval(row["PS2"]), ast.literal_eval(row["FH1"])]
        
        transform = self.train_transform if self.train else self.val_transform
        
        # 应用 Albumentations 数据增强
        augmented = transform(image=image_to_transform, keypoints=keypoints)
        image_tensor, transformed_keypoints = augmented['image'], augmented['keypoints']
        
        if self.train:
            perturbed_keypoints = []
            perturb_sigma = 2.0
            
            for kp in transformed_keypoints:
                if kp is not None:
                    x, y = kp
                    
                    noise_x = np.random.normal(0, perturb_sigma)
                    noise_y = np.random.normal(0, perturb_sigma)
                    
                    perturbed_x = x + noise_x
                    perturbed_y = y + noise_y
                    
                    perturbed_keypoints.append((perturbed_x, perturbed_y))
                else:
                    perturbed_keypoints.append(None)
            
            final_keypoints_for_heatmap = perturbed_keypoints
        else:
            final_keypoints_for_heatmap = transformed_keypoints
        
        # 生成热力图
        heatmaps = np.zeros((3, self.heatmap_size, self.heatmap_size), dtype=np.float32)
        scale = self.heatmap_size / self.image_size
        
        for i, kp in enumerate(final_keypoints_for_heatmap):
            # 确保关键点在变换后仍然有效
            if kp is not None:
                x, y = int(kp[0] * scale), int(kp[1] * scale)
                # 边界检查
                x = max(0, min(x, self.heatmap_size - 1))
                y = max(0, min(y, self.heatmap_size - 1))
                heatmaps[i] = self.generate_heatmap(x, y, self.heatmap_size, self.heatmap_size)
        
        heatmaps = torch.from_numpy(heatmaps)


        if self.train:
            return image_tensor, heatmaps, torch.empty(0), torch.empty(0) # 返回两个空的tensor占位
        else: 
            h_orig, w_orig = image_np_original.shape[:2]
            landmarks_list = [kp for p in keypoints for kp in (p[0] / w_orig, p[1] / h_orig)]
            landmarks_tensor = torch.tensor(landmarks_list, dtype=torch.float32)
            
            # 返回所有需要的值
            return image_tensor, heatmaps, landmarks_tensor, image_np_original

#----------------------------------------------------
# 2. 无标签数据集
#----------------------------------------------------
class UnlabeledDataset(Dataset):
    def __init__(self, dataframe, image_size=512):
        self.data = dataframe
        self.image_size = image_size
        

        # 弱增强：只进行基础的几何和颜色变换
        self.weak_transform = A.Compose([
            A.Resize(self.image_size, self.image_size),
            A.HorizontalFlip(p=0.5),
            A.ColorJitter(brightness=0.2, contrast=0.2, p=0.3),
            A.CLAHE(clip_limit=4.0, p=1.0),
            ToTensorV2(),
        ])

        # 强增强：在弱增强的基础上增加更强的变换
        self.strong_transform = A.Compose([
            A.Resize(self.image_size, self.image_size),
            A.HorizontalFlip(p=0.5),
            A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15, p=0.7),
            
            # 颜色变换
            A.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.8),
            
            # 模糊和噪声
            A.OneOf([
                A.GaussianBlur(blur_limit=(3, 5), p=0.5),
                A.GaussNoise(var_limit=(10.0, 30.0), p=0.5)
                # A.MotionBlur(blur_limit=7, p=0.5)
            ], p=0.5),
            
            # 擦除
            A.CoarseDropout(max_holes=6, max_height=int(image_size*0.1), max_width=int(image_size*0.1),
                            min_holes=1, p=0.5),
            
            A.CLAHE(clip_limit=4.0, p=1.0),
            ToTensorV2(),
        ])
        
    def __getitem__(self, index):
        row = self.data.iloc[index]
        image_np = sitk.GetArrayFromImage(sitk.ReadImage(row['Path']))
        if image_np.ndim == 2:
            image_np = np.stack([image_np] * 3, axis=-1)

        # 分别应用弱强变换
        img_weak = self.weak_transform(image=image_np)['image']
        img_strong = self.strong_transform(image=image_np)['image']
        
        return img_weak, img_strong

    def __len__(self):
        return len(self.data)