from datasets import Dataset
from enum import Enum
from src.utils.logging_utils import get_logger
from src.data.utils import CustomColName

logger = get_logger(name=__name__)


class InstagColName(str, Enum):
    INSTAG_META_NAME = "_instag"
    INSTAG_RAW = "_instag_raw"
    INSTAG_TAGS = "_instags"
    INSTAG_COMPLEXITY = "_instag_complexity"
    INSTAG_DIVERSITY = "_instag_diversity"


def diversity_first_diverse_sampling(*, dataset: Dataset, subset_size: int | float, **kwargs) -> Dataset:
    """Implements Algorithm 1 from the InsTag paper."""

    def _get_col_name(col_name) -> str:
        return f"{InstagColName.INSTAG_META_NAME.value}.{col_name.value}"

    if isinstance(subset_size, float):
        subset_size = int(subset_size * len(dataset))

    assert (
        f"{_get_col_name(InstagColName.INSTAG_COMPLEXITY)}" in dataset.column_names
    ), f"Column {InstagColName.INSTAG_COMPLEXITY.value} not found in dataset"

    # Sort by complexity (i.e., tag count) descending
    dataset = dataset.sort(
        f"{_get_col_name(InstagColName.INSTAG_COMPLEXITY)}",
        reverse=True,
    )

    subset = []

    indices = list(range(len(dataset)))
    while len(subset) < subset_size and len(dataset) > 0:
        # Reset the (temporary) tag set for this pass
        T_s_B = set()  # Same notation as in the paper
        added_indices = []  # Avoid re-adding the same row multiple times
        for i in indices:
            example = dataset[i]
            Tq = set(example[_get_col_name(InstagColName.INSTAG_TAGS)])

            # Check if Tq introduces any *new* tags we haven't already added this pass
            if len(T_s_B.union(Tq)) > len(T_s_B):
                subset.append(example)
                T_s_B = T_s_B.union(Tq)
                added_indices.append(i)

                # If we have reached N, stop.
                if len(subset) == subset_size:
                    break

        # If in this pass we didn't add anything, we cannot progress further.
        if not added_indices:
            break

        # Remove added indices
        indices = [i for i in indices if i not in added_indices]

    subset = Dataset.from_list(subset)
    assert len(subset) == subset_size, f"Expected {subset_size} samples, got {len(subset)}"
    logger.info(f"Generated Subset w InsTag:\n{subset}")
    return subset
