import warnings
warnings.filterwarnings("ignore", category=FutureWarning, message="torch.utils._pytree._register_pytree_node is deprecated")

import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import torch
from functools import partial
from . import Cutout, RandomErasing, GridMask, AugMix, YONA
import uuid

class FlowerBase(Dataset):
    def __init__(self, 
                 data_root, 
                 split='train', 
                 extra_data_root=None,
                 size=None, 
                 interpolation = "bicubic",
                 flip_p=0.5,
                 max_per_class=None, 
                 transform=None,
                 classes_number=None):
        self.data_root = data_root
        self.split = split
        self.size = size
        self.transform = transform
        self.classes_number = classes_number
        self.interpolation = {"bilinear": Image.BILINEAR,
                              "bicubic": Image.BICUBIC,
                              "lanczos": Image.LANCZOS,
                              }[interpolation]
        self.flip = transforms.RandomHorizontalFlip(p=flip_p)

        self.data_dir = os.path.join(data_root, split)
        self.class_names = sorted(
            [d for d in os.listdir(self.data_dir) if os.path.isdir(os.path.join(self.data_dir, d))],
            key=lambda x: int(x)
        )
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.class_names)}

        
        self.image_paths = []
        self.labels = {"file_paths": [], "categories": []} 
        for cls_name in self.class_names[:self.classes_number]:
            cls_dir = os.path.join(self.data_dir, cls_name)
            if not os.path.isdir(cls_dir):
                continue

            images = [f for f in os.listdir(cls_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
            if max_per_class and len(images) > max_per_class:
                images = images[:max_per_class]

            for image_name in images:
                image_path = os.path.join(cls_dir, image_name)
                self.image_paths.append(image_path)
                self.labels["file_paths"].append(image_path)
                self.labels["categories"].append(self.class_to_idx[cls_name])
        if extra_data_root:
            for cls_name in self.class_names[:self.classes_number]:
                cls_dir = os.path.join(extra_data_root, split, cls_name)
                if not os.path.isdir(cls_dir):
                    continue
                images = [f for f in os.listdir(cls_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
                if max_per_class and len(images) > max_per_class:
                    images = images[:max_per_class]

                for image_name in images:
                    image_path = os.path.join(cls_dir, image_name)
                    self.image_paths.append(image_path)
                    self.labels["file_paths"].append(image_path)
                    self.labels["categories"].append(self.class_to_idx[cls_name])
        self.sample_weights = [1.0] * len(self.image_paths)
        self._length = len(self.image_paths)
    
    def __len__(self):
        return self._length

    def __getitem__(self, i):
        image_path = self.image_paths[i]
        image = Image.open(image_path).convert("RGB")

        if self.transform:
            image = self.transform(image)
        else:
            img = np.array(image).astype(np.uint8)
            crop = min(img.shape[0], img.shape[1])
            h, w = img.shape[0], img.shape[1]
            img = img[(h - crop) // 2:(h + crop) // 2, (w - crop) // 2:(w + crop) // 2]
            image = Image.fromarray(img).resize((self.size, self.size), resample=self.interpolation)
            image = self.flip(image)
            image = np.array(image).astype(np.uint8)
            image = (image / 127.5 - 1.0).astype(np.float32)
            image = np.transpose(image, (2, 0, 1))

        label = torch.tensor(self.labels["categories"][i], dtype=torch.long)
        
        return image, label, i

    def add_samples(self, new_image_paths, new_categories, weight_factor=1.0):
        if len(new_image_paths) != len(new_categories):
            raise ValueError("len(new_image_paths) !== (new_categories)")

        self.image_paths.extend(new_image_paths)
        self.labels["file_paths"].extend(new_image_paths)
        self.labels["categories"].extend(new_categories)
        new_weights = [weight_factor] * len(new_image_paths)
        self.sample_weights.extend(new_weights)
        self._length = len(self.image_paths)
        
class FlowerTrain(FlowerBase):
    def __init__(self, data_root='data/flower102/dataset', flip_p=0.5, **kwargs):
        base_transform = transforms.RandomHorizontalFlip()
        transform = transforms.Compose([
            transforms.Resize((256, 256)),  
            transforms.CenterCrop(224),
            # RandomErasing(prob=0.3, scale=(0.02, 0.33)),
            # Cutout(mask_size=32),
            # transforms.RandomHorizontalFlip(p=flip_p),
            # Cutout(mask_size=32),
            transforms.ToTensor(),
            YONA(mask_direction='random', transform=base_transform),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        super().__init__(data_root=data_root, split='train', flip_p=flip_p, transform=transform, **kwargs)



class FlowerValidation(FlowerBase):
    def __init__(self, data_root='data/flower102/dataset', flip_p=0.0,  **kwargs):
        transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        super().__init__(data_root=data_root, split='valid', flip_p=flip_p, transform=transform, **kwargs)
