# Copyright 2024
# [ANONYMIZED_INSTITUTION],
# [ANONYMIZED_FACULTY],
# [ANONYMIZED_DEPARTMENT]
#
# Authors:
# AUTHOR_1 (author1@example.com)
# AUTHOR_2 (author2@example.com)
#
# Code generation tools and workflows:
# First versions of this code were potentially generated
# with the help of AI writing assistants including
# GitHub Copilot, ChatGPT, Microsoft Copilot, Google Gemini.
# Afterwards, the generated segments were manually reviewed and edited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Helper function for truncating a dataset to a specified number of samples."""

import logging

import datasets

default_logger: logging.Logger = logging.getLogger(
    name=__name__,
)


def truncate_dataset_with_maximum_the_actual_number_of_samples(
    dataset: datasets.Dataset,
    number_of_samples: int,
    logger: logging.Logger = default_logger,
) -> datasets.Dataset:
    """Truncate the dataset to the specified number of samples.

    If `number_of_samples` is -1, all samples are used.
    If `number_of_samples` is greater than 0, the specified number of samples is used,
    unless the dataset has fewer samples, in which case all samples are used.
    """
    if number_of_samples == -1:
        # Use all samples
        subsampled_dataset: datasets.Dataset = dataset
    elif number_of_samples > 0:
        if number_of_samples > len(dataset):
            logger.warning(
                msg=f"Requested {number_of_samples = } samples, "  # noqa: G004 - low overhead
                f"but dataset only has {len(dataset) = } samples.",
            )

            actual_number_of_samples: int = len(dataset)

            logger.info(
                msg=f"Using {actual_number_of_samples = } samples instead.",  # noqa: G004 - low overhead
            )
        else:
            actual_number_of_samples: int = number_of_samples

        # Use only the specified number of samples
        subsampled_dataset = dataset.select(
            indices=range(actual_number_of_samples),
        )
    else:
        msg: str = f"Expected {number_of_samples = } to be -1 or a positive integer"
        raise ValueError(
            msg,
        )

    return subsampled_dataset
