from datasets import load_dataset, Dataset
from src.data.metadata_utils import add_example_metadata, deserialized_ex_metadata
from src.utils.logging_utils import get_logger
from src.data.utils import DsColumn, MetaAttr

from copy import deepcopy

logger = get_logger(__name__)

def merge_tagging_metadata(*, dataset_name: str, other_revision: str, other_tag_key: str, num_proc: int = 1, **kwargs):
    latest_ds = load_dataset(dataset_name) 

    if other_tag_key in deserialized_ex_metadata(latest_ds["train"][0])["_meta"]:
        logger.warning(f"Tag key {other_tag_key} already exists in the dataset. Skipping merge.")
        return latest_ds

    other_ds = load_dataset(dataset_name, revision=other_revision)
    if other_tag_key not in deserialized_ex_metadata(other_ds["train"][0])["_meta"]:
        raise ValueError(f"Tag key {other_tag_key} not found in the other dataset.")

    def _merge_ex_tagging_metadata(ex, idx):
        other_ex = other_ds["train"][idx]
        assert ex["_id"] == other_ex["_id"] and ex["_ds_id"] == other_ex["_ds_id"]
        other_meta = deserialized_ex_metadata(other_ex)["_meta"]
        new_metadata = {other_tag_key: deepcopy(other_meta[other_tag_key])} 
        ex = add_example_metadata(ex, new_metadata=new_metadata)
        return ex


    latest_ds = latest_ds.map(
        _merge_ex_tagging_metadata,
        num_proc=num_proc,
        with_indices=True,
        desc="Adding metadata to latest dataset"
    ) 

    return latest_ds


def fix_ds_id(*, dataset: Dataset, **kwargs):
    ds_ids = sorted(list(set(dataset[DsColumn.DS_ID])))

    def _fix_ds_id(ex):
        ds_name = ex[DsColumn.DS_ID]
        assert isinstance(ds_name, str)
        ex[DsColumn.DS_ID] = ds_ids.index(ds_name)
        ex = add_example_metadata(ex, new_metadata={MetaAttr.DS_NAME: ds_name})
        return ex

    dataset = dataset.map(
        _fix_ds_id,
        desc="Fixing ds_id"
    )

    return dataset