import json

from src.data.utils import CustomColName
from src.utils.logging_utils import get_logger
from src.data.utils import DsColumn
from copy import deepcopy

logger = get_logger(name=__name__)


def deserialized_ex_metadata(example):
    metadata = json.loads(example[CustomColName.META.value])
    example[CustomColName.META.value] = metadata
    return example


def deserialize_metadata(*, dataset, **kwargs):
    logger.info("Deserializing metadata")
    ds = dataset.map(deserialized_ex_metadata, load_from_cache_file=False)
    return ds


def serialized_ex_metadata(example):
    metadata = example[DsColumn.META]
    if not isinstance(metadata, str):
        example[DsColumn.META] = json.dumps(metadata)
    return example


def serialize_ds_metadata(*, dataset, num_proc: int = 1, **kwargs):
    logger.info("Serializing metadata")
    ds = dataset.map(
        serialized_ex_metadata,
        num_proc=num_proc,
        desc="Serializing metadata",
    )
    return ds


def flatten_example_metadata(example, *, meta_join_char: str | None = None):
    if isinstance(example[CustomColName.META.value], str):
        example = deserialized_ex_metadata(example)

    metadata = example[CustomColName.META.value]

    seen_keys = set()
    for k, v in metadata.items():
        if isinstance(v, dict):
            for sub_k, sub_v in v.items():
                if sub_k in seen_keys:
                    raise ValueError(f"Duplicate key in metadata: {sub_k}")
                if meta_join_char is not None:
                    sub_k = f"{k}{meta_join_char}{sub_k}"
                example[sub_k] = sub_v
                seen_keys.add(sub_k)
        else:
            if k in seen_keys:
                raise ValueError(f"Duplicate key in metadata: {k}")
            example[k] = v
            seen_keys.add(k)

    del example[CustomColName.META.value]
    return example


def flatten_metadata(
    *,
    dataset,
    num_proc: int = 1,
    meta_join_char: str | None = None,
    meta_to_keep: list[str] | None = None,
    **kwargs,
):
    original_cols = dataset.column_names
    ds = dataset.map(
        flatten_example_metadata,
        num_proc=num_proc,
        fn_kwargs={"meta_join_char": meta_join_char},
        desc="Flattening metadata",
    )

    if meta_to_keep is not None:
        cols_to_keep = original_cols + meta_to_keep
        ds = ds.remove_columns([col for col in ds.column_names if col not in cols_to_keep]) 

    logger.info(f"Flattened dataset columns: {ds.column_names}")
    return ds


def add_example_metadata(example, *, new_metadata: dict):
    metadata = example.get(DsColumn.META, {})
    if isinstance(metadata, str):
        is_serialized = True
        metadata = json.loads(metadata)
    else:
        is_serialized = False

    example[DsColumn.META] = {**metadata, **new_metadata}

    if is_serialized:
        example[DsColumn.META] = json.dumps(example[DsColumn.META])

    return example
