from typing import Any

from datasets import DatasetDict, concatenate_datasets

from src.data.chatml_utils import sort_chatml_msg_keys
from src.data.raw_data_loader import RawDataLoader
from src.data.utils import CustomColName, DatasetConfig
from src.utils.logging_utils import get_logger
from src.utils.pipeline_utils import PipelineStep

logger = get_logger(name=__name__)


class CreateDataMix(PipelineStep):
    class Config(PipelineStep.Config):
        class Config:
            arbitrary_types_allowed = True

        hf_token: str
        num_proc: int
        dataset_cfg_list: list[DatasetConfig]
        output_ds_config: DatasetConfig
        push_to_hf_hub: bool = True

    def __init__(
        self,
        *,
        raw_data_loader: RawDataLoader,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        self.config = self.Config(**kwargs)
        self.raw_data_loader = raw_data_loader

    def _call(self, **kwargs) -> Any:
        train_datasets, eval_datasets = [], []
        seen_ds_ids = set()
        for ds_cfg in self.config.dataset_cfg_list:
            ds = self.raw_data_loader.load(dataset_config=ds_cfg)

            # Sanity
            train_ds_id = set(ds[ds_cfg.train_split_name][CustomColName.DS_ID])
            eval_ds_id = set(ds[ds_cfg.eval_split_name][CustomColName.DS_ID])
            assert len(train_ds_id) == 1 and len(eval_ds_id) == 1
            assert train_ds_id == eval_ds_id

            # Sort chatml message keys
            ds = ds.map(sort_chatml_msg_keys, num_proc=self.config.num_proc, desc="Sorting chatml message keys")
            for split in ds.keys():
                # Map ignores reordering of messages keys
                ds[split] = ds[split].rename_column(CustomColName.REORDERED_MSGS.value, "messages")

            train_datasets.append(ds[ds_cfg.train_split_name])
            eval_datasets.append(ds[ds_cfg.eval_split_name])
            seen_ds_ids.add(train_ds_id.pop())

        logger.info(f"Loaded {len(train_datasets)} train datasets and {len(eval_datasets)} eval datasets")

        # Combine datasets
        final_ds = DatasetDict(
            train=concatenate_datasets(train_datasets),
            test=concatenate_datasets(eval_datasets),
        )
        logger.info(f"Final dataset:\n{final_ds}")

        if self.config.push_to_hf_hub:
            final_ds.push_to_hub(
                self.config.output_ds_config.name_or_path,
                private=True,
                token=self.config.hf_token,
            )
            logger.info(f"Dataset {self.config.output_ds_config.name_or_path} pushed to hub")
