import os
import json
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from matplotlib import pyplot as plt

import torchvision.transforms as T
from torchvision.transforms import InterpolationMode

def build_clip_transform(size=224, augmentation=False):
    if augmentation:
        return T.Compose([
            T.RandomResizedCrop(size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC, antialias=True),
            T.RandomHorizontalFlip(p=0.5),
            T.ToTensor(),
            T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                        std=[0.26862954, 0.26130258, 0.27577711]),
        ])
    else:
        return T.Compose([
            T.Resize(size, interpolation=InterpolationMode.BICUBIC, antialias=True),
            T.CenterCrop(size),
            T.ToTensor(),
            T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                        std=[0.26862954, 0.26130258, 0.27577711]),
        ])

class COCOCaptionDataset(Dataset):
    """
    Dataset class for COCO 2014 caption dataset
    """
    def __init__(
            self,
            root_dir,
            ann_file,
            tokenizer,
            max_length=77,
            size=224,
            augmentation=False,
            image_transform=None):
            
        self.root_dir = root_dir
        self.augmentation = augmentation 
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.transform = image_transform or build_clip_transform(size=size, augmentation=augmentation)
        
        if torch.cuda.is_available():
            self.device_id = torch.cuda.current_device()
            self.device_name = torch.cuda.get_device_name(self.device_id)
        else:
            self.device_id = "CPU"
            self.device_name = "CPU"
        
        # Get process info for distributed training
        self.rank = int(os.environ.get('RANK', 0))
        self.local_rank = int(os.environ.get('LOCAL_RANK', 0))
        self.world_size = int(os.environ.get('WORLD_SIZE', 1))
        
        # Load annotations
        with open(ann_file, 'r') as f:
            self.annotations = json.load(f)
        
        # Create image_id to filename mapping
        self.id_to_filename = {}
        for image in self.annotations['images']:
            self.id_to_filename[image['id']] = image['file_name']
        
        # Create caption list with image_id references
        self.captions = []
        for annotation in self.annotations['annotations']:
            self.captions.append({
                'image_id': annotation['image_id'],
                'caption': annotation['caption']
            })
        
        # Track accessed indices for debugging
        self.accessed_indices = set()
        self.access_count = 0
    
    def __len__(self):
        return len(self.captions)
    
    def __getitem__(self, idx):

        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        # Track which indices are being accessed
        self.accessed_indices.add(idx)
        self.access_count += 1
        
        caption_data = self.captions[idx]
        caption = caption_data['caption']
        
        tok = self.tokenizer(
            caption,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors ='pt',
        )
        
        input_ids      = tok['input_ids'].squeeze(0)
        attention_mask = tok['attention_mask'].squeeze(0)
        
        image_id = caption_data['image_id']
        
        # Get image path
        img_filename = self.id_to_filename[image_id]
        img_path = os.path.join(self.root_dir, img_filename)
        
        # Load & transform image
        try:
            image = Image.open(img_path).convert('RGB')
        except:
            print(f"[RANK {self.rank}/GPU {self.device_id}] Error loading image: {img_path}, trying next index")
            return self.__getitem__(idx + 1)
            
        image = self.transform(image)
        
        # Return the exact keys your model wants
        return {
            'input_ids':      input_ids,
            'attention_mask': attention_mask,
            'images':         image,
            'caption':	      caption,
        }    
