from abc import ABC, abstractmethod
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizer
from typing import Any, List

from src.dataset_processing.perplexity.common.config.base_configs import PerplexityDatasetConfig
from src.dataset_processing.perplexity.common.dataset import BasePerplexityDataset
from src.dataset_processing.perplexity.common.models.dataset_entry import PerplexityDatasetEntry
from src.dataset_processing.perplexity.common.models.dataset_result import PerplexityDatasetResult

class BaseProcessor(ABC):
    """Base processor for all perplexity datasets."""
    
    def __init__(self, tokenizer: PreTrainedTokenizer, stride: int = 512):
        self.tokenizer = tokenizer
        self.stride = stride

    @abstractmethod
    def load_raw_data(self, config: PerplexityDatasetConfig) -> Any:
        """Load dataset-specific raw data."""
        pass

    @abstractmethod
    def process_raw_data(self, raw_data: Any, config: PerplexityDatasetConfig) -> List[PerplexityDatasetEntry]:
        """Process dataset-specific raw data into entries."""
        pass

    def process_dataset(self, config: PerplexityDatasetConfig) -> DataLoader:
        """Common dataset processing pipeline."""
        raw_data = self.load_raw_data(config)
        entries = self.process_raw_data(raw_data, config)
        
        result = PerplexityDatasetResult(entries=entries, config=config)
        return DataLoader(
            BasePerplexityDataset(result.entries),
            batch_size=config.batch_size,
            shuffle=False
        )