from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
from datasets import load_dataset
from io import BytesIO
import os
import torch
import torchvision.transforms as T
from torchvision.transforms import InterpolationMode
from datasets import load_from_disk

def build_clip_transform(size=224, augmentation=False):
    if augmentation:
        return T.Compose([
            T.RandomResizedCrop(size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC, antialias=True),
            T.RandomHorizontalFlip(p=0.5),
            T.ToTensor(),
            T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                        std=[0.26862954, 0.26130258, 0.27577711]),
        ])
    else:
        return T.Compose([
            T.Resize(size, interpolation=InterpolationMode.BICUBIC, antialias=True),
            T.CenterCrop(size),
            T.ToTensor(),
            T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                        std=[0.26862954, 0.26130258, 0.27577711]),
        ])
        
class CC3MDataset(Dataset):

    def __init__(
        self,
        dataset_path_or_name,
        tokenizer,
        split="train",
        max_length=77,
        size=224,
        augmentation=False,
        image_key="jpg",
        caption_key="txt",
        image_transform=None,
    ):
        """
        Args:
            dataset_path_or_name (str): HF dataset name (e.g. 'laion/cc3m') or local directory path.
            tokenizer: Tokenizer for text processing (same as COCOCaptionDataset)
            split (str): Dataset split ('train', 'validation', etc.)
            max_length (int): Maximum token length for captions
            size (int): Target image size for resizing.
            augmentation (bool): If True, apply random horizontal flip.
            image_key (str): Key for image field in dataset.
            caption_key (str): Key for caption field in dataset.
            image_transform (callable, optional): Optional transform to be applied on images
        """
        
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.augmentation = augmentation
        
        # Load dataset
        self.dataset = load_from_disk(dataset_path_or_name, keep_in_memory=False)
        self.image_key = image_key
        self.caption_key = caption_key

        self.transform = image_transform or build_clip_transform(size=size, augmentation=augmentation)
        
        # Get GPU info for tracking
        if torch.cuda.is_available():
            self.device_id = torch.cuda.current_device()
            self.device_name = torch.cuda.get_device_name(self.device_id)
        else:
            self.device_id = "CPU"
            self.device_name = "CPU"
        
        # Get process info for distributed training
        self.rank = int(os.environ.get('RANK', 0))
        self.local_rank = int(os.environ.get('LOCAL_RANK', 0))
        self.world_size = int(os.environ.get('WORLD_SIZE', 1))
        self.accessed_indices = set()
        self.access_count = 0
        
        print(f"[RANK {self.rank}/GPU {self.device_id}] CC3M dataset loaded (streaming mode)")

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        """
        Returns format
            input_ids (Tensor): Tokenized caption
            attention_mask (Tensor): Attention mask for tokens
            images (Tensor): Processed image tensor
            caption (str): Original caption string
        """
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        # Track which indices are being accessed
        self.accessed_indices.add(idx)
        self.access_count += 1
        
        try:
            if hasattr(idx, 'item'):
            	idx = idx.item()
            item = self.dataset[idx]

            
            # Extract image and caption
            image = self._load_image(item[self.image_key])
            caption = item[self.caption_key]
            
            # Tokenize the caption
            tok = self.tokenizer(
                caption,
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt',
            )
            
            input_ids = tok['input_ids'].squeeze(0)
            attention_mask = tok['attention_mask'].squeeze(0) 
            
            max_token = input_ids.max().item()
            min_token = input_ids.min().item()
            vocab_size = self.tokenizer.vocab_size           
            
            if max_token >= vocab_size or min_token < 0:
                print(f"ERROR: Token out of bounds!")
                return self.__getitem__((idx + 1) % len(self))
            
            
            # Transform image
            image = self.transform(image)
            
            del item
            
            return {
                'input_ids': input_ids,
                'attention_mask': attention_mask,
                'images': image,
                'caption': caption,
            }
            
        except Exception as e:
            print(f"[RANK {self.rank}/GPU {self.device_id}] Error loading item {idx}: {e}, trying next index")
            return self.__getitem__(idx + 1 if idx + 1 < len(self) else 0)
    
    def _load_image(self, image_data):
        """
        Converts image data to a PIL Image.
        Supports:
        - Raw bytes
        - File paths  
        - PIL.Image instances
        """
        # Handle different image storage formats
        if isinstance(image_data, dict) and "bytes" in image_data:
            # Case 1: image stored as bytes
            return Image.open(BytesIO(image_data["bytes"])).convert("RGB")
        elif isinstance(image_data, str) and os.path.exists(image_data):
            # Case 2: image is a path string
            return Image.open(image_data).convert("RGB")
        elif hasattr(image_data, 'convert'):        	
            # Case 3: already PIL.Image (most common for HuggingFace datasets)
            return image_data.convert("RGB")   
        else:
            raise TypeError(f"Unsupported image format: {type(image_data)}")