import os
from tqdm import tqdm
import glob
import re
import csv
from PIL import Image
import torch
from torch.utils.data import Dataset
from transformers import BertTokenizer
from transformers import T5Tokenizer

class UPMCFoodDataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None, max_seq_len=128, max_samples=None):
        """
        Args:
            root_dir (str): Path to upmc-food101 folder (containing 'images' and 'texts').
            split (str): 'train' or 'test'.
            transform (callable): Transform to apply to images.
            max_seq_len (int): Maximum sequence length for BERT.
        """
        self.root_dir = root_dir
        self.img_dir = os.path.join(root_dir, 'images', split)
        # check if dir exists
        if not os.path.exists(self.img_dir):
            raise FileNotFoundError(f"Image directory {self.img_dir} not found.")
        self.txt_dir = os.path.join(root_dir, 'texts')
        if not os.path.exists(self.txt_dir):
            raise FileNotFoundError(f"Text directory {self.txt_dir} not found.")
        self.transform = transform
        self.max_seq_len = max_seq_len
        
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        #self.tokenizer = T5Tokenizer.from_pretrained('google/t5-efficient-tiny')
        
        self.samples = []
        self.classes = []
        self.class_to_idx = {}
        self.labels = []
        
        # Load samples from CSV
        csv_path = os.path.join(self.txt_dir, f'{split}_titles.csv')
        if not os.path.exists(csv_path):
            raise FileNotFoundError(f"CSV file {csv_path} not found.")
        
        print(f"Loading {split} data from {csv_path}...")
        with open(csv_path, 'r', encoding='utf-8') as f:
            reader = csv.reader(f)
            for row in reader:
                if len(row) != 3:
                    continue
                img_name, text_desc, class_name = row
                
                # Build class mapping
                if class_name not in self.class_to_idx:
                    self.class_to_idx[class_name] = len(self.classes)
                    self.classes.append(class_name)
                
                # Full path to image - store path, not loaded image
                img_path = os.path.join(self.img_dir, class_name, img_name)
                if os.path.exists(img_path):
                    label = self.class_to_idx[class_name]
                    # Store path and raw text description, transform in __getitem__
                    self.samples.append((img_path, text_desc, label))
                    self.labels.append(label)
                    
                if max_samples is not None and len(self.samples) >= max_samples:
                    break
        
        print(f"Found {len(self.samples)} valid samples in {split} set with {len(self.classes)} classes.")

    def clean_text(self, text):
        # Remove HTML tags
        text = re.sub(r'<[^>]+>', '', text)
        # Remove messy whitespace
        text = re.sub(r'\s+', ' ', text).strip()
        return text

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, text_desc, label = self.samples[idx]

        # 1. Load and transform image
        try:
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return a dummy black image in case of corruption
            image = torch.zeros((3, 224, 224))

        # 2. Clean and tokenize text
        raw_text = self.clean_text(text_desc)
        encoded = self.tokenizer(
            raw_text,
            padding='max_length',
            truncation=True,
            max_length=self.max_seq_len,
            return_tensors='pt'
        )
        
        input_ids = encoded['input_ids'].squeeze(0)  # Remove batch dim
        attention_mask = encoded['attention_mask'].squeeze(0)

        return {
            'image': image,
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'label': label,
            'idx': idx
        }