import warnings
warnings.filterwarnings("ignore", category=FutureWarning, message="torch.utils._pytree._register_pytree_node is deprecated")

import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import torch
from functools import partial
from . import Cutout, RandomErasing, GridMask, AugMix, YONA
import uuid

class BirdsBase(Dataset):
    def __init__(self, 
                 data_root, 
                 extra_data_root=None,
                 split='train', 
                 size=None, 
                 interpolation = "bicubic",
                 flip_p=0.5,
                 max_per_class=None, 
                 more_per_class=None,
                 weight_factor=1.0,
                 transform=None,
                 classes_number=None):
        self.data_root = data_root
        self.split = split
        self.size = size
        self.transform = transform
        self.classes_number = classes_number
        self.max_per_class = max_per_class
        self.more_per_class = more_per_class
        self.interpolation = {"bilinear": Image.BILINEAR,
                              "bicubic": Image.BICUBIC,
                              "lanczos": Image.LANCZOS,
                              }[interpolation]
        self.flip = transforms.RandomHorizontalFlip(p=flip_p)

        self.data_dir = os.path.join(data_root, split)
        self.class_names = sorted(
            [d for d in os.listdir(self.data_dir) if os.path.isdir(os.path.join(self.data_dir, d))]
        )
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.class_names)}

        
        self.image_paths = []
        self.sample_weights = []
        self.labels = {"file_paths": [], "categories": []} 
        for cls_name in self.class_names[:self.classes_number]:
            cls_dir = os.path.join(self.data_dir, cls_name)
            if not os.path.isdir(cls_dir):
                continue
            images = [f for f in os.listdir(cls_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
            if max_per_class and more_per_class and len(images) > max_per_class + more_per_class:
                images = images[:max_per_class + more_per_class]
            elif max_per_class and len(images) > max_per_class:
                images = images[:max_per_class]

            for i, image_name in enumerate(images):
                image_path = os.path.join(cls_dir, image_name)
                self.image_paths.append(image_path)
                self.labels["file_paths"].append(image_path)
                self.labels["categories"].append(self.class_to_idx[cls_name])
                if more_per_class and i >= max_per_class:
                    self.sample_weights.append(weight_factor)
                else:
                    self.sample_weights.append(1.0)
        if extra_data_root:
            for cls_name in self.class_names[:self.classes_number]:
                cls_dir = os.path.join(extra_data_root, split, cls_name)
                if not os.path.isdir(cls_dir):
                    continue
                images = [f for f in os.listdir(cls_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
                for image_name in images:
                    image_path = os.path.join(cls_dir, image_name)
                    self.image_paths.append(image_path)
                    self.labels["file_paths"].append(image_path)
                    self.labels["categories"].append(self.class_to_idx[cls_name])
                    self.sample_weights.append(weight_factor)
                    
        self._length = len(self.image_paths)
    
    def __len__(self):
        return self._length

    def __getitem__(self, i):
        image_path = self.image_paths[i]
        image = Image.open(image_path).convert("RGB")

        if self.transform:
            image = self.transform(image)


        label = torch.tensor(self.labels["categories"][i], dtype=torch.long)
        
        return image, label, i
    
    def add_samples(self, new_image_paths, new_categories, weight_factor=1.0):
        if len(new_image_paths) != len(new_categories):
            raise ValueError("len(new_image_paths) !== (new_categories)")

        self.image_paths.extend(new_image_paths)
        self.labels["file_paths"].extend(new_image_paths)
        self.labels["categories"].extend(new_categories)
        new_weights = [weight_factor] * len(new_image_paths)
        self.sample_weights.extend(new_weights)
        self._length = len(self.image_paths)
        
        
class BirdsTrain(BirdsBase):
    def __init__(self, data_root='data/birds_525_species', flip_p=0.5, **kwargs):
        base_transform = transforms.RandomHorizontalFlip()
        transform = transforms.Compose([
            # transforms.Resize((256, 256)),
            transforms.CenterCrop(224),
            # Cutout(mask_size=32),                  
            # RandomErasing(prob=0.3, scale=(0.02, 0.33)),  
            # transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            YONA(mask_direction='random', transform=base_transform),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        super().__init__(data_root=data_root, split='train', flip_p=flip_p, transform=transform, **kwargs)



class BirdsValidation(BirdsBase):
    def __init__(self, data_root='data/birds_525_species', flip_p=0.0,  **kwargs):
        transform = transforms.Compose([
            # transforms.Resize((256, 256)),
            transforms.Resize((224, 224), interpolation=Image.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        super().__init__(data_root=data_root, split='valid', flip_p=flip_p, transform=transform, **kwargs)
