import os
import json
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import AutoTokenizer, BertForSequenceClassification

class JsonlDataset(Dataset):
    def __init__(self, root_dir, mode='train', label_mapping=None, transform=None):
        """
        Args:
            root_dir (string): Directory with the dataset JSONL file.
            mode (string): Mode to load the data, either 'train' or 'test'.
            label_mapping (dict): A dictionary mapping labels to integers.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.data = []
        
        # Determine the JSONL file to load based on the mode
        jsonl_filename = f"{mode}.jsonl"
        jsonl_path = os.path.join(root_dir, jsonl_filename)
        
        # Load the JSONL file
        with open(jsonl_path, 'r') as f:
            for line in f:
                self.data.append(json.loads(line))
        
        # Sort the data to ensure label order is consistent
        self.data.sort(key=lambda x: x['label'])
        
        # Create a label mapping if not provided
        if label_mapping is None:
            labels = sorted(set(item['label'] for item in self.data))
            self.label_mapping = {label: idx for idx, label in enumerate(labels)}
        else:
            self.label_mapping = label_mapping
        
        # Output label mapping for verification
        print("Label Mapping (text to integer):")
        for label, idx in self.label_mapping.items():
            print(f"{label}: {idx}")
        
        self.tokenizer = AutoTokenizer.from_pretrained("textattack/bert-base-uncased-yelp-polarity")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        sample = self.data[idx]
        label = self.label_mapping[sample['label']]
        text = sample['text']
        img_path = os.path.join(self.root_dir, sample['img'])
        
        # Load image
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        text_encoding = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
        
        return label, image, text_encoding

def collate_fn(batch):
    labels, images, encodings = zip(*batch)
    
    # Stack images
    images = torch.stack(images)
    
    # Combine text encodings
    input_ids = torch.cat([encoding['input_ids'] for encoding in encodings], dim=0)
    attention_mask = torch.cat([encoding['attention_mask'] for encoding in encodings], dim=0)
    
    batch_encodings = {
        'input_ids': input_ids,
        'attention_mask': attention_mask
    }
    
    # Convert labels to tensor
    labels = torch.tensor(labels)
    
    return labels, images, batch_encodings

if __name__ == "__main__":
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])
    
    dataset = JsonlDataset(root_dir='path', mode='test', transform=transform)
    dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
    
    for batch in dataloader:
        labels, images, encodings = batch
        print(encodings['input_ids'].shape)  # Should print (batch_size, max_seq_length)
        print(images.shape)  # Should print (batch_size, 3, 224, 224)
        print(labels.shape)  # Should print (batch_size,)
