from collections import Counter
from math import ceil
from typing import Callable

import pandas as pd
from datasets import ClassLabel, DatasetDict
from pydantic import model_validator

from src.data.raw_data_loader import RawDataLoader
from src.data.utils import DEF_COLS, CustomColName, DatasetConfig, update_example_metadata
from src.utils.logging_utils import get_logger
from src.utils.pipeline_utils import PipelineStep

logger = get_logger(name=__name__)


logger = get_logger(name=__name__)


class ParseDataset(PipelineStep):
    class Config(PipelineStep.Config):
        num_proc: int = 1
        push_to_hf_hub: bool = True
        chars_per_token: float = 3.6
        hf_token: str | None = None
        custom_parsers: list[Callable] = []
        dataset_id: int | None = None
        cols_to_keep: list[str] = DEF_COLS
        extra_metadata_cols: list[str] = []
        subsets: dict | None = None
        splitting_col: str | None = None
        unknown_subset_id: int = -1
        eval_ratio: float | int | None = None
        eval_src_split: str | None = None
        eval_split_strat_col: str | None = None
        custom_map_fcts: list[dict] = []

        _subset_id_map: dict | None = None

        @model_validator(mode="after")
        def validate(self):
            if self.subsets is not None:
                assert self.splitting_col is not None, "splitting_col must be set if subsets are provided"
                assert self.dataset_id is None

                self._subset_id_map = {}
                for subset_id, subset_info in self.subsets.items():
                    for subset_name in subset_info["names"]:
                        self._subset_id_map[subset_name] = subset_id

            if self.dataset_id is not None:
                assert self.subsets is None, "subsets must be None if main_dataset_id is provided"

            if self.eval_ratio is not None:
                assert self.eval_src_split is not None

            return self

        @property
        def subset_id_map(self):
            return self._subset_id_map

    def __init__(
        self,
        original_dataset: DatasetConfig,
        raw_data_loader: RawDataLoader,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.raw_data_loader = raw_data_loader
        self.original_dataset = original_dataset
        self.config = self.Config(**kwargs)

    def _run(self) -> None:
        ds = self.raw_data_loader.load(dataset_config=self.original_dataset)
        ds = self._parse_examples(ds=ds, ds_name=self.original_dataset.name_or_path)
        ds = self._remove_flagged_examples(ds)

        if self.config.eval_ratio is not None:
            ds = self._create_eval_ds(ds)

        ds = self._remove_unused_columns(ds)

        logger.info(f"Final dataset: {ds}")

        if self.config.push_to_hf_hub:
            self._push_to_hub(
                dataset=ds,
                dataset_name=self.original_dataset.name_or_path.split("/")[-1],
            )

    def _create_eval_ds(self, ds: DatasetDict) -> DatasetDict:
        _tmp = ds[self.config.eval_src_split].train_test_split(
            test_size=self.config.eval_ratio,
            seed=123,
            stratify_by_column=self.config.eval_split_strat_col,
        )

        ds = DatasetDict(train=_tmp["train"], test=_tmp["test"])
        logger.info(f"Created eval ds form {self.config.eval_src_split} with ratio {self.config.eval_ratio}: {ds}")
        return ds

    def _push_to_hub(self, dataset: DatasetDict, dataset_name: str) -> None:
        dataset.push_to_hub(dataset_name, private=True, token=self.config.hf_token)
        logger.info(f"Dataset {dataset_name} pushed to hub")

    def _parse_examples(self, *, ds: DatasetDict, ds_name: str) -> DatasetDict:
        def _parse_ex_chatml_attr(example, **kwargs):
            for i in range(len(example["messages"])):
                content = example["messages"][i]["content"].lstrip()
                example[CustomColName.FLAGGED.value] = True if content.strip() == "" else False
                example["messages"][i]["content"] = example["messages"][i]["content"].lstrip()

            example[CustomColName.NUM_TURNS.value] = len(example["messages"])
            example[CustomColName.NUM_USR_TURNS.value] = len([m for m in example["messages"] if m["role"] == "user"])
            example[CustomColName.NUM_SYS_TURNS.value] = len([m for m in example["messages"] if m["role"] == "system"])
            example[CustomColName.NUM_ASSIST_TURNS.value] = len(
                [m for m in example["messages"] if m["role"] == "assistant"]
            )

            return example

        def _add_content_len_estimate(example, **kwargs):
            example[CustomColName.TOKEN_CNT_ESTIMATE.value] = ceil(
                (len("".join([t["content"] for t in example["messages"]])) / self.config.chars_per_token)
            )
            return example

        def _add_ex_id(example, ex_idx: int, **kwargs):
            example[CustomColName.ID.value] = ex_idx
            return example

        def _add_dataset_id(example, subset_id_map: dict, **kwargs):
            if self.config.dataset_id is not None:
                example[CustomColName.DS_ID.value] = self.config.dataset_id
            else:
                example[CustomColName.DS_ID.value] = subset_id_map.get(
                    example[self.config.splitting_col], self.config.unknown_subset_id
                )
            return example

        def _map_fct(example, ex_idx, **kwargs):
            for map_fct_cfg in self.config.custom_map_fcts:
                if map_fct_cfg["mode"] == "before":
                    example = map_fct_cfg["fct"](example, **kwargs)

            example = _add_ex_id(example, ex_idx, **kwargs)
            example = _parse_ex_chatml_attr(example, **kwargs)
            example = _add_content_len_estimate(example, **kwargs)
            if self.config.splitting_col is not None or self.config.dataset_id is not None:
                example = _add_dataset_id(example, **kwargs)

            for map_fct_cfg in self.config.custom_map_fcts:
                if map_fct_cfg["mode"] == "after":
                    example = map_fct_cfg["fct"](example, **kwargs)

            return example

        ds = ds.map(
            _map_fct,
            with_indices=True,
            num_proc=self.config.num_proc,
            fn_kwargs={"subset_id_map": self.config.subset_id_map},
        )

        ds = self._add_len_buckets(ds)

        logger.info(f"Parse ds examples Done! ({ds_name}):\n{ds}")
        return ds

    def _add_len_buckets(self, ds: DatasetDict) -> DatasetDict:
        # Add bucket column per token count
        bins = [0, 128, 512, 1024, 2048, 4096, 8192, float("inf")]
        labels = list(range(len(bins) - 1))

        buckets_to_remove = []
        for key in ds.keys():
            buckets = pd.cut(pd.Series(ds[key][CustomColName.TOKEN_CNT_ESTIMATE.value]), bins=bins, labels=labels)
            ds[key] = ds[key].add_column(
                CustomColName.TOKEN_CNT_BUCKET.value, buckets.tolist(), feature=ClassLabel(names=labels)
            )  # type: ignore

            # Check if there are buckets with less than 2 examples (not enough for train/test split)
            per_bucket_cnt = Counter(buckets)
            buckets_to_remove = [k for k in per_bucket_cnt.keys() if per_bucket_cnt[k] < 2]
            if buckets_to_remove:
                logger.info(f"Removing buckets with less than 2 examples: {buckets_to_remove}")
                indicies_to_keep = buckets[~buckets.isin(buckets_to_remove)].index
                _tmp_len = len(ds[key])
                ds[key] = ds[key].select(indicies_to_keep)
                logger.info(f"Removed {(_tmp_len - len(ds[key]))} examples")

        return ds

    def _remove_unused_columns(self, ds):
        def _flatten_ex_metadata(example, *, metadata_keys: list):
            meta = {k: example[k] for k in metadata_keys}
            example = update_example_metadata(example=example, metadata=meta)
            return example

        # Remove columns that are not needed
        ds_column_names = ds[list(ds.keys())[0]].column_names
        metadata_keys = [
            c
            for c in ds_column_names
            if c.startswith("_") and c not in [CustomColName.META.value] + self.config.cols_to_keep
        ] + self.config.extra_metadata_cols
        cols_to_remove = set(ds_column_names) - set(self.config.cols_to_keep)
        ds = ds.map(
            _flatten_ex_metadata,
            num_proc=self.config.num_proc,
            fn_kwargs={"metadata_keys": metadata_keys},
            remove_columns=cols_to_remove,
        )
        logger.info(f"Metadata keys: {metadata_keys}")

        return ds

    def _remove_flagged_examples(self, ds):
        _tmp_len = len(ds)
        ds = ds.filter(
            lambda x: not x[CustomColName.FLAGGED.value],
            num_proc=self.config.num_proc,
            desc="Filtering flagged examples",
        )
        ds = ds.remove_columns([CustomColName.FLAGGED.value])
        logger.info(f"Removed {(_tmp_len - len(ds))} flagged examples")

        return ds
