#!/usr/bin/env python3
"""
PyTorch Dataset for annDNA training
"""

import torch
from torch.utils.data import Dataset
import numpy as np


class GenomeDataset(Dataset):
    def __init__(self, data_dir, split='train'):
        """
        Args:
            data_dir: Path to processed data directory
            split: 'train' or 'val'
        """
        print(f"Loading {split} data...")

        self.input_ids = np.load(f"{data_dir}/{split}_input_ids.npy")
        self.labels = np.load(f"{data_dir}/{split}_labels.npy")
        self.attention_mask = np.load(f"{data_dir}/{split}_attention_mask.npy")

        print(f"Loaded {len(self.input_ids):,} {split} samples")

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {
            'input_ids': torch.tensor(self.input_ids[idx], dtype=torch.long),
            'labels': torch.tensor(self.labels[idx], dtype=torch.long),
            'attention_mask': torch.tensor(self.attention_mask[idx], dtype=torch.long)
        }
