import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"  # to avoid HF Tokenizer warning

import pytorch_lightning as pl

from torch.utils.data import DataLoader, ConcatDataset

from datasets import load_dataset


class ConstrainedLanguageDataModule(pl.LightningDataModule):
    """Base class for our dataloaders."""

    def __init__(self, batch_size=32, num_workers=16):
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers


class HuggingfaceDataModule(ConstrainedLanguageDataModule):
    """
    Generic DataModule for loading Huggingface datasets for text classification tasks.

    Handles standard datasets with a 'text' and 'label' field, or can be configured
    to handle more complex datasets that need preprocessing.
    """

    def __init__(
        self,
        dataset_name,
        num_classes,
        batch_size=32,
        num_workers=16,
        val_split=0.1,
        seed=42,
        dataset_config=None,
        text_fields=None,
        label_field="label",
        test_split="test",
    ):
        super().__init__(batch_size, num_workers)
        self.dataset_name = dataset_name
        self.dataset_config = dataset_config
        self.num_classes = num_classes
        self.val_split = val_split
        self.seed = seed
        self.text_fields = text_fields  # For datasets with multiple text fields to combine
        self.label_field = label_field
        self.test_split = test_split

    def setup(self, stage=None):
        """
        Load the dataset and prepare splits.

        Handles different dataset structures:
        1. Simple datasets with predefined train/test splits
        2. Datasets that need to create a validation split from training data
        3. Datasets with multiple text fields that need to be combined
        """
        # Load the dataset
        dataset = load_dataset(self.dataset_name, self.dataset_config)

        # Handle datasets with multiple text fields that need to be combined (like MNLI)
        if self.text_fields is not None:
            dataset = self._preprocess_text_fields(dataset)

        # Handle train/val split if needed
        if "validation" not in dataset:
            # Split training data into train and validation
            train_val = dataset["train"].train_test_split(test_size=self.val_split, seed=self.seed)
            self.train_data = train_val["train"]
            self.val_data = train_val["test"]
        else:
            self.train_data = dataset["train"]
            # Check if the dataset has multiple validation splits that need to be combined
            validation_splits = [split for split in dataset.keys() if "validation" in split]
            if len(validation_splits) > 1:
                val_datasets = [dataset[split] for split in validation_splits]
                self.val_data = ConcatDataset(val_datasets)
            else:
                self.val_data = dataset["validation"]

        # Set test data
        if self.test_split in dataset:
            self.test_data = dataset[self.test_split]
        else:
            # Use validation data as test if no test split is available
            self.test_data = self.val_data

        # Print dataset information for debugging
        print(f"Train set size: {len(self.train_data)}")
        print(f"Validation set size: {len(self.val_data)}")
        print(f"Test set size: {len(self.test_data)}")

    def _preprocess_text_fields(self, dataset):
        """
        Combine multiple text fields into a single 'text' field.

        Args:
            dataset: The HuggingFace dataset

        Returns:
            The processed dataset with a unified 'text' field
        """

        def combine_text_fields(example):
            if len(self.text_fields) == 2:
                # Common case: two fields with a separator
                field1, field2 = self.text_fields
                texts = [f"{t1} [SEP] {t2}" for t1, t2 in zip(example[field1], example[field2])]
            else:
                # General case: multiple fields with separators
                texts = []
                for i in range(len(example[self.text_fields[0]])):
                    text_parts = []
                    for field in self.text_fields:
                        text_parts.append(example[field][i])
                    texts.append(" [SEP] ".join(text_parts))

            return {
                "text": texts,
                self.label_field: example[self.label_field],
            }

        # Process all splits
        return dataset.map(
            combine_text_fields,
            remove_columns=[col for col in dataset["train"].column_names if col != self.label_field],
            num_proc=max(self.num_workers, 1),
            batched=True,
            desc="Preprocessing text fields",
        )

    def train_dataloader(self):
        """Return the training dataloader."""
        return DataLoader(
            self.train_data, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, pin_memory=True
        )

    def val_dataloader(self):
        """Return the validation dataloader."""
        return DataLoader(self.val_data, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)

    def test_dataloader(self):
        """Return the test dataloader."""
        return DataLoader(self.test_data, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)

    def get_num_classes(self):
        """Return the number of classes in the dataset."""
        return self.num_classes


class IMDBDataModule(HuggingfaceDataModule):
    """
    DataModule for IMDB Movie Reviews dataset (binary sentiment classification).
    Uses the HuggingfaceDataModule for implementation with preconfigured parameters.
    """

    def __init__(self, batch_size=32, num_workers=16, val_split=0.1, seed=42):
        super().__init__(
            dataset_name="imdb",
            num_classes=2,
            batch_size=batch_size,
            num_workers=num_workers,
            val_split=val_split,
            seed=seed,
        )




class MNLIDataModule(ConstrainedLanguageDataModule):
    def __init__(self, batch_size=32, num_workers=16):
        super().__init__(batch_size, num_workers)
        self.dataset_name = "nyu-mll/glue"
        self.dataset_config = "mnli"
        self.num_classes = 3

    def setup(self, stage=None):
        """
        Load the MNLI dataset with all its splits.

        The MNLI dataset has the following structure:
        - train
        - validation_matched
        - validation_mismatched
        - test_matched
        - test_mismatched

        Each example contains:
        - premise: text that establishes a scenario
        - hypothesis: text to be evaluated against the premise
        - label: 0 (contradiction), 1 (entailment), or 2 (neutral)
        - idx: example index
        """
        self.dataset = load_dataset(self.dataset_name, self.dataset_config)

        # Concatenate premise and hypothesis into a single text field
        self.dataset = self.dataset.map(
            lambda example: {
                "text": [p + " [SEP] " + h for p, h in zip(example["premise"], example["hypothesis"])],
                'label': example['label'],
            },
            remove_columns=self.dataset['train'].column_names,
            num_proc=max(self.num_workers, 1),
            batched=True,
            desc="Concatenating premise and hypothesis",
        )
        # Access the different splits
        self.train_data = self.dataset["train"]
        self.val_matched_data = self.dataset["validation_matched"]
        self.val_mismatched_data = self.dataset["validation_mismatched"]
        self.test_matched_data = self.dataset["test_matched"]
        self.test_mismatched_data = self.dataset["test_mismatched"]

        # Print dataset information for debugging
        print(f"Train set size: {len(self.train_data)}")
        print(f"Validation matched size: {len(self.val_matched_data)}")
        print(f"Validation mismatched size: {len(self.val_mismatched_data)}")
        print(f"Test matched size: {len(self.test_matched_data)}")
        print(f"Test mismatched size: {len(self.test_mismatched_data)}")

    def train_dataloader(self):
        """Return the training dataloader."""
        return DataLoader(
            self.train_data, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, pin_memory=True
        )

    def val_dataloader(self):
        """
        Return a single validation dataloader that combines matched and mismatched data.
        """
        # Concatenate the matched and mismatched validation datasets
        combined_val_dataset = ConcatDataset([self.val_matched_data, self.val_mismatched_data])

        return DataLoader(combined_val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        """
        Return a single test dataloader that combines matched and mismatched data.

        Note: Using validation data for testing as MNLI test is blind.
        """
        # Use val for testing because MNLI test is blind
        return self.val_dataloader()

    def get_num_classes(self):
        """Return the number of classes in the MNLI dataset (3)."""
        return 3
