from functools import partial
from typing import Optional

import torch as t
from datasets import Dataset
from datasets import Dataset as HFDataset
from datasets import IterableDataset, load_dataset
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizerBase

from auto_encoder import debug
from data.combined_dataset import get_combined_dataset
from data.data_prep import chunk_examples, data_collator

DEFAULT_BATCH_SIZE = 3 if debug else 32


def get_train_dataloader(
    use_combined_dataset: bool = False,
    batch_size: int = DEFAULT_BATCH_SIZE,
    tokenizer: Optional[PreTrainedTokenizerBase] = None,
    max_length: int = 256,
) -> DataLoader:

    if tokenizer is not None:
        _chunk_examples = partial(chunk_examples, tokenizer=tokenizer, max_length=max_length)
        _data_collator = partial(data_collator, tokenizer=tokenizer)
    else:
        _chunk_examples = chunk_examples
        _data_collator = data_collator

    # Set up the dataset
    dataset: IterableDataset
    if use_combined_dataset:
        train_dataset, _num_steps = get_combined_dataset()
    else:
        from datasets import load_dataset

        dataset = load_dataset(
            # "allenai/c4",
            # "en",
            "HuggingFaceFW/fineweb",
            "CC-MAIN-2024-10",
            split="train",
            streaming=True,
            trust_remote_code=True,
        )  # type: ignore
        dataset.shuffle(seed=42)

        processed_dataset = dataset.map(
            _chunk_examples, batched=True, remove_columns=dataset.column_names
        )

        processed_dataset.with_format(type="torch")
        train_dataset = processed_dataset

    train_dataloader = DataLoader(
        train_dataset,  #  type: ignore
        batch_size=batch_size,
        collate_fn=_data_collator,
        pin_memory=True,
    )

    return train_dataloader


def get_eval_dataloader(
    batch_size: int = DEFAULT_BATCH_SIZE,
    tokenizer: Optional[PreTrainedTokenizerBase] = None,
    max_length: int = 256,
) -> DataLoader:

    if tokenizer is not None:
        _chunk_examples = partial(chunk_examples, tokenizer=tokenizer, max_length=max_length)
        _data_collator = partial(data_collator, tokenizer=tokenizer)
    else:
        _chunk_examples = chunk_examples
        _data_collator = data_collator

    dataset: Dataset
    dataset = load_dataset(
        # "allenai/c4",
        # "en",
        # split="validation",
        "HuggingFaceFW/fineweb",
        "CC-MAIN-2024-10",
        split="train",
        streaming=True,
        trust_remote_code=True,
    )  # type: ignore
    dataset.shuffle(seed=42)

    processed_dataset = dataset.map(
        _chunk_examples, batched=True, remove_columns=dataset.column_names
    )

    processed_dataset.with_format(type="torch")
    eval_dataset = processed_dataset

    eval_dataloader = DataLoader(
        eval_dataset,  #  type: ignore
        batch_size=batch_size,
        collate_fn=_data_collator,
        pin_memory=True,
    )

    return eval_dataloader


def get_image_dataloader(
    batch_size: int, dataset_name: str = "ylecun/mnist"
) -> tuple[DataLoader, DataLoader]:
    # Download and load the training data
    assert dataset_name in ("ylecun/mnist", "cifar10", "cifar100")

    train_dataset: HFDataset = load_dataset(dataset_name, split="train", trust_remote_code=True)  # type: ignore
    eval_dataset: HFDataset = load_dataset(dataset_name, split="test", trust_remote_code=True)  # type: ignore

    def transform_func(examples):
        examples["image"] = [t.tensor(img) for img in examples["image"]]
        return examples

    train_dataset = train_dataset.with_transform(transform_func)
    train_dataset.set_format("torch")
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)  # type: ignore

    eval_dataset = eval_dataset.with_transform(transform_func)
    eval_dataset.set_format("torch")
    eval_loader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=True)  # type: ignore

    return train_loader, eval_loader


if __name__ == "__main__":
    dataloader = get_train_dataloader()
    # Get first 5 elements
    first_5_elements = []
    for i, batch in enumerate(dataloader):
        if i >= 5:
            break
        first_5_elements.append(batch["input_ids"])
        print(batch["input_ids"].shape)
