from __future__ import annotations

from pathlib import Path
from typing import Callable, Concatenate, Self, TypeGuard, cast, overload

import datasets.features.features
import numpy as np

import datasets
from datasets import Dataset

type DatasetMapping = Callable[[Dataset], Dataset]


def map_queue[T: DatasetBuilder, **P](
    func: Callable[Concatenate[T, P], DatasetMapping],
) -> Callable[Concatenate[T, P], T]:
    def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> T:
        self._queue(func(self, *args, **kwargs))
        return self

    return wrapper


class DatasetBuilder:
    def __init__(self, dataset: Dataset) -> None:
        self.__dataset = dataset
        self.__map_queue: list[DatasetMapping] = []

    def _queue(self, func: DatasetMapping) -> None:
        self.__map_queue.append(func)

    @classmethod
    def load(
        cls,
        dataset_path: str | Path,
        keep_in_memory: bool | None = None,
        storage_options: dict | None = None,
        **options,
    ):
        dataset = Dataset.load_from_disk(
            dataset_path,
            keep_in_memory=keep_in_memory,
            storage_options=storage_options,
        )
        return cls(dataset, **options)

    @map_queue
    def filter(
        self,
        function: Callable | None = None,
        with_indices: bool = False,
        with_rank: bool = False,
        input_columns: str | list[str] | None = None,
        batched: bool = False,
        batch_size: int | None = 1000,
        keep_in_memory: bool = False,
        load_from_cache_file: bool | None = None,
        cache_file_name: str | None = None,
        writer_batch_size: int | None = 1000,
        fn_kwargs: dict | None = None,
        num_proc: int | None = None,
        suffix_template: str = "_{rank:05d}_of_{num_proc:05d}",
        new_fingerprint: str | None = None,
        desc: str | None = None,
    ) -> DatasetMapping:
        return lambda dataset: dataset.filter(
            function=function,
            with_indices=with_indices,
            with_rank=with_rank,
            input_columns=input_columns,
            batched=batched,
            batch_size=batch_size,
            keep_in_memory=keep_in_memory,
            load_from_cache_file=load_from_cache_file,
            cache_file_name=cache_file_name,
            writer_batch_size=writer_batch_size,
            fn_kwargs=fn_kwargs,
            num_proc=num_proc,
            suffix_template=suffix_template,
            new_fingerprint=new_fingerprint,
            desc=desc,
        )

    @overload
    def doif[T: DatasetBuilder](
        self,
        condition: Callable[[Self], TypeGuard[T]],
        /,
        function: Callable[[T], T],
    ) -> Self: ...
    @overload
    def doif[T: DatasetBuilder](
        self,
        condition: Callable[[Self], bool],
        /,
        function: Callable[[Self], Self],
    ) -> Self: ...
    def doif[T: DatasetBuilder](
        self,
        condition: Callable[[Self], bool] | Callable[[Self], TypeGuard[T]],
        /,
        function: Callable[[Self], Self] | Callable[[T], T],
    ) -> Self:
        return function(self) if condition(self) else self  # type: ignore

    @map_queue
    def map(
        self,
        function: Callable | None = None,
        with_indices: bool = False,
        with_rank: bool = False,
        input_columns: str | list[str] | None = None,
        batched: bool = False,
        batch_size: int | None = 1000,
        drop_last_batch: bool = False,
        remove_columns: str | list[str] | None = None,
        keep_in_memory: bool = False,
        load_from_cache_file: bool | None = None,
        cache_file_name: str | None = None,
        writer_batch_size: int | None = 1000,
        features: datasets.Features | None = None,
        disable_nullable: bool = False,
        fn_kwargs: dict | None = None,
        num_proc: int | None = None,
        suffix_template: str = "_{rank:05d}_of_{num_proc:05d}",
        new_fingerprint: str | None = None,
        desc: str | None = None,
    ) -> DatasetMapping:
        return lambda dataset: dataset.map(
            function=function,
            with_indices=with_indices,
            with_rank=with_rank,
            input_columns=input_columns,
            batched=batched,
            batch_size=batch_size,
            drop_last_batch=drop_last_batch,
            remove_columns=remove_columns,
            keep_in_memory=keep_in_memory,
            load_from_cache_file=load_from_cache_file,
            cache_file_name=cache_file_name,
            writer_batch_size=writer_batch_size,
            features=features,
            disable_nullable=disable_nullable,
            fn_kwargs=fn_kwargs,
            num_proc=num_proc,
            suffix_template=suffix_template,
            new_fingerprint=new_fingerprint,
            desc=desc,
        )

    @map_queue
    def mapif(
        self,
        condition: Callable[[Self], bool],
        /,
        function: Callable | None = None,
        with_indices: bool = False,
        with_rank: bool = False,
        input_columns: str | list[str] | None = None,
        batched: bool = False,
        batch_size: int | None = 1000,
        drop_last_batch: bool = False,
        remove_columns: str | list[str] | None = None,
        keep_in_memory: bool = False,
        load_from_cache_file: bool | None = None,
        cache_file_name: str | None = None,
        writer_batch_size: int | None = 1000,
        features: datasets.Features | None = None,
        disable_nullable: bool = False,
        fn_kwargs: dict | None = None,
        num_proc: int | None = None,
        suffix_template: str = "_{rank:05d}_of_{num_proc:05d}",
        new_fingerprint: str | None = None,
        desc: str | None = None,
    ) -> DatasetMapping:
        return lambda dataset: (
            dataset.map(
                function=function,
                with_indices=with_indices,
                with_rank=with_rank,
                input_columns=input_columns,
                batched=batched,
                batch_size=batch_size,
                drop_last_batch=drop_last_batch,
                remove_columns=remove_columns,
                keep_in_memory=keep_in_memory,
                load_from_cache_file=load_from_cache_file,
                cache_file_name=cache_file_name,
                writer_batch_size=writer_batch_size,
                features=features,
                disable_nullable=disable_nullable,
                fn_kwargs=fn_kwargs,
                num_proc=num_proc,
                suffix_template=suffix_template,
                new_fingerprint=new_fingerprint,
                desc=desc,
            )
            if condition(self)
            else dataset
        )

    @map_queue
    def merge(self, other: Self) -> DatasetMapping:
        return lambda dataset: Dataset.from_dict(
            {
                col: dataset[col] + other.unwrap()[col]
                for col in self.__dataset.column_names
            }
        )

    @map_queue
    def shuffle(
        self,
        seed: int | None = None,
        generator: np.random.Generator | None = None,
        keep_in_memory: bool = False,
        load_from_cache_file: bool | None = None,
        indices_cache_file_name: str | None = None,
        writer_batch_size: int | None = 1000,
        new_fingerprint: str | None = None,
    ) -> DatasetMapping:
        return lambda dataset: dataset.shuffle(
            seed=seed,
            generator=generator,
            keep_in_memory=keep_in_memory,
            load_from_cache_file=load_from_cache_file,
            indices_cache_file_name=indices_cache_file_name,
            writer_batch_size=writer_batch_size,
            new_fingerprint=new_fingerprint,
        )

    def __add__(self, other: Self) -> Self:
        return self.merge(other)

    def __iadd__(self, other: Self) -> Self:
        return self.merge(other)

    @map_queue
    def flatten(
        self, new_fingerprint: str | None = None, max_depth: int = 16
    ) -> DatasetMapping:
        return lambda dataset: dataset.flatten(
            new_fingerprint=new_fingerprint, max_depth=max_depth
        )

    @map_queue
    def rename_column(
        self,
        original_column_name: str,
        new_column_name: str,
        new_fingerprint: str | None = None,
    ) -> DatasetMapping:
        return lambda dataset: dataset.rename_column(
            original_column_name,
            new_column_name,
            new_fingerprint=new_fingerprint,
        )

    @map_queue
    def remove_columns(
        self,
        column_names: str | list[str],
        new_fingerprint: str | None = None,
    ) -> DatasetMapping:
        return lambda dataset: dataset.remove_columns(
            column_names, new_fingerprint=new_fingerprint
        )

    @map_queue
    def cast(
        self,
        features: datasets.Features,
        batch_size: int | None = 1000,
        keep_in_memory: bool = False,
        load_from_cache_file: bool | None = None,
        cache_file_name: str | None = None,
        writer_batch_size: int | None = 1000,
        num_proc: int | None = None,
    ) -> DatasetMapping:
        return lambda dataset: dataset.cast(
            features,
            batch_size=batch_size,
            keep_in_memory=keep_in_memory,
            load_from_cache_file=load_from_cache_file,
            cache_file_name=cache_file_name,
            writer_batch_size=writer_batch_size,
            num_proc=num_proc,
        )

    @map_queue
    def cast_column(
        self,
        column: str,
        feature: datasets.features.features.FeatureType,
        new_fingerprint: str | None = None,
    ) -> DatasetMapping:
        return lambda dataset: dataset.cast_column(
            column, feature, new_fingerprint=new_fingerprint
        )

    @map_queue
    def take(self, n: int) -> DatasetMapping:
        return lambda dataset: dataset.take(n)

    def unwrap(
        self,
        type: str | None = None,
        columns: list | None = None,
        output_all_columns: bool = False,
        **format_kwargs,
    ) -> Dataset:
        dataset = self.__dataset
        while self.__map_queue:
            func = self.__map_queue.pop(0)
            dataset = func(dataset)

        dataset.set_format(
            type=type,
            columns=columns,
            output_all_columns=output_all_columns,
            **format_kwargs,
        )
        return dataset

    @map_queue
    def summary(self):
        def wrapper(dataset: Dataset) -> Dataset:
            print(f"Queued operations: {len(self.__map_queue)}")
            print(f"Number of rows: {len(dataset)}")
            print(f"Columns: {dataset.column_names}")
            print(f"Features: {dataset.features}")
            return dataset

        return wrapper
