from pathlib import Path
import yaml
from typing import Any
from src.data.raw_data_loader import RawDataLoader
from src.data.utils import DatasetConfig
from src.utils.logging_utils import get_logger
from datasets import Dataset, DatasetDict, ClassLabel, concatenate_datasets
from collections import Counter
from enum import StrEnum


logger = get_logger(name=__name__)
class DsColumn(StrEnum):
    ID = "_id"
    DS_ID = "_ds_id"
    IDX = "_idx"
    META = "_meta"


class TokenMetaAttr(StrEnum):
    TOKEN_META = "_token_meta"
    TOKEN_CNT = "_token_cnt"
    LABEL_TOKEN_CNT = "_label_token_cnt"
    PROMPT_TOKEN_CNT = "_prompt_token_cnt"
class SelectionMetric(StrEnum):
    PPL = "ppl"
    LOSS = "loss"

class MetaAttr(StrEnum):
    DS_NAME: str = "_ds_name"




def parse_args(func):
    def wrapper(*, dataset: Dataset | DatasetDict, **kwargs) -> Any:
        split_name = kwargs.get("split_name")
        if isinstance(dataset, DatasetDict):
            assert split_name is not None, "split_name must be provided when dataset is a DatasetDict"
            _ds = dataset[split_name]
            using_ds_dict = True
        else:
            _ds = dataset
            using_ds_dict = False

        result = func(dataset=_ds, **kwargs)

        if using_ds_dict:
            dataset[split_name] = result

        return result

    return wrapper


def load_dataset(
    *,
    raw_dataset_loader: RawDataLoader,
    dataset_config: DatasetConfig,
    split: str | None = None,
    limit: int | None = None,
    revision: str | None = None,
    **kwargs,
) -> Any:
    logger.info(f"Loading dataset: {dataset_config.name_or_path}")
    ds = raw_dataset_loader.load(dataset_config=dataset_config, revision=revision)
    if limit is not None:
        for k in ds.keys():
            ds[k] = ds[k].select(range(limit))
            logger.warning(f"Limiting dataset to {limit} examples!!. DEBUG ONLY")

    if split is not None:
        ds = ds[split]

    logger.info(f"Loaded dataset name={dataset_config.name_or_path} split={split}:\n{ds}")
    return ds



@parse_args
def select_subset(*, selection_method: str, **kwargs) -> Any:
    if selection_method == "rand":
        return select_random_subset(**kwargs)
    if selection_method == "strat":
        return select_stratified_subset(**kwargs)
    if selection_method == "uni":
        return select_uniform_subset(**kwargs)

    raise ValueError(f"Invalid method: {selection_method}")


@parse_args
def select_random_subset(
    *,
    dataset: Dataset,
    subset_size: int | float,
    data_seed: int = 42,
    **kwargs,
) -> Any:
    if isinstance(subset_size, float):
        subset_size = int(len(dataset) * subset_size)

    logger.info(f"Selecting random subset of {subset_size} examples")

    _ds = dataset.shuffle(seed=data_seed)

    subset = _ds.select(range(subset_size))
    assert len(subset) == subset_size, f"Expected {subset_size} examples, got {len(subset)}"

    return subset


@parse_args
def select_stratified_subset(
    *,
    dataset: Dataset,
    subset_size: int | float,
    col_name: str,
    data_seed: int = 42,
    **kwargs,
) -> Dataset | DatasetDict:
    if isinstance(subset_size, float):
        subset_size = int(len(dataset) * subset_size)

    assert len(dataset) >= subset_size, f"Dataset has {len(dataset)} examples, requested {subset_size}"

    logger.info(f"Selecting stratified subset of {subset_size} examples")

    labels = dataset[col_name]
    unique_labels = list(set(labels))

    _ds = dataset.add_column("_tmp_col", labels)
    _ds = _ds.cast_column("_tmp_col", feature=ClassLabel(names=unique_labels))  # type: ignore

    subset = _ds.train_test_split(
        test_size=subset_size,
        seed=data_seed,
        stratify_by_column="_tmp_col",
        shuffle=True,
    )["test"]
    subset = subset.remove_columns("_tmp_col")

    per_class_counts = Counter(subset[col_name])
    logger.info(f"Selected stratified subset of {subset_size} examples: {per_class_counts}")

    assert len(subset) == subset_size, f"Expected {subset_size} examples, got {len(subset)}"

    return subset


@parse_args
def select_uniform_subset(
    *,
    dataset: Dataset,
    subset_size: int | float,
    col_name: str,
    data_seed: int = 42,
    **kwargs,
) -> Dataset:
    if isinstance(subset_size, float):
        subset_size = int(len(dataset) * subset_size)

    logger.info(f"Selecting uniform subset of {subset_size} examples")

    _ds = dataset.shuffle(seed=data_seed)

    labels = _ds[col_name]
    unique_labels = list(set(labels))

    num_ex_per_label = subset_size // len(unique_labels)
    subset_parts = []
    missing_cnt = 0
    for label in unique_labels:
        label_subset = _ds.filter(lambda x: x[col_name] == label)
        label_subset = label_subset.shuffle(seed=data_seed)

        if len(label_subset) < num_ex_per_label:
            missing_cnt += num_ex_per_label - len(label_subset)
            label_subset = label_subset.select(range(len(label_subset)))
        else:
            label_subset = label_subset.select(range(num_ex_per_label))
        subset_parts.append(label_subset)

    if missing_cnt > 0:
        logger.warning(
            f"Missing {missing_cnt} examples to create uniform subset. Using random selection for the missing."
        )
        subset_parts.append(_ds.select(range(missing_cnt)))

    subset = concatenate_datasets(subset_parts)
    assert abs(subset_size - len(subset)) < 2*len(unique_labels), f"Expected {subset_size} examples, got {len(subset)}"

    per_class_counts = Counter(subset[col_name])
    logger.info(f"Selected uniform subset of {subset_size} examples: {per_class_counts}")

    return subset


@parse_args
def filter_dataset_by_col_value(
    *,
    dataset: Dataset,
    col_name: str,
    col_value: Any,
    **kwargs,
) -> Dataset:
    """Keep only the examples where the value of the specified column matches the given value."""

    def _filter_fn(x):
        if isinstance(col_value, list):
            return x[col_name] in col_value
        return x[col_name] == col_value

    _ds = dataset.filter(_filter_fn, desc=f"Filtering dataset by {col_name}={col_value}")

    logger.info(f"Filtered dataset by {col_name}={col_value}: \n{_ds}")

    return _ds


@parse_args
def filter_long_examples(
    *,
    dataset: Dataset,
    max_seq_length: int,
    num_proc: int = 1,
    tokenizer_name: str = "philschmid/meta-llama-3-tokenizer",
    add_token_cnt: bool = True,
) -> Dataset:
    from transformers import AutoTokenizer
    from src.data.tulu_3_sft_utils import encode_sft_dataset
    from src.data.metadata_utils import add_example_metadata

    cols_to_keep = dataset.column_names

    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    tokenized_dataset = encode_sft_dataset(
        dataset=dataset,
        tokenizer=tokenizer,
        max_len=None,  # No truncation
        num_proc=num_proc,
        # overwrite_cache=True,
        cols_to_keep=cols_to_keep,
    )

    filtered_ds = tokenized_dataset.filter(
        lambda x: x["input_ids"].shape[0] <= max_seq_length,
        num_proc=num_proc,
        desc=f"Selecting examples with length <= {max_seq_length}",
    )

    if add_token_cnt:

        def _add_token_meta(example):
            # There is no padding no truncation
            overall_token_cnt = example["input_ids"].shape[0]
            labels_token_cnt = (example["labels"] != -100).sum().item()
            prompt_token_cnt = overall_token_cnt - labels_token_cnt
            new_metadata = {
                TokenMetaAttr.TOKEN_META: {
                    TokenMetaAttr.TOKEN_CNT: overall_token_cnt,
                    TokenMetaAttr.LABEL_TOKEN_CNT: labels_token_cnt,
                    TokenMetaAttr.PROMPT_TOKEN_CNT: prompt_token_cnt,
                }
            }
            example = add_example_metadata(example, new_metadata=new_metadata)
            return example

        filtered_ds = filtered_ds.map(
            _add_token_meta,
            num_proc=num_proc,
            desc="Adding token metadata",
        )
        cols_to_keep.append(DsColumn.META)

    filtered_ds = filtered_ds.remove_columns([c for c in filtered_ds.column_names if c not in cols_to_keep])
    filtered_ds.set_format(type=None)  # Reset format to default

    assert len(filtered_ds) > 0, "No examples left after filtering"

    logger.info(f"Filtered dataset (max len{max_seq_length}): {filtered_ds}")
    return filtered_ds


def create_ds_cfg(
    *,
    dataset,
    dataset_name: str | None,
    dataset_id: str | None,
    train_split_name: str = "train",
    eval_split_name: str | None = None,
    config_template_file: Path = Path("config/train/dataset/_dataset.yaml"),
    dataset_name_template: str | None = None,
    dataset_id_template: str | None = None,
    hf_space_name: str | None = None,
    **kwargs,
) -> None:
    if dataset_name_template:
        assert dataset_name is None, "Only one of dataset_name or dataset_name_template should be provided"
        dataset_name = dataset_name_template.format(**kwargs)
        assert "{" not in dataset_name, "All template keys should be filled"

    if dataset_id_template:
        assert dataset_id is None, "Only one of dataset_id or dataset_id_template should be provided"
        dataset_id = dataset_id_template.format(**kwargs)
        assert "{" not in dataset_id, "All template keys should be filled"

    if hf_space_name:
        dataset_name = f"{hf_space_name}/{dataset_name}"


    with config_template_file.open("r") as f:
        ds_cfg = yaml.safe_load(f)

    ds_cfg["name_or_path"] = dataset_name
    ds_cfg["eval_split_name"] = eval_split_name
    ds_cfg["train_split_name"] = train_split_name
    ds_cfg["meta"]["id_suffix"] = dataset_id

    out_file = config_template_file.parent / f"{dataset_id.replace('.', '-')}.yaml"
    with out_file.open("w") as f:
        yaml.dump(ds_cfg, f)

    logger.info(f"Dataset config created: {out_file}")

    return dataset


@parse_args
def add_example_uids(
    *,
    dataset: Dataset,
    ds_ref_col: str,
    num_proc: int = 1,
    **kwargs,
) -> Dataset:
    labels = dataset[ds_ref_col]
    unique_labels = list(set(labels))

    def _add_uids(example):
        example[DsColumn.ID] = unique_labels.index(example[ds_ref_col])
        example[DsColumn.DS_ID] = example[ds_ref_col]
        return example

    _ds = dataset.map(
        _add_uids,
        num_proc=num_proc,
        desc="Adding unique UIDs to examples",
    )

    logger.info(f"Added unique UIDs to examples: columns={_ds.column_names}")

    return _ds
