import argparse
import random
from collections import defaultdict

from datasets import load_dataset

import os
from pathlib import Path
from datasets import ClassLabel


RNG_SEED = 42
ROOT = Path(os.getenv("HF_HOME")) / "hub"
ORG = "SmolVEncoder"

def _collect_dict_first_value_in_list(dict, keys):
    """Collect the first value from a list of keys in a dictionary.
    Returns the first value found for any of the keys, or None if none are found.
    """
    for key in keys:
        if key in dict:
            return dict[key]
    return None

def get_txt_classnames(ds_name: str, root: Path = ROOT) -> list[str]:
    """Return the first classnames.txt found (depth-first)."""
    dataset_path = root / f"datasets--{ds_name.replace('/', '--')}"
    print(f"Searching for classnames.txt in {dataset_path}")
    for root, _dirs, files in os.walk(dataset_path):
        if "classnames.txt" in files:
            return ClassLabel(names_file=Path(root) / "classnames.txt")
    return None

def get_cls_map(dataset, dataset_name, root: Path = ROOT) -> ClassLabel:
    """Return a ClassLabel object for the dataset, if available."""
    if "cls" in dataset.features and isinstance(dataset.features["cls"], ClassLabel):
        print(f"Using existing ClassLabel feature from dataset {dataset.info.dataset_name}")
        return dataset.features["cls"]
    
    # Check if the dataset has a class label mapping file
    cls_map = get_txt_classnames(dataset_name, root)
    if cls_map is not None:
        print(f"Found class label mapping for {dataset.info.dataset_name}")
        return cls_map

    print(f"No class label mapping found for {dataset.info.dataset_name}")
    return None

def subsample(dataset, num_samples: int, balanced: bool = False):
    """Return a subset of *dataset* with *num_samples* rows.

    Parameters
    ----------
    dataset : datasets.Dataset
        The source dataset. Must contain a ``cls`` column with class labels.
    num_samples : int
        Desired total number of examples in the returned subset.
    balanced : bool, optional
        If ``True``, sample the same number of examples from every class
        (up to the minimum class size). Otherwise, sample proportionally to
        the class distribution in *dataset*.

    Returns
    -------
    datasets.Dataset
        A new dataset containing the sampled examples, *shuffled*.
    """
    if num_samples <= 0:
        raise ValueError("num_samples must be a positive integer")

    # Build mapping {class_label: [row_indices]}
    label_to_indices = defaultdict(list)
    for i, label in enumerate(dataset["cls"]):
        label_to_indices[label].append(i)

    classes = list(label_to_indices)
    rng = random.Random(RNG_SEED)

    selected_indices = []

    if balanced:
        # How many per class? At most the minimum available examples per class.
        per_class = num_samples // len(classes)
        remainder = num_samples % len(classes)

        for lbl in classes:
            available = label_to_indices[lbl]
            take = min(per_class, len(available))
            selected_indices.extend(rng.sample(available, take))

        if remainder:
            # Collect the pool of still‑unused examples from classes that have
            # leftovers, then sample the remainder at random.
            pool = [idx for lbl in classes for idx in label_to_indices[lbl]
                    if idx not in selected_indices]
            if len(pool) < remainder:
                raise ValueError(
                    "Not enough samples to satisfy the balanced request. Try a smaller num_samples.")
            selected_indices.extend(rng.sample(pool, remainder))
    else:
        # Proportional sampling (stratified by class).
        total = len(dataset)
        # First round: round() may overshoot due to ties – collect separately.
        for lbl, idxs in label_to_indices.items():
            prop = len(idxs) / total
            take = int(round(prop * num_samples))
            take = min(take, len(idxs))
            selected_indices.extend(rng.sample(idxs, take))

        # Adjust for rounding error.
        if len(selected_indices) < num_samples:
            deficit = num_samples - len(selected_indices)
            pool = [i for i in range(total) if i not in selected_indices]
            selected_indices.extend(rng.sample(pool, deficit))
        elif len(selected_indices) > num_samples:
            selected_indices = rng.sample(selected_indices, num_samples)

    rng.shuffle(selected_indices)
    return dataset.select(selected_indices)

def format_dataset_example(example, cls_map = None):
    """
    Format a single example from the CLIP dataset.

    Args:
        example (dict): A dictionary containing the example data.
            It must contain an image in one of the following formats: jpg, webp, or png.
            It must also contain a class label under the key "cls".
        cls_map (object, optional): A class mapping object that can convert integer class labels to string labels.
    
    Returns:
        dict: A dictionary with keys "image" and "class".
            The "image" key contains the image data, and the "class" key contains the class label.
    """
    image = _collect_dict_first_value_in_list(example, ["jpg", "webp", "png", "image"])
    if image is None:
        raise ValueError("Example must contain an image in jpg, webp, or png format.")
    
    cls = _collect_dict_first_value_in_list(example, ["cls", "label"])
    if cls is None:
        raise ValueError("Example must contain a class label under the key 'cls' or 'label'.")
    
    if hasattr(cls_map, "int2str") and isinstance(cls, int):
        cls = cls_map.int2str(cls)

    return {
        "image": image,
        "label": cls,
    }


def main():
    parser = argparse.ArgumentParser(
        description="Subsample a Hugging Face dataset, optionally balancing class counts.")
    parser.add_argument("dataset_name", help="Dataset name or canonical path on the Hub (e.g. 'cifar10').")
    parser.add_argument("--num_samples", type=int, default=None, help="Total number of samples to keep.")
    parser.add_argument("--split", default="train", help="Split to load (default: train).")
    parser.add_argument("--balanced_classes", action="store_true",
                        help="If set, sample the same number of examples for every class.")
    args = parser.parse_args()

    print(f"Loading {args.dataset_name}[{args.split}] …")
    dataset = load_dataset(args.dataset_name, split=args.split)

    if args.num_samples is not None:
        if isinstance(args.num_samples, float):
            args.num_samples = int(args.num_samples * len(dataset))
        subset = subsample(dataset, args.num_samples, args.balanced_classes)
    else:
        subset = dataset

    cls_map = get_cls_map(subset, args.dataset_name)
    formatted_subset = subset.map(format_dataset_example, fn_kwargs={"cls_map": cls_map})
    formatted_subset = formatted_subset.select_columns(["image", "label"])

    ds_name = args.dataset_name.split("/")[-1].split("_", 1)[-1].split("-", 1)[-1]
    repo_name =  f"{ORG}/{ds_name}"
    if args.num_samples:
        repo_name += f"-sub{args.num_samples}"
    if args.balanced_classes:
        repo_name += "-balanced"
    try:
        print(f"Pushing subset to Hub as '{repo_name}' …")
        formatted_subset.push_to_hub(repo_name)
    except Exception as e:
        print(f"Error pushing to Hub: {e}")
        print("Pushing to Hub failed, saving locally instead.")
        formatted_subset.save_to_disk(f"./datasets/{repo_name.split('/', 1)[-1]}")

if __name__ == "__main__":
    main()
