import copy
import re
from abc import ABC
from typing import (
    Generic,
    TypeVar,
    Union,
    Sequence,
    Callable,
    Optional,
    Dict,
    Any,
    Iterable,
    List,
    Set,
    Tuple,
    NamedTuple,
    Mapping,
)

from typing_extensions import Protocol

import warnings
from torch.utils.data.dataset import Dataset

from avalanche.benchmarks.scenarios.generic_definitions import (
    TCLExperience,
    TCLStream,
    ClassificationExperience,
    TCLScenario,
)
from avalanche.benchmarks.scenarios.lazy_dataset_sequence import (
    LazyDatasetSequence,
)
from avalanche.benchmarks.utils import make_classification_dataset
from avalanche.benchmarks.utils.classification_dataset import (
    ClassificationDataset,
)
from avalanche.benchmarks.utils.dataset_utils import manage_advanced_indexing

TGenericCLClassificationScenario = TypeVar(
    "TGenericCLClassificationScenario", bound="GenericCLScenario"
)
TGenericClassificationExperience = TypeVar(
    "TGenericClassificationExperience", bound="GenericClassificationExperience"
)
TGenericScenarioStream = TypeVar(
    "TGenericScenarioStream", bound="ClassificationStream"
)

TStreamDataOrigin = Union[
    make_classification_dataset,
    Sequence[make_classification_dataset],
    Tuple[Iterable[make_classification_dataset], int],
]
TStreamTaskLabels = Optional[Sequence[Union[int, Set[int]]]]
TOriginDataset = Optional[Dataset]


# The definitions used to accept user stream definition
# Those definitions allow for a more simpler usage as they don't
# mandate setting task labels and the origin dataset
class StreamUserDef(NamedTuple):
    exps_data: TStreamDataOrigin
    exps_task_labels: TStreamTaskLabels = None
    origin_dataset: TOriginDataset = None
    is_lazy: Optional[bool] = None


TStreamUserDef = Union[
    Tuple[TStreamDataOrigin, TStreamTaskLabels, TOriginDataset, bool],
    Tuple[TStreamDataOrigin, TStreamTaskLabels, TOriginDataset],
    Tuple[TStreamDataOrigin, TStreamTaskLabels],
    Tuple[TStreamDataOrigin],
]

TStreamsUserDict = Dict[str, StreamUserDef]


# The definitions used to store stream definitions
class StreamDef(NamedTuple):
    exps_data: LazyDatasetSequence
    exps_task_labels: Sequence[Set[int]]
    origin_dataset: TOriginDataset
    is_lazy: bool


TStreamsDict = Dict[str, StreamDef]

STREAM_NAME_REGEX = re.compile("^[A-Za-z][A-Za-z_\\d]*$")


class GenericCLScenario(Generic[TCLExperience]):
    """
    Base implementation of a Continual Learning benchmark instance.
    A Continual Learning benchmark instance is defined by a set of streams of
    experiences (batches or tasks depending on the terminology). Each experience
    contains the training (or test, or validation, ...) data that becomes
    available at a certain time instant.

    Experiences are usually defined in children classes, with this class serving
    as the more general implementation. This class handles the most simple type
    of assignment: each stream is defined as a list of experiences, each
    experience is defined by a dataset.

    Defining the "train" and "test" streams is mandatory. This class supports
    custom streams as well. Custom streams can be accessed by using the
    `streamname_stream` field of the created instance.

    The name of custom streams can only contain letters, numbers or the "_"
    character and must not start with a number.
    """

    def __init__(
        self: TGenericCLClassificationScenario,
        *,
        stream_definitions: TStreamsUserDict,
        complete_test_set_only: bool = False,
        experience_factory: Callable[
            ["ClassificationStream", int], TCLExperience
        ] = None,
    ):
        """
        Creates an instance of a Continual Learning benchmark instance.

        The benchmark instance is defined by a stream definition dictionary,
        which describes the content of each stream. The "train" and "test"
        stream are mandatory. Any other custom stream can be added.

        There is no constraint on the amount of experiences in each stream
        (excluding the case in which `complete_test_set_only` is set).

        :param stream_definitions: The stream definitions dictionary. Must
            be a dictionary where the key is the stream name and the value
            is the definition of that stream. "train" and "test" streams are
            mandatory. This class supports custom streams as well. The name of
            custom streams can only contain letters, numbers and the "_"
            character and must not start with a number. Streams can be defined
            is two ways: static and lazy. In the static case, the
            stream must be a tuple containing 1, 2 or 3 elements:
            - The first element must be a list containing the datasets
            describing each experience. Datasets must be instances of
            :class:`AvalancheDataset`.
            - The second element is optional and must be a list containing the
            task labels of each experience (as an int or a set of ints).
            If the stream definition tuple contains only one element (the list
            of datasets), then the task labels for each experience will be
            obtained by inspecting the content of the datasets.
            - The third element is optional and must be a reference to the
            originating dataset (if applicable). For instance, for SplitMNIST
            this may be a reference to the whole MNIST dataset. If the stream
            definition tuple contains less than 3 elements, then the reference
            to the original dataset will be set to None.
            In the lazy case, the stream must be defined as a tuple with 2
            elements:
            - The first element must be a tuple containing the dataset generator
            (one for each experience) and the number of experiences in that
            stream.
            - The second element must be a list containing the task labels of
            each experience (as an int or a set of ints).
        :param complete_test_set_only: If True, the test stream will contain
            a single experience containing the complete test set. This also
            means that the definition for the test stream must contain the
            definition for a single experience.
        :param experience_factory: If not None, a callable that, given the
            benchmark instance and the experience ID, returns an experience
            instance. This parameter is usually used in subclasses (when
            invoking the super constructor) to specialize the experience class.
            Defaults to None, which means that the :class:`GenericExperience`
            constructor will be used.
        """

        self.stream_definitions = GenericCLScenario._check_stream_definitions(
            stream_definitions
        )
        """
        A structure containing the definition of the streams.
        """

        self.original_train_dataset: Optional[
            Dataset
        ] = self.stream_definitions["train"].origin_dataset
        """ The original training set. May be None. """

        self.original_test_dataset: Optional[Dataset] = self.stream_definitions[
            "test"
        ].origin_dataset
        """ The original test set. May be None. """

        self.train_stream: ClassificationStream[
            TCLExperience, TGenericCLClassificationScenario
        ] = ClassificationStream("train", self)
        """
        The stream used to obtain the training experiences. 
        This stream can be sliced in order to obtain a subset of this stream.
        """

        self.test_stream: ClassificationStream[
            TCLExperience, TGenericCLClassificationScenario
        ] = ClassificationStream("test", self)
        """
        The stream used to obtain the test experiences. This stream can be 
        sliced in order to obtain a subset of this stream.

        Beware that, in certain scenarios, this stream may contain a single
        element. Check the ``complete_test_set_only`` field for more details.
        """

        self.complete_test_set_only: bool = bool(complete_test_set_only)
        """
        If True, only the complete test set will be returned from experience
        instances.

        This flag is usually set to True in scenarios where having one separate
        test set aligned to each training experience is impossible or doesn't
        make sense from a semantic point of view.
        """

        if self.complete_test_set_only:
            if len(self.stream_definitions["test"].exps_data) > 1:
                raise ValueError(
                    "complete_test_set_only is True, but the test stream"
                    " contains more than one experience"
                )

        if experience_factory is None:
            experience_factory = GenericClassificationExperience

        self.experience_factory: Callable[
            [TGenericScenarioStream, int], TCLExperience
        ] = experience_factory

        # Create the original_<stream_name>_dataset fields for other streams
        self._make_original_dataset_fields()

        # Create the <stream_name>_stream fields for other streams
        self._make_stream_fields()

    @property
    def streams(
        self,
    ) -> Dict[
        str,
        "ClassificationStream["
        "TCLExperience, TGenericCLClassificationScenario]",
    ]:
        streams_dict = dict()
        for stream_name in self.stream_definitions.keys():
            streams_dict[stream_name] = getattr(self, f"{stream_name}_stream")

        return streams_dict

    @property
    def n_experiences(self) -> int:
        """The number of incremental training experiences contained
        in the train stream."""
        return len(self.stream_definitions["train"].exps_data)

    @property
    def task_labels(self) -> Sequence[List[int]]:
        """The task label of each training experience."""
        t_labels = []

        for exp_t_labels in self.stream_definitions["train"].exps_task_labels:
            t_labels.append(list(exp_t_labels))

        return t_labels

    def get_reproducibility_data(self) -> Dict[str, Any]:
        """
        Gets the data needed to reproduce this experiment.

        This data can be stored using the pickle module or some other mechanism.
        It can then be loaded by passing it as the ``reproducibility_data``
        parameter in the constructor.

        Child classes should create their own reproducibility dictionary.
        This means that the implementation found in :class:`GenericCLScenario`
        will return an empty dictionary, which is meaningless.

        In order to obtain the same benchmark instance, the reproducibility
        data must be passed to the constructor along with the exact same
        input datasets.

        :return: A dictionary containing the data needed to reproduce the
            experiment.
        """

        return dict()

    @property
    def classes_in_experience(
        self,
    ) -> Mapping[str, Sequence[Optional[Set[int]]]]:
        """
        A dictionary mapping each stream (by name) to a list.

        Each element of the list is a set describing the classes included in
        that experience (identified by its index).

        In previous releases this field contained the list of sets for the
        training stream (that is, there was no way to obtain the list for other
        streams). That behavior is deprecated and support for that usage way
        will be removed in the future.
        """

        return LazyStreamClassesInExps(self)

    def get_classes_timeline(
        self, current_experience: int, stream: str = "train"
    ):
        """
        Returns the classes timeline given the ID of a experience.

        Given a experience ID, this method returns the classes in that
        experience, previously seen classes, the cumulative class list and a
        list of classes that will be encountered in next experiences of the
        same stream.

        Beware that by default this will obtain the timeline of an experience
        of the **training** stream. Use the stream parameter to select another
        stream.

        :param current_experience: The reference experience ID.
        :param stream: The stream name.
        :return: A tuple composed of four lists: the first list contains the
            IDs of classes in this experience, the second contains IDs of
            classes seen in previous experiences, the third returns a cumulative
            list of classes (that is, the union of the first two list) while the
            last one returns a list of classes that will be encountered in next
            experiences. Beware that each of these elements can be None when
            the benchmark is initialized by using a lazy generator.
        """

        class_set_current_exp = self.classes_in_experience[stream][
            current_experience
        ]

        if class_set_current_exp is not None:
            # May be None in lazy benchmarks
            classes_in_this_exp = list(class_set_current_exp)
        else:
            classes_in_this_exp = None

        class_set_prev_exps: Optional[Set] = set()
        for exp_id in range(0, current_experience):
            prev_exp_classes = self.classes_in_experience[stream][exp_id]
            if prev_exp_classes is None:
                # May be None in lazy benchmarks
                class_set_prev_exps = None
                break
            class_set_prev_exps.update(prev_exp_classes)

        if class_set_prev_exps is not None:
            previous_classes = list(class_set_prev_exps)
        else:
            previous_classes = None

        if (
            class_set_current_exp is not None
            and class_set_prev_exps is not None
        ):
            classes_seen_so_far = list(
                class_set_current_exp.union(class_set_prev_exps)
            )
        else:
            classes_seen_so_far = None

        class_set_future_exps = set()
        stream_n_exps = len(self.classes_in_experience[stream])
        for exp_id in range(current_experience + 1, stream_n_exps):
            future_exp_classes = self.classes_in_experience[stream][exp_id]
            if future_exp_classes is None:
                class_set_future_exps = None
                break
            class_set_future_exps.update(future_exp_classes)

        if class_set_future_exps is not None:
            future_classes = list(class_set_future_exps)
        else:
            future_classes = None

        return (
            classes_in_this_exp,
            previous_classes,
            classes_seen_so_far,
            future_classes,
        )

    def _make_original_dataset_fields(self):
        for stream_name, stream_def in self.stream_definitions.items():
            if stream_name in ["train", "test"]:
                continue

            orig_dataset = stream_def.origin_dataset
            setattr(self, f"original_{stream_name}_dataset", orig_dataset)

    def _make_stream_fields(self):
        for stream_name, stream_def in self.stream_definitions.items():
            if stream_name in ["train", "test"]:
                continue

            stream_obj = ClassificationStream(stream_name, self)
            setattr(self, f"{stream_name}_stream", stream_obj)

    @staticmethod
    def _check_stream_definitions(
        stream_definitions: TStreamsUserDict,
    ) -> TStreamsDict:
        """
        A function used to check the input stream definitions.

        This function should returns the adapted definition in which the
        missing optional fields are filled. If the input definition doesn't
        follow the expected structure, a `ValueError` will be raised.

        :param stream_definitions: The input stream definitions.
        :return: The checked and adapted stream definitions.
        """
        streams_defs = dict()

        if "train" not in stream_definitions:
            raise ValueError("No train stream found!")

        if "test" not in stream_definitions:
            raise ValueError("No test stream found!")

        for stream_name, stream_def in stream_definitions.items():
            GenericCLScenario._check_stream_name(stream_name)
            stream_def = GenericCLScenario._check_and_adapt_user_stream_def(
                stream_def, stream_name
            )
            streams_defs[stream_name] = stream_def

        return streams_defs

    @staticmethod
    def _check_stream_name(stream_name: Any):
        if not isinstance(stream_name, str):
            raise ValueError('Invalid type for stream name. Must be a "str"')

        if STREAM_NAME_REGEX.fullmatch(stream_name) is None:
            raise ValueError(f"Invalid name for stream {stream_name}")

    @staticmethod
    def _check_and_adapt_user_stream_def(
        stream_def: TStreamUserDef, stream_name: str
    ) -> StreamDef:
        exp_data = stream_def[0]
        task_labels = None
        origin_dataset = None
        is_lazy = None

        if len(stream_def) > 1:
            task_labels = stream_def[1]

        if len(stream_def) > 2:
            origin_dataset = stream_def[2]

        if len(stream_def) > 3:
            is_lazy = stream_def[3]

        if is_lazy or (isinstance(exp_data, tuple) and (is_lazy is None)):
            # Creation based on a generator
            if is_lazy:
                # We also check for LazyDatasetSequence, which is sufficient
                # per se (only if is_lazy==True, otherwise is treated as a
                # standard Sequence)
                if not isinstance(exp_data, LazyDatasetSequence):
                    if (not isinstance(exp_data, tuple)) or (
                        not len(exp_data) == 2
                    ):
                        raise ValueError(
                            f"The stream {stream_name} was flagged as "
                            f"lazy-generated but its definition is not a "
                            f"2-elements tuple (generator and stream length)."
                        )
            else:
                if (not len(exp_data) == 2) or (
                    not isinstance(exp_data[1], int)
                ):
                    raise ValueError(
                        f"The stream {stream_name} was detected "
                        f"as lazy-generated but its definition is not a "
                        f"2-elements tuple. If you're trying to define a "
                        f"non-lazily generated stream, don't use a tuple "
                        f"when passing the list of datasets, use a list "
                        f"instead."
                    )

            if isinstance(exp_data, LazyDatasetSequence):
                stream_length = len(exp_data)
            else:
                # exp_data[0] must contain the generator
                stream_length = exp_data[1]
            is_lazy = True
        elif isinstance(exp_data, ClassificationDataset):
            # Single element
            exp_data = [exp_data]
            is_lazy = False
            stream_length = 1
        else:
            # Standard def
            stream_length = len(exp_data)
            is_lazy = False

        if not is_lazy:
            for i, dataset in enumerate(exp_data):
                if not isinstance(dataset, ClassificationDataset):
                    raise ValueError(
                        "All experience datasets must be subclasses of"
                        " AvalancheDataset"
                    )

        if task_labels is None:
            if is_lazy:
                raise ValueError(
                    "Task labels must be defined for each experience when "
                    "creating the stream using a generator."
                )

            # Extract task labels from the dataset
            task_labels = []
            for i in range(len(exp_data)):
                exp_dataset: ClassificationDataset = exp_data[i]
                task_labels.append(set(exp_dataset.targets_task_labels))
        else:
            # Standardize task labels structure
            task_labels = list(task_labels)
            for i in range(len(task_labels)):
                if isinstance(task_labels[i], int):
                    task_labels[i] = {task_labels[i]}
                elif not isinstance(task_labels[i], set):
                    task_labels[i] = set(task_labels[i])

        if stream_length != len(task_labels):
            raise ValueError(
                f"{len(exp_data)} experiences have been defined, but task "
                f"labels for {len(task_labels)} experiences are given."
            )

        if is_lazy:
            if isinstance(exp_data, LazyDatasetSequence):
                lazy_sequence = exp_data
            else:
                lazy_sequence = LazyDatasetSequence(exp_data[0], stream_length)
        else:
            lazy_sequence = LazyDatasetSequence(exp_data, stream_length)
            lazy_sequence.load_all_experiences()

        return StreamDef(lazy_sequence, task_labels, origin_dataset, is_lazy)


class ClassificationScenarioStream(Protocol[TCLScenario, TCLExperience]):
    """
    A scenario stream describes a sequence of incremental experiences.
    Experiences are described as :class:`IExperience` instances. They contain a
    set of patterns which has become available at a particular time instant
    along with any optional, scenario specific, metadata.

    Most scenario expose two different streams: the training stream and the test
    stream.
    """

    name: str
    """
    The name of the stream.
    """

    benchmark: TCLScenario
    """
    A reference to the scenario this stream belongs to.
    """

    @property
    def scenario(self) -> TCLScenario:
        """This property is DEPRECATED, use self.benchmark instead."""
        warnings.warn(
            "Using self.scenario is deprecated ScenarioStream. "
            "Consider using self.benchmark instead.",
            stacklevel=2,
        )
        return self.benchmark

    def __getitem__(
        self: TCLStream, experience_idx: Union[int, slice, Iterable[int]]
    ) -> Union[TCLExperience, TCLStream]:
        """
        Gets an experience given its experience index (or a stream slice given
        the experience order).

        :param experience_idx: An int describing the experience index or an
            iterable/slice object describing a slice of this stream.
        :return: The Experience instance associated to the given experience
            index or a sliced stream instance.
        """
        ...

    def __len__(self) -> int:
        """
        Used to get the length of this stream (the amount of experiences).

        :return: The amount of experiences in this stream.
        """
        ...


class ClassificationStream(
    Generic[TCLExperience, TGenericCLClassificationScenario],
    ClassificationScenarioStream[
        TGenericCLClassificationScenario, TCLExperience
    ],
    Sequence[TCLExperience],
):
    def __init__(
        self: TGenericScenarioStream,
        name: str,
        benchmark: TGenericCLClassificationScenario,
        *,
        slice_ids: List[int] = None,
    ):
        self.slice_ids: Optional[List[int]] = slice_ids
        """
        Describes which experiences are contained in the current stream slice. 
        Can be None, which means that this object is the original stream. """

        self.name: str = name
        """
        The name of the stream (for instance: "train", "test", "valid", ...).
        """

        self.benchmark = benchmark
        """
        A reference to the benchmark.
        """

    def __len__(self) -> int:
        """
        Gets the number of experiences this stream it's made of.

        :return: The number of experiences in this stream.
        """
        if self.slice_ids is None:
            return len(self.benchmark.stream_definitions[self.name].exps_data)
        else:
            return len(self.slice_ids)

    def __getitem__(
        self, exp_idx: Union[int, slice, Iterable[int]]
    ) -> Union[TCLExperience, TCLStream]:
        """
        Gets a experience given its experience index (or a stream slice given
        the experience order).

        :param exp_idx: An int describing the experience index or an
            iterable/slice object describing a slice of this stream.

        :return: The experience instance associated to the given experience
            index or a sliced stream instance.
        """
        if isinstance(exp_idx, int):
            if exp_idx < len(self):
                if self.slice_ids is None:
                    return self.benchmark.experience_factory(self, exp_idx)
                else:
                    return self.benchmark.experience_factory(
                        self, self.slice_ids[exp_idx]
                    )
            raise IndexError(
                "Experience index out of bounds" + str(int(exp_idx))
            )
        else:
            return self._create_slice(exp_idx)

    def _create_slice(
        self: TGenericScenarioStream,
        exps_slice: Union[int, slice, Iterable[int]],
    ) -> TCLStream:
        """
        Creates a sliced version of this stream.

        In its base version, a shallow copy of this stream is created and
        then its ``slice_ids`` field is adapted.

        :param exps_slice: The slice to use.
        :return: A sliced version of this stream.
        """
        stream_copy = copy.copy(self)
        slice_exps = _get_slice_ids(exps_slice, len(self))

        if self.slice_ids is None:
            stream_copy.slice_ids = slice_exps
        else:
            stream_copy.slice_ids = [self.slice_ids[x] for x in slice_exps]
        return stream_copy

    def drop_previous_experiences(self, to_exp: int) -> None:
        """
        Drop the reference to experiences up to a certain experience ID
        (inclusive).

        This means that any reference to experiences with ID [0, from_exp] will
        be released. By dropping the reference to previous experiences, the
        memory associated with them can be freed, especially the one occupied by
        the dataset. However, if external references to the experience or the
        dataset still exist, dropping previous experiences at the stream level
        will have little to no impact on the memory usage.

        To make sure that the underlying dataset can be freed, make sure that:
        - No reference to previous datasets or experiences are kept in you code;
        - The replay implementation doesn't keep a reference to previous
            datasets (in which case, is better to store a copy of the raw
            tensors instead);
        - The benchmark is being generated using a lazy initializer.

        By dropping previous experiences, those experiences will no longer be
        available in the stream. Trying to access them will result in an
        exception.

        :param to_exp: The ID of the last exp to drop (inclusive). Can be a
            negative number, in which case this method doesn't have any effect.
            Can be greater or equal to the stream length, in which case all
            currently loaded experiences will be dropped.
        :return: None
        """
        self.benchmark.stream_definitions[
            self.name
        ].exps_data.drop_previous_experiences(to_exp)


class LazyStreamClassesInExps(Mapping[str, Sequence[Optional[Set[int]]]]):
    def __init__(self, benchmark: GenericCLScenario):
        self._benchmark = benchmark
        self._default_lcie = LazyClassesInExps(benchmark, stream="train")

    def __len__(self):
        return len(self._benchmark.stream_definitions)

    def __getitem__(self, stream_name_or_exp_id):
        if isinstance(stream_name_or_exp_id, str):
            return LazyClassesInExps(
                self._benchmark, stream=stream_name_or_exp_id
            )

        warnings.warn(
            "Using classes_in_experience[exp_id] is deprecated. "
            "Consider using classes_in_experience[stream_name][exp_id]"
            "instead.",
            stacklevel=2,
        )
        return self._default_lcie[stream_name_or_exp_id]

    def __iter__(self):
        yield from self._benchmark.stream_definitions.keys()


class LazyClassesInExps(Sequence[Optional[Set[int]]]):
    def __init__(self, benchmark: GenericCLScenario, stream: str = "train"):
        self._benchmark = benchmark
        self._stream = stream

    def __len__(self):
        return len(self._benchmark.streams[self._stream])

    def __getitem__(self, exp_id) -> Set[int]:
        return manage_advanced_indexing(
            exp_id,
            self._get_single_exp_classes,
            len(self),
            LazyClassesInExps._slice_collate,
        )

    def __str__(self):
        return (
            "[" + ", ".join([str(self[idx]) for idx in range(len(self))]) + "]"
        )

    def _get_single_exp_classes(self, exp_id):
        b = self._benchmark.stream_definitions[self._stream]
        if not b.is_lazy and exp_id not in b.exps_data.targets_field_sequence:
            raise IndexError
        targets = b.exps_data.targets_field_sequence[exp_id]
        if targets is None:
            return None
        return set(targets)

    @staticmethod
    def _slice_collate(*classes_in_exps: Optional[Set[int]]):
        if any(x is None for x in classes_in_exps):
            return None

        return [list(x) for x in classes_in_exps]


def _get_slice_ids(
    slice_definition: Union[int, slice, Iterable[int]], sliceable_len: int
) -> List[int]:
    # Obtain experiences list from slice object (or any iterable)
    exps_list: List[int]
    if isinstance(slice_definition, slice):
        exps_list = list(range(*slice_definition.indices(sliceable_len)))
    elif isinstance(slice_definition, int):
        exps_list = [slice_definition]
    elif (
        hasattr(slice_definition, "shape")
        and len(getattr(slice_definition, "shape")) == 0
    ):
        exps_list = [int(slice_definition)]
    else:
        exps_list = list(slice_definition)

    # Check experience id(s) boundaries
    if max(exps_list) >= sliceable_len:
        raise IndexError(
            "Experience index out of range: " + str(max(exps_list))
        )

    if min(exps_list) < 0:
        raise IndexError(
            "Experience index out of range: " + str(min(exps_list))
        )

    return exps_list


class AbstractClassificationExperience(
    ClassificationExperience[TGenericCLClassificationScenario, TCLStream], ABC
):
    """
    Definition of a learning experience. A learning experience contains a set of
    patterns which has become available at a particular time instant. The
    content and size of an Experience is defined by the specific benchmark that
    creates the experience.

    For instance, an experience of a New Classes scenario will contain all
    patterns belonging to a subset of classes of the original training set. An
    experience of a New Instance scenario will contain patterns from previously
    seen classes.
    """

    def __init__(
        self,
        origin_stream: TCLStream,
        current_experience: int,
        classes_in_this_exp: Sequence[int],
        previous_classes: Sequence[int],
        classes_seen_so_far: Sequence[int],
        future_classes: Optional[Sequence[int]],
    ):
        """
        Creates an instance of the abstract experience given the benchmark
        stream, the current experience ID and data about the classes timeline.

        :param origin_stream: The stream from which this experience was
            obtained.
        :param current_experience: The current experience ID, as an integer.
        :param classes_in_this_exp: The list of classes in this experience.
        :param previous_classes: The list of classes in previous experiences.
        :param classes_seen_so_far: List of classes of current and previous
            experiences.
        :param future_classes: The list of classes of next experiences.
        """

        self.origin_stream: TCLStream = origin_stream

        # benchmark keeps a reference to the base benchmark
        self.benchmark: TCLScenario = origin_stream.benchmark

        # current_experience is usually an incremental, 0-indexed, value used to
        # keep track of the current batch/task.
        self.current_experience: int = current_experience

        self.classes_in_this_experience: Sequence[int] = classes_in_this_exp
        """ The list of classes in this experience """

        self.previous_classes: Sequence[int] = previous_classes
        """ The list of classes in previous experiences """

        self.classes_seen_so_far: Sequence[int] = classes_seen_so_far
        """ List of classes of current and previous experiences """

        self.future_classes: Optional[Sequence[int]] = future_classes
        """ The list of classes of next experiences """

    @property
    def task_label(self) -> int:
        """
        The task label. This value will never have value "None". However,
        for scenarios that don't produce task labels a placeholder value like 0
        is usually set. Beware that this field is meant as a shortcut to obtain
        a unique task label: it assumes that only patterns labeled with a
        single task label are present. If this experience contains patterns from
        multiple tasks, accessing this property will result in an exception.
        """
        if len(self.task_labels) != 1:
            raise ValueError(
                "The task_label property can only be accessed "
                "when the experience contains a single task label"
            )

        return self.task_labels[0]


class GenericClassificationExperience(
    AbstractClassificationExperience[
        TGenericCLClassificationScenario,
        ClassificationStream[
            TGenericClassificationExperience, TGenericCLClassificationScenario
        ],
    ]
):
    """
    Definition of a learning experience based on a :class:`GenericCLScenario`
    instance.

    This experience implementation uses the generic experience-patterns
    assignment defined in the :class:`GenericCLScenario` instance. Instances of
    this class are usually obtained from a benchmark stream.
    """

    def __init__(
        self: TGenericClassificationExperience,
        origin_stream: ClassificationStream[
            TGenericClassificationExperience, TGenericCLClassificationScenario
        ],
        current_experience: int,
    ):
        """
        Creates an instance of a generic experience given the stream from this
        experience was taken and the current experience ID.

        :param origin_stream: The stream from which this experience was
            obtained.
        :param current_experience: The current experience ID, as an integer.
        """
        self.dataset: ClassificationDataset = (
            origin_stream.benchmark.stream_definitions[
                origin_stream.name
            ].exps_data[current_experience]
        )

        (
            classes_in_this_exp,
            previous_classes,
            classes_seen_so_far,
            future_classes,
        ) = origin_stream.benchmark.get_classes_timeline(
            current_experience, stream=origin_stream.name
        )

        super(GenericClassificationExperience, self).__init__(
            origin_stream,
            current_experience,
            classes_in_this_exp,
            previous_classes,
            classes_seen_so_far,
            future_classes,
        )

    def _get_stream_def(self):
        return self.benchmark.stream_definitions[self.origin_stream.name]

    @property
    def task_labels(self) -> List[int]:
        stream_def = self._get_stream_def()
        return list(stream_def.exps_task_labels[self.current_experience])


__all__ = [
    "StreamUserDef",
    "TStreamUserDef",
    "TStreamsUserDict",
    "StreamDef",
    "TStreamsDict",
    "TGenericCLClassificationScenario",
    "GenericCLScenario",
    "ClassificationStream",
    "AbstractClassificationExperience",
    "GenericClassificationExperience",
]
