"""Download and preprocess the Python Edu dataset from the Hugging Face Hub.

This script downloads the Python Edu dataset from the Hugging Face Hub, processes the
dataset by downloading the contents of each blob ID from an S3 bucket, and saves the
processed dataset to disk.
"""

import argparse
import gzip
import operator
from logging import INFO
from typing import Any

import boto3
from botocore import UNSIGNED
from botocore.config import Config
from botocore.exceptions import ClientError
from datasets import load_dataset
from flwr.common import log


def download_contents(blob_id: str) -> dict[str, Any]:
    """Download and decompress the content of a blob from an S3 bucket.

    This function retrieves the content of a specified blob ID from an S3 bucket,
    decompresses it using gzip, and returns the content as a string. If the blob is not
    found, it logs an informational message and returns an empty string with a
    download_success flag set to False.

    Parameters
    ----------
    blob_id : str
        The ID of the blob to download from the S3 bucket.

    Returns
    -------
    dict[str, Any]
        A dictionary containing the following keys:
        - "text" (str): The decompressed content of the blob.
        - "download_success" (bool): A flag indicating whether the download was
          successful.

    Example
    -------
    >>> content = download_contents("example_blob_id")
    >>> print(content["text"])

    Raises
    ------
    ClientError
        If there is an error in retrieving the blob from the S3 bucket, other than the
        blob not being found.

    """
    key = f"content/{blob_id}"
    try:
        obj = s3.get_object(Bucket=bucket_name, Key=key)
        with gzip.GzipFile(fileobj=obj["Body"]) as fin:
            content = fin.read().decode("utf-8", errors="ignore")
    except ClientError as e:
        if e.response["Error"]["Code"] == "NoSuchKey":
            log(INFO, f"File not found: {key}")
            return {"text": "", "download_success": False}
        raise
    else:
        return {"text": content, "download_success": True}


def main(output_dir: str, num_proc: int) -> None:
    """Download and preprocess the Python Edu dataset from the Hugging Face Hub.

    This function loads the Python Edu dataset from the Hugging Face Hub, processes the
    dataset by downloading the contents of each blob ID from an S3 bucket, and saves the
    processed dataset to the specified output directory.

    Parameters
    ----------
    output_dir : str
        The directory where the processed dataset will be saved.
    num_proc : int
        The number of processes to use for data processing.

    Example
    -------
    >>> main("/nfs-share-old/datasets_repo/python_edu", 192)

    """
    # Load dataset from Hugging Face Hub
    ds = load_dataset(
        "HuggingFaceTB/smollm-corpus", "python-edu", split="train", num_proc=num_proc,
    )
    ds = ds.map(download_contents, input_columns="blob_id", num_proc=num_proc)  # type: ignore[reportCallIssue]

    # Filter out failed downloads
    ds = ds.filter(operator.itemgetter("download_success"))

    # Keep only `text` column
    all_columns = ds.column_names
    assert all_columns is not None
    assert "text" in all_columns
    columns_to_remove = [col for col in all_columns if col != "text"]
    ds = ds.remove_columns(column_names=columns_to_remove)

    # Save the processed dataset to disk
    ds.save_to_disk(output_dir)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Download and preprocess the Python Edu dataset from the HF Hub.",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        required=True,
        help="The directory where the processed dataset will be saved.",
    )
    parser.add_argument(
        "--num_proc",
        type=int,
        default=1,
        help="The number of processes to use for data processing.",
    )
    args = parser.parse_args()

    # Initialize S3 client
    s3 = boto3.client(
        "s3", region_name="us-west-2", config=Config(signature_version=UNSIGNED),
    )
    bucket_name = "softwareheritage"

    main(args.output_dir, args.num_proc)
