import warnings
warnings.filterwarnings("ignore", category=FutureWarning, message="torch.utils._pytree._register_pytree_node is deprecated")

import os
import random
import numpy as np
import torch
import torchvision.datasets as D
import torchvision.transforms as T
from torch.utils.data import random_split, DataLoader, Dataset, Subset, WeightedRandomSampler
import pytorch_lightning as pl
from utils.util import instantiate_from_config
from PIL import Image
from functools import partial
import torchvision.transforms as transforms
def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    g = torch.Generator()
    g.manual_seed(seed)
    return g
generator = set_seed(42)

class WrappedDataset(Dataset):
    """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""

    def __init__(self, dataset):
        self.data = dataset

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


class DataModuleFromConfig(pl.LightningDataModule):
    def __init__(self, batch_size, train=None, validation=None, test=None, predict=None,
                 wrap=False, num_workers=24, shuffle_test_loader=False, use_worker_init_fn=False,
                 shuffle_val_dataloader=False):
        super().__init__()
        self.batch_size = batch_size
        self.dataset_configs = dict()
        self.num_workers = num_workers if num_workers is not None else batch_size * 2
        self.use_worker_init_fn = use_worker_init_fn
        if train is not None:
            self.dataset_configs["train"] = train
        if validation is not None:
            self.dataset_configs["validation"] = validation
        if test is not None:
            self.dataset_configs["test"] = test
        if predict is not None:
            self.dataset_configs["predict"] = predict
        self.wrap = wrap


    def prepare_data(self):
        for data_cfg in self.dataset_configs.values():
            instantiate_from_config(data_cfg)

    def setup(self, stage=None):
        self.datasets = dict(
            (k, instantiate_from_config(self.dataset_configs[k]))
            for k in self.dataset_configs)
        if self.wrap:
            for k in self.datasets:
                self.datasets[k] = WrappedDataset(self.datasets[k])
        self.dataloaders = {}
        for key, dataset in self.datasets.items():
            if key == "train":
                sampler = self._create_weighted_sampler(dataset.data)
                self.dataloaders[key] = DataLoader(dataset, batch_size=self.batch_size, 
                                                   sampler=sampler, generator=generator)
            else:
                self.dataloaders[key] = DataLoader(dataset, batch_size=self.batch_size, shuffle=False)
    
    def _create_weighted_sampler(self, dataset):
        if not hasattr(dataset, "sample_weights") or dataset.sample_weights is None:
            raise ValueError("dataset.sample_weights is null!")

        sample_weights = dataset.sample_weights 

        if not isinstance(sample_weights, torch.Tensor):
            sample_weights = torch.tensor(sample_weights, dtype=torch.float32)

        sampler = WeightedRandomSampler(weights=sample_weights, 
                                        num_samples=len(dataset), 
                                        replacement=False,
                                        generator=generator)
        return sampler
    
    def train_dataloader(self):
        dataset = self.datasets["train"]
        sampler = self._create_weighted_sampler(dataset.data)  
        self.dataloaders['train'] = DataLoader(dataset, batch_size=self.batch_size, 
                                                   sampler=sampler)
        print("Train dataloader updated.")


class Cutout(object):
    def __init__(self, mask_size):
        """
        Args:
            mask_size (int): The size of the square mask to apply to the image.
        """
        self.mask_size = mask_size
    
    def __call__(self, img):
        """
        Apply cutout to the input image.

        Args:
            img (PIL Image or Tensor): The input image to be transformed.

        Returns:
            PIL Image or Tensor: The transformed image with cutout applied.
        """
        if isinstance(img, torch.Tensor):
            img = img.numpy() 
        elif isinstance(img, Image.Image):  
            img = np.array(img)  

        h, w = img.shape[0], img.shape[1]

        mask_x = random.randint(0, w - self.mask_size)
        mask_y = random.randint(0, h - self.mask_size)

        img[mask_y:mask_y + self.mask_size, mask_x:mask_x + self.mask_size,:] = 0

        return img


class BatchCutMix(object):
    """Batch-level CutMix transform for use in training loops"""
    def __init__(self, alpha=1.0, prob=0.5):
        self.alpha = alpha
        self.prob = prob
    
    def __call__(self, batch):
        if random.random() > self.prob:
            return batch
            
        images, labels = batch
        batch_size = images.size(0)
        
        if self.alpha > 0:
            lam = np.random.beta(self.alpha, self.alpha)
        else:
            lam = 1
            
        rand_index = torch.randperm(batch_size)
        
        W = images.size(3)
        H = images.size(2)
        cut_rat = np.sqrt(1. - lam)
        cut_w = int(W * cut_rat)
        cut_h = int(H * cut_rat)
        
        cx = np.random.randint(W)
        cy = np.random.randint(H)
        
        bbx1 = np.clip(cx - cut_w // 2, 0, W)
        bby1 = np.clip(cy - cut_h // 2, 0, H)
        bbx2 = np.clip(cx + cut_w // 2, 0, W)
        bby2 = np.clip(cy + cut_h // 2, 0, H)
        
        images[:, :, bby1:bby2, bbx1:bbx2] = images[rand_index, :, bby1:bby2, bbx1:bbx2]
        lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (W * H))
        
        return images, (labels, labels[rand_index], lam)


class BatchMixUp(object):
    """Batch-level MixUp transform for use in training loops"""
    def __init__(self, alpha=1.0, prob=0.5):
        self.alpha = alpha
        self.prob = prob
    
    def __call__(self, batch):
        if random.random() > self.prob:
            return batch
            
        images, labels = batch
        batch_size = images.size(0)
        
        if self.alpha > 0:
            lam = np.random.beta(self.alpha, self.alpha)
        else:
            lam = 1
            
        rand_index = torch.randperm(batch_size)
        mixed_images = lam * images + (1 - lam) * images[rand_index]
        
        return mixed_images, (labels, labels[rand_index], lam)


class Mosaic(object):
    def __init__(self, prob=0.5):
        """
        Mosaic data augmentation (combines 4 images into one).
        
        Args:
            prob (float): Probability of applying Mosaic
        """
        self.prob = prob
    
    def __call__(self, batch):
        """
        Apply Mosaic to a batch of images and labels.
        
        Args:
            batch (tuple): (images, labels) where images is a tensor of shape (B, C, H, W)
                          and labels is a tensor of shape (B,)
        
        Returns:
            tuple: (mosaic_images, mosaic_labels) 
        """
        if random.random() > self.prob:
            return batch
            
        images, labels = batch
        batch_size = images.size(0)
        
        if batch_size < 4:
            return batch  
            
        C, H, W = images.shape[1], images.shape[2], images.shape[3]

        mosaic_images = []
        mosaic_labels = []
        
        for i in range(0, batch_size - 3, 4):

            four_images = images[i:i+4]
            four_labels = labels[i:i+4]
            
            top_row = torch.cat([four_images[0], four_images[1]], dim=2) 
            bottom_row = torch.cat([four_images[2], four_images[3]], dim=2)
            mosaic = torch.cat([top_row, bottom_row], dim=1) 
            
            mosaic = torch.nn.functional.interpolate(
                mosaic.unsqueeze(0), size=(H, W), mode='bilinear', align_corners=False
            ).squeeze(0)
            
            mosaic_images.append(mosaic)
            mosaic_labels.append(four_labels[0])
        
        if mosaic_images:
            mosaic_images = torch.stack(mosaic_images)
            mosaic_labels = torch.stack(mosaic_labels)
            return mosaic_images, mosaic_labels
        else:
            return batch


class RandomErasing(object):
    def __init__(self, prob=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0):
        """
        Random Erasing data augmentation - Compatible with transforms.Compose
        
        Args:
            prob (float): Probability of applying Random Erasing
            scale (tuple): Range of proportion of erased area against input image
            ratio (tuple): Range of aspect ratio of erased area
            value (float or tuple): Erasing value
        """
        self.prob = prob
        self.scale = scale
        self.ratio = ratio
        self.value = value
    
    def __call__(self, img):
        """
        Apply Random Erasing to the input image.
        
        Args:
            img (PIL Image or Tensor): The input image to be transformed.
        
        Returns:
            PIL Image or Tensor: The transformed image with random erasing applied.
        """
        if random.random() > self.prob:
            return img
            
        # Handle different input types
        if isinstance(img, torch.Tensor):
            return self._erase_tensor(img)
        elif isinstance(img, Image.Image):
            return self._erase_pil(img)
        else:
            return self._erase_numpy(img)
    
    def _erase_tensor(self, img):
        """Erase tensor image"""
        if len(img.shape) == 3:  # C, H, W
            c, h, w = img.shape
            area = h * w
            
            target_area = random.uniform(self.scale[0], self.scale[1]) * area
            aspect_ratio = random.uniform(self.ratio[0], self.ratio[1])
            
            h_erase = int(round(np.sqrt(target_area * aspect_ratio)))
            w_erase = int(round(np.sqrt(target_area / aspect_ratio)))
            
            if h_erase < h and w_erase < w:
                x1 = random.randint(0, w - w_erase)
                y1 = random.randint(0, h - h_erase)
                img[:, y1:y1+h_erase, x1:x1+w_erase] = self.value
        
        return img
    
    def _erase_pil(self, img):
        """Erase PIL image"""
        img_array = np.array(img)
        img_array = self._erase_numpy(img_array)
        return Image.fromarray(img_array)
    
    def _erase_numpy(self, img_array):
        """Erase numpy array image"""
        if len(img_array.shape) == 3:
            h, w, c = img_array.shape
        else:
            h, w = img_array.shape
            
        area = h * w
        target_area = random.uniform(self.scale[0], self.scale[1]) * area
        aspect_ratio = random.uniform(self.ratio[0], self.ratio[1])
        
        h_erase = int(round(np.sqrt(target_area * aspect_ratio)))
        w_erase = int(round(np.sqrt(target_area / aspect_ratio)))
        
        if h_erase < h and w_erase < w:
            x1 = random.randint(0, w - w_erase)
            y1 = random.randint(0, h - h_erase)
            
            if len(img_array.shape) == 3:
                img_array[y1:y1+h_erase, x1:x1+w_erase, :] = self.value
            else:
                img_array[y1:y1+h_erase, x1:x1+w_erase] = self.value
        
        return img_array


class GridMask(object):
    def __init__(self, d1=96, d2=224, rotate=1, ratio=0.5, mode=1, prob=0.5):
        """
        GridMask data augmentation - Compatible with transforms.Compose
        
        Args:
            d1 (int): Minimum grid size
            d2 (int): Maximum grid size
            rotate (int): Rotation angle range (currently not used)
            ratio (float): Grid ratio
            mode (int): Grid mode (0 or 1)
            prob (float): Probability of applying GridMask
        """
        self.d1 = d1
        self.d2 = d2
        self.rotate = rotate
        self.ratio = ratio
        self.mode = mode
        self.prob = prob
    
    def __call__(self, img):
        """
        Apply GridMask to the input image.
        
        Args:
            img (PIL Image or Tensor): The input image to be transformed.
        
        Returns:
            PIL Image or Tensor: The transformed image with GridMask applied.
        """
        if random.random() > self.prob:
            return img
            
        if isinstance(img, torch.Tensor):
            return self._grid_mask_tensor(img)
        elif isinstance(img, Image.Image):
            return self._grid_mask_pil(img)
        else:
            return self._grid_mask_numpy(img)
    
    def _grid_mask_tensor(self, img):
        """Apply grid mask to tensor image"""
        if len(img.shape) == 3:  # C, H, W
            c, h, w = img.shape
            mask = self._create_mask(h, w)
            mask = torch.from_numpy(mask).float()
            
            for i in range(c):
                img[i] = img[i] * mask
        
        return img
    
    def _grid_mask_pil(self, img):
        """Apply grid mask to PIL image"""
        img_array = np.array(img)
        img_array = self._grid_mask_numpy(img_array)
        return Image.fromarray(img_array.astype(np.uint8))
    
    def _grid_mask_numpy(self, img_array):
        """Apply grid mask to numpy array image"""
        if len(img_array.shape) == 3:
            h, w, c = img_array.shape
            mask = self._create_mask(h, w)
            
            for i in range(c):
                img_array[:, :, i] = img_array[:, :, i] * mask
        else:
            h, w = img_array.shape
            mask = self._create_mask(h, w)
            img_array = img_array * mask
            
        return img_array
    
    def _create_mask(self, h, w):
        """Create grid mask"""
        d = random.randint(self.d1, self.d2)
        l = int(d * self.ratio)
        
        mask = np.ones((h, w), dtype=np.float32)
        
        for i in range(0, h, d):
            for j in range(0, w, d):
                if self.mode == 1:
                    # Remove grid squares
                    mask[i:min(i+l, h), j:min(j+l, w)] = 0
                else:
                    # Keep only grid squares
                    mask[i:min(i+l, h), j:min(j+l, w)] = 1
                    mask[i+l:min(i+d, h), j:min(j+l, w)] = 0
                    mask[i:min(i+l, h), j+l:min(j+d, w)] = 0
                    mask[i+l:min(i+d, h), j+l:min(j+d, w)] = 0
        
        return mask


class AugMix(object):
    def __init__(self, prob=0.5, alpha=1.0, width=3, depth=1, severity=1):
        """
        AugMix data augmentation.
        
        Args:
            prob (float): Probability of applying AugMix
            alpha (float): Beta distribution parameter
            width (int): Number of augmentation chains
            depth (int): Depth of augmentation chains
            severity (int): Severity of augmentations
        """
        self.prob = prob
        self.alpha = alpha
        self.width = width
        self.depth = depth
        self.severity = severity
        
        # Define augmentation operations
        self.ops = [
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
            transforms.RandomRotation(degrees=30),
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
            transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
        ]
    
    def __call__(self, img):
        """
        Apply AugMix to the input image.
        
        Args:
            img (PIL Image or Tensor): The input image to be transformed.
        
        Returns:
            PIL Image or Tensor: The transformed image with AugMix applied.
        """
        if random.random() > self.prob:
            return img
        
        # Convert to PIL if needed
        original_type = type(img)
        if isinstance(img, torch.Tensor):
            img = transforms.ToPILImage()(img)
        elif isinstance(img, np.ndarray):
            img = Image.fromarray(img)
            
        # Sample mixing weights
        ws = np.random.dirichlet([self.alpha] * self.width)
        m = np.random.beta(self.alpha, self.alpha)
        
        # Apply augmentation chains
        mix = transforms.ToTensor()(img) * 0 
        
        for i in range(self.width):
            # Apply augmentation chain
            img_aug = img.copy()  
            for _ in range(self.depth):
                op = random.choice(self.ops)
                img_aug = op(img_aug)
            
            # Convert to tensor and add to mix
            mix += ws[i] * transforms.ToTensor()(img_aug)
        
        # Final mixing
        mixed = (1 - m) * transforms.ToTensor()(img) + m * mix
        
        # Return in the same format as input
        if original_type == np.ndarray:
            mixed = (mixed.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
        elif original_type == Image.Image:
            mixed = transforms.ToPILImage()(mixed)
        
        return mixed

class YONA(object):
    """
    YONA augmentation:
    - Half-mask the image (top/bottom or left/right).
    - Apply a given transform only to the unmasked part.
    """

    def __init__(self, mask_direction='random', transform=None):
        """
        Args:
            mask_direction (str): 'horizontal', 'vertical', or 'random'.
            transform (callable): Transform to apply only on the unmasked region.
        """
        self.mask_direction = mask_direction
        self.transform = transform  

    def __call__(self, img):
        """
        Args:
            img (Tensor): C x H x W image.

        Returns:
            Tensor: Partially masked and transformed image.
        """
        if not isinstance(img, torch.Tensor):
            raise TypeError("YONA expects torch.Tensor input")

        _, H, W = img.shape
        img = img.clone()  # Avoid modifying original

        # Determine direction
        direction = self.mask_direction
        if direction == 'random':
            direction = random.choice(['horizontal', 'vertical'])

        # Generate mask
        if direction == 'horizontal':
            mask = torch.ones_like(img)
            mask[:, :, :W//2] = 0
            unmasked_region = img[:, :, W//2:].clone()
            if self.transform:
                transformed = self.transform(unmasked_region)
                img[:, :, W//2:] = transformed
            img = img * mask
        elif direction == 'vertical':
            mask = torch.ones_like(img)
            mask[:, :H//2, :] = 0
            unmasked_region = img[:, H//2:, :].clone()
            if self.transform:
                transformed = self.transform(unmasked_region)
                img[:, H//2:, :] = transformed
            img = img * mask
        else:
            raise ValueError("Invalid mask direction")

        return img