import warnings
warnings.filterwarnings("ignore", category=FutureWarning, message="torch.utils._pytree._register_pytree_node is deprecated")

import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import torch
from functools import partial
from . import Cutout, RandomErasing, GridMask, AugMix, YONA
import uuid

class AFHQBase(Dataset):
    def __init__(self, 
                 data_root, 
                 extra_data_root=None,
                 split='train', 
                 size=None, 
                 interpolation = "bicubic",
                 flip_p=0.5,
                 cat_num=None,
                 dog_num=None,
                 wild_num=None,
                 transform=None):
        self.data_root = data_root
        self.split = split
        self.size = size
        self.interpolation = {"bilinear": Image.BILINEAR,
                              "bicubic": Image.BICUBIC,
                              "lanczos": Image.LANCZOS,
                              }[interpolation]
        self.flip = transforms.RandomHorizontalFlip(p=flip_p)
        self.transform = transform
        self.categories = ['cat', 'dog', 'wild']
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.categories)}
        max_samples_per_category = {
            "cat": cat_num,
            "dog": dog_num,   
            "wild": wild_num   
        }
        
        self.image_paths = []
        self.labels = {"file_paths": [], "categories": []} 
        for category in self.categories:
            category_dir = os.path.join(self.data_root, split, category)
            if os.path.exists(category_dir):
                category_count = 0
                for image_name in os.listdir(category_dir):
                    if max_samples_per_category[category] is not None and category_count >= max_samples_per_category[category]:
                        break 
                    if image_name.endswith(('.jpg', '.jpeg', '.png')):
                        image_path = os.path.join(category_dir, image_name)
                        self.image_paths.append(image_path)
                        self.labels["file_paths"].append(image_path)
                        self.labels["categories"].append(self.class_to_idx[category]) 
                        # self.sample_ids.append(str(uuid.uuid4()))
                        category_count += 1
        if extra_data_root:
            for category in self.categories:
                category_dir = os.path.join(extra_data_root, split, category)
                if os.path.exists(category_dir):
                    category_count = 0
                    for image_name in os.listdir(category_dir):
                        if max_samples_per_category[category] is not None and category_count >= max_samples_per_category[category]:
                            break 
                        if image_name.endswith(('.jpg', '.jpeg', '.png')):
                            image_path = os.path.join(category_dir, image_name)
                            self.image_paths.append(image_path)
                            self.labels["file_paths"].append(image_path)
                            self.labels["categories"].append(self.class_to_idx[category])
                            # self.sample_ids.append(str(uuid.uuid4()))
                            category_count += 1 
        self.sample_weights = [1.0] * len(self.image_paths)
        self._length = len(self.image_paths)
    
    def __len__(self):
        return self._length

    def __getitem__(self, i):
        image_path = self.image_paths[i]
        # sample_id = self.sample_ids[i]
        
        image = Image.open(image_path)
        if not image.mode == "RGB":
            image = image.convert("RGB")

        if self.transform:
            image = self.transform(image)


        label = torch.tensor(self.labels["categories"][i], dtype=torch.long)

        return image, label, i


    def add_samples(self, new_image_paths, new_categories, weight_factor=1.0):
        if len(new_image_paths) != len(new_categories):
            raise ValueError("len(new_image_paths) !== (new_categories)")

        self.image_paths.extend(new_image_paths)
        self.labels["file_paths"].extend(new_image_paths)
        self.labels["categories"].extend(new_categories)
        new_weights = [weight_factor] * len(new_image_paths)
        self.sample_weights.extend(new_weights)
        self._length = len(self.image_paths)
        
class AfhqTrain(AFHQBase):
    def __init__(self, data_root='data/afhq/afhq', flip_p=0.5, size=224, **kwargs):
        base_transform = transforms.RandomHorizontalFlip()
        transform = transforms.Compose([
            transforms.Resize((size, size)),  
            # transforms.RandomHorizontalFlip(p=flip_p),
            # # 单图像增强方法 - 可以直接在 transforms.Compose 中使用
            # Cutout(mask_size=32),                    # 原有的 Cutout
            # RandomErasing(prob=0.3, scale=(0.02, 0.33)),  # 随机擦除
            # GridMask(prob=0.2, d1=32, d2=96),       # 网格遮罩
            # AugMix(prob=0.1, alpha=1.0, width=3),   # AugMix 增强
            transforms.ToTensor(),
            # YONA(mask_direction='random', transform=base_transform),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        super().__init__(data_root=data_root, split='train', flip_p=flip_p, transform=transform, **kwargs)


class AfhqTrainMinimal(AFHQBase):
    """最小化增强版本"""
    def __init__(self, data_root='data/afhq/afhq', flip_p=0.5, size=256, **kwargs):
        transform = transforms.Compose([
            transforms.Resize((size, size)),  
            transforms.RandomHorizontalFlip(p=flip_p),
            Cutout(mask_size=32),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        super().__init__(data_root=data_root, split='train', flip_p=flip_p, transform=transform, **kwargs)


class AfhqTrainHeavy(AFHQBase):
    """重度增强版本"""
    def __init__(self, data_root='data/afhq/afhq', flip_p=0.5, size=256, **kwargs):
        transform = transforms.Compose([
            transforms.Resize((size, size)),  
            transforms.RandomHorizontalFlip(p=flip_p),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            Cutout(mask_size=32),
            RandomErasing(prob=0.5, scale=(0.02, 0.33)),
            GridMask(prob=0.3, d1=32, d2=96),
            AugMix(prob=0.2, alpha=1.0, width=3),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        super().__init__(data_root=data_root, split='train', flip_p=flip_p, transform=transform, **kwargs)



class AfhqValidation(AFHQBase):
    def __init__(self, data_root='data/afhq/afhq', flip_p=0.0, size=224, **kwargs):
        transform = transforms.Compose([
            transforms.Resize((size, size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        super().__init__(data_root=data_root, split='val', flip_p=flip_p, transform=transform, **kwargs)
