import os
import pandas as pd
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from transformers import T5Tokenizer, BertTokenizer

class FashionProductDataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None, 
                 tokenizer_model="google/t5-efficient-tiny", max_seq_len=64,
                 task='category'):
        """
        Args:
            root_dir: Path to 'fashion-product-images-small' folder
            split: 'train' or 'test' (Random 80/20 split based on seed)
            task: 'category' (7 classes) or 'gender' (2 classes)
        """
        self.root_dir = root_dir
        self.img_dir = os.path.join(root_dir, 'images')
        self.csv_path = os.path.join(root_dir, 'styles.csv')
        self.transform = transform
        #self.tokenizer = T5Tokenizer.from_pretrained(tokenizer_model)
        self.tokenizer = BertTokenizer.from_pretrained('prajjwal1/bert-mini')
        self.max_seq_len = max_seq_len
        self.task = task
        
        # 1. Load Metadata
        # on_bad_lines='skip' ignores malformed rows in the raw CSV
        df = pd.read_csv(self.csv_path, on_bad_lines='skip')
        
        # 2. Filter & Clean Labels
        if task == 'gender':
            # Filter for Men/Women only (ignore Boys/Girls/Unisex for clean binary)
            df = df[df['gender'].isin(['Men', 'Women'])]
            self.classes = ['Women', 'Men'] # 0, 1
            self.label_col = 'gender'
            
        elif task == 'category':
            # Use Master Category (Apparel, Accessories, Footwear, etc.)
            # Filter out rare classes to keep it clean
            counts = df['masterCategory'].value_counts()
            keep_cats = counts[counts > 100].index.tolist()
            df = df[df['masterCategory'].isin(keep_cats)]
            self.classes = sorted(keep_cats)
            self.label_col = 'masterCategory'
        
        self.class_to_idx = {c: i for i, c in enumerate(self.classes)}
        
        # 3. Create File Paths
        df['filename'] = df['id'].astype(str) + ".jpg"
        
        # 4. Deterministic Split
        np.random.seed(42)
        shuffled_indices = np.random.permutation(len(df))
        split_point = int(len(df) * 0.8)
        
        if split == 'train':
            self.df = df.iloc[shuffled_indices[:split_point]].reset_index(drop=True)
        else:
            self.df = df.iloc[shuffled_indices[split_point:]].reset_index(drop=True)
            
        print(f"Loaded Fashion Dataset ({split}): {len(self.df)} samples.")
        print(f"Task: {task.upper()} ({len(self.classes)} classes)")
        print(f"Classes: {self.class_to_idx}")
        
    @property
    def vocab_size(self):
        """Return the vocabulary size of the tokenizer."""
        return len(self.tokenizer)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        # --- 1. Load Image ---
        img_path = os.path.join(self.img_dir, row['filename'])
        
        # Handle missing images gracefully
        if not os.path.exists(img_path):
            # Return a dummy or recurse (simple recursion here)
            return self.__getitem__((idx + 1) % len(self))
            
        try:
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
        except Exception:
            return self.__getitem__((idx + 1) % len(self))

        # --- 2. Prepare Text ---
        # Combine name and usage for a rich description
        # e.g. "Nike Men Blue Running Shoe (Casual)"
        desc = str(row['productDisplayName'])
        usage = str(row['usage'])
        raw_text = f"{desc} ({usage})"
        
        # Tokenize
        encoded = self.tokenizer(
            raw_text,
            padding='max_length',
            truncation=True,
            max_length=self.max_seq_len,
            return_tensors='pt'
        )
        
        # --- 3. Label ---
        label_str = row[self.label_col]
        label = self.class_to_idx[label_str]
        
        return {
            'image': image,
            'input_ids': encoded['input_ids'].squeeze(0),
            'attention_mask': encoded['attention_mask'].squeeze(0),
            'label': label,
            'idx': idx
        }
