from pathlib import Path
from datasets import Dataset, load_dataset
from src.data.metadata_utils import add_example_metadata
from src.utils.logging_utils import get_logger
from enum import StrEnum

logger = get_logger(__name__)

class DeitaAttr(StrEnum):
    DEITA_META = "_deita"
    COMPLEXITY = "_complexity_scores"
    QUALITY = "_quality_scores"

def create_deita_ds(*, dataset: Dataset, deita_data_file: Path, num_proc: int = 1, **kwargs):

    deita_data = load_dataset("json", data_files=deita_data_file, split="train")
    logger.info(f"Loaded DEITA data from {deita_data_file} with {len(deita_data)} examples")
    _map = {e["id"]: e for e in deita_data}

    _ds = dataset

    def _add_deita_scores(ex):
        if ex["id"] not in _map:
            logger.warning(f"Example {ex['id']} not found in DEITA data")
            complexity_scores, quality_scores = [], []
        else:
            complexity_scores = _map[ex["id"]]["complexity_scores"]
            quality_scores = _map[ex["id"]]["quality_scores"]
        new_meta = {
            DeitaAttr.DEITA_META: {
                DeitaAttr.COMPLEXITY: complexity_scores,
                DeitaAttr.QUALITY: quality_scores
            }
        }
        ex = add_example_metadata(ex, new_metadata=new_meta)
        return ex

    _ds = _ds.map(_add_deita_scores, desc="Adding DEITA scores", num_proc=num_proc)

    assert (len(_ds) - len(deita_data)) < 10, f"Dataset size mismatch: {len(_ds)} != {len(deita_data)}" 
    return _ds


def create_deita_subset(*, dataset: Dataset, subset_size: float, deita_data_file: Path, num_proc: int = 1, **kwargs):
    deita_data = load_dataset("json", data_files=deita_data_file, split="train")
    logger.info(f"Loaded DEITA data from {deita_data_file} with {len(deita_data)} examples")
    _map = {e["id"]: e for e in deita_data}

    _ds = dataset.filter(lambda x: x["id"] in _map, desc="Filtering dataset", num_proc=num_proc)

    assert len(_ds) == int(subset_size * len(dataset)) 
    logger.info(f"Created DEITA subset with {len(_ds)} examples")

    return _ds
