"""
Model-Agnostic Data Loader for TOFU Dataset
Just pass any tokenizer and it works
"""

import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
import logging
from typing import Tuple, Dict, Any

logger = logging.getLogger(__name__)

class TOFUDataset(Dataset):
    """Model-agnostic TOFU dataset"""

    def __init__(self, subset: str, tokenizer, max_length: int = 512, split: str = "train"):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.subset = subset

        # Load TOFU dataset
        logger.info(f"Loading TOFU dataset subset: {subset}")
        self.dataset = load_dataset("locuslab/TOFU", subset, split=split)
        logger.info(f"Loaded {len(self.dataset)} samples")

        # Preprocess data
        self.processed_data = self._preprocess_data()

    def _preprocess_data(self):
        """Preprocess data for any model"""
        processed = []

        for item in self.dataset:
            question = item.get('question', '')
            answer = item.get('answer', '')

            # Simple format that works for any model
            if question and answer:
                text = f"Q: {question}\nA: {answer}"
            elif question:
                text = f"Q: {question}"
            elif answer:
                text = f"A: {answer}"
            else:
                continue

            try:
                tokens = self.tokenizer(
                    text,
                    max_length=self.max_length,
                    padding='max_length',
                    truncation=True,
                    return_tensors='pt'
                )

                processed.append({
                    'input_ids': tokens['input_ids'].squeeze(),
                    'attention_mask': tokens['attention_mask'].squeeze(),
                    'labels': tokens['input_ids'].squeeze()
                })

            except Exception as e:
                logger.warning(f"Error processing sample: {e}")
                continue

        logger.info(f"Processed {len(processed)} samples")
        return processed

    def __len__(self):
        return len(self.processed_data)

    def __getitem__(self, idx):
        return self.processed_data[idx]


class DummyDataset(Dataset):
    """Dummy dataset for testing"""

    def __init__(self, size: int, tokenizer, max_length: int = 512):
        self.size = size
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        text = f"Q: What is {idx}?\nA: The answer is {idx}."

        tokens = self.tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'input_ids': tokens['input_ids'].squeeze(),
            'attention_mask': tokens['attention_mask'].squeeze(),
            'labels': tokens['input_ids'].squeeze()
        }


def create_dataloaders(
    tokenizer,
    subset: str = "full",
    max_length: int = 128,
    batch_size: int = 8,
    train_split: float = 0.8,
    use_dummy: bool = False
) -> Tuple[DataLoader, DataLoader]:
    """
    Create dataloaders - works with any tokenizer

    Args:
        tokenizer: Any tokenizer object
        subset: TOFU subset to use
        max_length: Maximum sequence length
        batch_size: Batch size
        train_split: Train/val split ratio
        use_dummy: Use dummy data for testing

    Returns:
        (train_loader, val_loader)
    """

    if use_dummy:
        logger.info("Using dummy dataset")
        train_dataset = DummyDataset(800, tokenizer, max_length)
        val_dataset = DummyDataset(200, tokenizer, max_length)

    else:
        try:
            # Load TOFU dataset
            full_dataset = TOFUDataset(subset, tokenizer, max_length, "train")

            # Split into train and validation
            train_size = int(train_split * len(full_dataset))
            val_size = len(full_dataset) - train_size

            train_dataset, val_dataset = torch.utils.data.random_split(
                full_dataset, [train_size, val_size]
            )

            logger.info(f"Created dataloaders: Train={len(train_dataset)}, Val={len(val_dataset)}")

        except Exception as e:
            logger.error(f"Error loading TOFU: {e}")
            logger.info("Falling back to dummy dataset")

            train_dataset = DummyDataset(800, tokenizer, max_length)
            val_dataset = DummyDataset(200, tokenizer, max_length)

    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader


def get_available_subsets() -> Dict[str, str]:
    """Available TOFU subsets"""
    return {
        "forget01": "1% forget, 99% retain",
        "forget05": "5% forget, 95% retain",
        "forget10": "10% forget, 90% retain",
        "retain90": "90% retain only",
        "retain95": "95% retain only",
        "retain99": "99% retain only"
    }


if __name__ == "__main__":
    # Test with dummy tokenizer
    print("Testing data loader...")

    class MockTokenizer:
        def __init__(self):
            self.vocab_size = 1000
            self.pad_token = "<pad>"

        def __call__(self, text, max_length=512, padding='max_length', truncation=True, return_tensors='pt'):
            # Mock tokenization
            tokens = list(range(min(50, max_length)))  # Dummy tokens
            if len(tokens) < max_length:
                tokens.extend([0] * (max_length - len(tokens)))  # Pad

            return {
                'input_ids': torch.tensor(tokens).unsqueeze(0),
                'attention_mask': torch.ones(max_length).unsqueeze(0)
            }

    tokenizer = MockTokenizer()
    train_loader, val_loader = create_dataloaders(tokenizer, use_dummy=True, batch_size=2)

    print(f"Train batches: {len(train_loader)}")
    print(f"Val batches: {len(val_loader)}")

    # Test one batch
    for batch in train_loader:
        print(f"Batch keys: {batch.keys()}")
        print(f"Input shape: {batch['input_ids']}")
        break

    print("Data loader test passed!")